FEATURE (cloud): Add cloud

This commit is contained in:
Rostislav Dugin
2026-03-26 12:35:32 +03:00
parent c648e9c29f
commit 61a0bcabb1
106 changed files with 8924 additions and 1963 deletions

View File

@@ -27,6 +27,13 @@ VALKEY_PORT=6379
VALKEY_USERNAME=
VALKEY_PASSWORD=
VALKEY_IS_SSL=false
# billing
PRICE_PER_GB_CENTS=
IS_PADDLE_SANDBOX=true
PADDLE_API_KEY=
PADDLE_WEBHOOK_SECRET=
PADDLE_PRICE_ID=
PADDLE_CLIENT_TOKEN=
# testing
# to get Google Drive env variables: add storage in UI and copy data from added storage here
TEST_GOOGLE_DRIVE_CLIENT_ID=

View File

@@ -9,6 +9,7 @@ import (
"os/exec"
"os/signal"
"path/filepath"
"runtime/debug"
"syscall"
"time"
@@ -25,6 +26,8 @@ import (
backups_download "databasus-backend/internal/features/backups/backups/download"
backups_services "databasus-backend/internal/features/backups/backups/services"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/billing"
billing_paddle "databasus-backend/internal/features/billing/paddle"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/disk"
"databasus-backend/internal/features/encryption/secrets"
@@ -105,7 +108,9 @@ func main() {
go generateSwaggerDocs(log)
gin.SetMode(gin.ReleaseMode)
ginApp := gin.Default()
ginApp := gin.New()
ginApp.Use(gin.Logger())
ginApp.Use(ginRecoveryWithLogger(log))
// Add GZIP compression middleware
ginApp.Use(gzip.Gzip(
@@ -217,6 +222,10 @@ func setUpRoutes(r *gin.Engine) {
backups_controllers.GetPostgresWalBackupController().RegisterRoutes(v1)
databases.GetDatabaseController().RegisterPublicRoutes(v1)
if config.GetEnv().IsCloud {
billing_paddle.GetPaddleBillingController().RegisterPublicRoutes(v1)
}
// Setup auth middleware
userService := users_services.GetUserService()
authMiddleware := users_middleware.AuthMiddleware(userService)
@@ -240,6 +249,7 @@ func setUpRoutes(r *gin.Engine) {
audit_logs.GetAuditLogController().RegisterRoutes(protected)
users_controllers.GetManagementController().RegisterRoutes(protected)
users_controllers.GetSettingsController().RegisterRoutes(protected)
billing.GetBillingController().RegisterRoutes(protected)
}
func setUpDependencies() {
@@ -252,6 +262,11 @@ func setUpDependencies() {
storages.SetupDependencies()
backups_config.SetupDependencies()
task_cancellation.SetupDependencies()
billing.SetupDependencies()
if config.GetEnv().IsCloud {
billing_paddle.SetupDependencies()
}
}
func runBackgroundTasks(log *slog.Logger) {
@@ -308,6 +323,12 @@ func runBackgroundTasks(log *slog.Logger) {
go runWithPanicLogging(log, "restore nodes registry background service", func() {
restoring.GetRestoreNodesRegistry().Run(ctx)
})
if config.GetEnv().IsCloud {
go runWithPanicLogging(log, "billing background service", func() {
billing.GetBillingService().Run(ctx, *log)
})
}
} else {
log.Info("Skipping primary node tasks as not primary node")
}
@@ -330,7 +351,7 @@ func runBackgroundTasks(log *slog.Logger) {
func runWithPanicLogging(log *slog.Logger, serviceName string, fn func()) {
defer func() {
if r := recover(); r != nil {
log.Error("Panic in "+serviceName, "error", r)
log.Error("Panic in "+serviceName, "error", r, "stacktrace", string(debug.Stack()))
}
}()
fn()
@@ -410,6 +431,25 @@ func enableCors(ginApp *gin.Engine) {
}
}
func ginRecoveryWithLogger(log *slog.Logger) gin.HandlerFunc {
return func(ctx *gin.Context) {
defer func() {
if r := recover(); r != nil {
log.Error("Panic recovered in HTTP handler",
"error", r,
"stacktrace", string(debug.Stack()),
"method", ctx.Request.Method,
"path", ctx.Request.URL.Path,
)
ctx.AbortWithStatus(http.StatusInternalServerError)
}
}()
ctx.Next()
}
}
func mountFrontend(ginApp *gin.Engine) {
staticDir := "./ui/build"
ginApp.NoRoute(func(c *gin.Context) {

View File

@@ -5,6 +5,7 @@ go 1.26.1
require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.3
github.com/PaddleHQ/paddle-go-sdk v1.0.0
github.com/gin-contrib/cors v1.7.5
github.com/gin-contrib/gzip v1.2.3
github.com/gin-gonic/gin v1.10.0
@@ -100,6 +101,8 @@ require (
github.com/emersion/go-message v0.18.2 // indirect
github.com/emersion/go-vcard v0.0.0-20241024213814-c9703dde27ff // indirect
github.com/flynn/noise v1.1.0 // indirect
github.com/ggicci/httpin v0.19.0 // indirect
github.com/ggicci/owl v0.8.2 // indirect
github.com/go-chi/chi/v5 v5.2.3 // indirect
github.com/go-darwin/apfs v0.0.0-20211011131704-f84b94dbf348 // indirect
github.com/go-git/go-billy/v5 v5.6.2 // indirect

View File

@@ -77,6 +77,8 @@ github.com/Max-Sum/base32768 v0.0.0-20230304063302-18e6ce5945fd/go.mod h1:C8yoIf
github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/PaddleHQ/paddle-go-sdk v1.0.0 h1:+EXitsPFbRcc0CpQE/MIeudxiVOR8pFe/aOWTEUHDKU=
github.com/PaddleHQ/paddle-go-sdk v1.0.0/go.mod h1:kbBBzf0BHEj38QvhtoELqlGip3alKgA/I+vl7RQzB58=
github.com/ProtonMail/bcrypt v0.0.0-20210511135022-227b4adcab57/go.mod h1:HecWFHognK8GfRDGnFQbW/LiV7A3MX3gZVs45vk5h8I=
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf h1:yc9daCCYUefEs69zUkSzubzjBbL+cmOXgnmt9Fyd9ug=
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf/go.mod h1:o0ESU9p83twszAU8LBeJKFAAMX14tISa0yk4Oo5TOqo=
@@ -248,6 +250,10 @@ github.com/gabriel-vasile/mimetype v1.4.11/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9t
github.com/geoffgarside/ber v1.1.0/go.mod h1:jVPKeCbj6MvQZhwLYsGwaGI52oUorHoHKNecGT85ZCc=
github.com/geoffgarside/ber v1.2.0 h1:/loowoRcs/MWLYmGX9QtIAbA+V/FrnVLsMMPhwiRm64=
github.com/geoffgarside/ber v1.2.0/go.mod h1:jVPKeCbj6MvQZhwLYsGwaGI52oUorHoHKNecGT85ZCc=
github.com/ggicci/httpin v0.19.0 h1:p0B3SWLVgg770VirYiHB14M5wdRx3zR8mCTzM/TkTQ8=
github.com/ggicci/httpin v0.19.0/go.mod h1:hzsQHcbqLabmGOycf7WNw6AAzcVbsMeoOp46bWAbIWc=
github.com/ggicci/owl v0.8.2 h1:og+lhqpzSMPDdEB+NJfzoAJARP7qCG3f8uUC3xvGukA=
github.com/ggicci/owl v0.8.2/go.mod h1:PHRD57u41vFN5UtFz2SF79yTVoM3HlWpjMiE+ZU2dj4=
github.com/gin-contrib/cors v1.7.5 h1:cXC9SmofOrRg0w9PigwGlHG3ztswH6bqq4vJVXnvYMk=
github.com/gin-contrib/cors v1.7.5/go.mod h1:4q3yi7xBEDDWKapjT2o1V7mScKDDr8k+jZ0fSquGoy0=
github.com/gin-contrib/gzip v1.2.3 h1:dAhT722RuEG330ce2agAs75z7yB+NKvX/ZM1r8w0u2U=
@@ -454,6 +460,8 @@ github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1
github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk=
github.com/jtolio/noiseconn v0.0.0-20231127013910-f6d9ecbf1de7 h1:JcltaO1HXM5S2KYOYcKgAV7slU0xPy1OcvrVgn98sRQ=
github.com/jtolio/noiseconn v0.0.0-20231127013910-f6d9ecbf1de7/go.mod h1:MEkhEPFwP3yudWO0lj6vfYpLIB+3eIcuIW+e0AZzUQk=
github.com/justinas/alice v1.2.0 h1:+MHSA/vccVCF4Uq37S42jwlkvI2Xzl7zTPCN5BnZNVo=
github.com/justinas/alice v1.2.0/go.mod h1:fN5HRH/reO/zrUflLfTN43t3vXvKzvZIENsNEe7i7qA=
github.com/jzelinskie/whirlpool v0.0.0-20201016144138-0675e54bb004 h1:G+9t9cEtnC9jFiTxyptEKuNIAbiN5ZCQzX2a74lj3xg=
github.com/jzelinskie/whirlpool v0.0.0-20201016144138-0675e54bb004/go.mod h1:KmHnJWQrgEvbuy0vcvj00gtMqbvNn1L+3YUZLK/B92c=
github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU=

View File

@@ -5,6 +5,7 @@ import (
"path/filepath"
"strings"
"sync"
"time"
"github.com/ilyakaznacheev/cleanenv"
"github.com/joho/godotenv"
@@ -53,6 +54,20 @@ type EnvVariables struct {
TempFolder string
SecretKeyPath string
// Billing (always tax-exclusive)
PricePerGBCents int64 `env:"PRICE_PER_GB_CENTS"`
MinStorageGB int
MaxStorageGB int
TrialDuration time.Duration
TrialStorageGB int
GracePeriod time.Duration
// Paddle billing
IsPaddleSandbox bool `env:"IS_PADDLE_SANDBOX"`
PaddleApiKey string `env:"PADDLE_API_KEY"`
PaddleWebhookSecret string `env:"PADDLE_WEBHOOK_SECRET"`
PaddlePriceID string `env:"PADDLE_PRICE_ID"`
PaddleClientToken string `env:"PADDLE_CLIENT_TOKEN"`
TestGoogleDriveClientID string `env:"TEST_GOOGLE_DRIVE_CLIENT_ID"`
TestGoogleDriveClientSecret string `env:"TEST_GOOGLE_DRIVE_CLIENT_SECRET"`
TestGoogleDriveTokenJSON string `env:"TEST_GOOGLE_DRIVE_TOKEN_JSON"`
@@ -132,9 +147,9 @@ var (
once sync.Once
)
func GetEnv() EnvVariables {
func GetEnv() *EnvVariables {
once.Do(loadEnvVariables)
return env
return &env
}
func loadEnvVariables() {
@@ -363,5 +378,39 @@ func loadEnvVariables() {
}
// Billing
if env.IsCloud {
if env.PricePerGBCents == 0 {
log.Error("PRICE_PER_GB_CENTS is empty or zero")
os.Exit(1)
}
if env.PaddleApiKey == "" {
log.Error("PADDLE_API_KEY is empty")
os.Exit(1)
}
if env.PaddleWebhookSecret == "" {
log.Error("PADDLE_WEBHOOK_SECRET is empty")
os.Exit(1)
}
if env.PaddlePriceID == "" {
log.Error("PADDLE_PRICE_ID is empty")
os.Exit(1)
}
if env.PaddleClientToken == "" {
log.Error("PADDLE_CLIENT_TOKEN is empty")
os.Exit(1)
}
}
env.MinStorageGB = 20
env.MaxStorageGB = 10_000
env.TrialDuration = 24 * time.Hour
env.TrialStorageGB = 20
env.GracePeriod = 30 * 24 * time.Hour
log.Info("Environment variables loaded successfully!")
}

View File

@@ -171,26 +171,6 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
backup.BackupSizeMb = completedMBs
backup.BackupDurationMs = time.Since(start).Milliseconds()
// Check size limit (0 = unlimited)
if backupConfig.MaxBackupSizeMB > 0 &&
completedMBs > float64(backupConfig.MaxBackupSizeMB) {
errMsg := fmt.Sprintf(
"backup size (%.2f MB) exceeded maximum allowed size (%d MB)",
completedMBs,
backupConfig.MaxBackupSizeMB,
)
backup.Status = backups_core.BackupStatusFailed
backup.IsSkipRetry = true
backup.FailMessage = &errMsg
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to save backup with size exceeded error", "error", err)
}
cancel() // Cancel the backup context
return
}
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to update backup progress", "error", err)
}

View File

@@ -153,121 +153,3 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
assert.Equal(t, notifier.ID, capturedNotifier.ID)
})
}
func Test_BackupSizeLimits(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
t.Run("UnlimitedSize_MaxBackupSizeMBIsZero_BackupCompletes", func(t *testing.T) {
// Enable backups with unlimited size (0)
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backupConfig.MaxBackupSizeMB = 0 // unlimited
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backuperNode := CreateTestBackuperNode()
backuperNode.createBackupUseCase = &CreateLargeBackupUsecase{}
// Create a backup record
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backuperNode.MakeBackup(backup.ID, false)
// Verify backup completed successfully even with large size
updatedBackup, err := backupRepository.FindByID(backup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
assert.Equal(t, float64(10000), updatedBackup.BackupSizeMb)
assert.Nil(t, updatedBackup.FailMessage)
})
t.Run("SizeExceeded_BackupFailedWithIsSkipRetry", func(t *testing.T) {
// Enable backups with 5 MB limit
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backupConfig.MaxBackupSizeMB = 5
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backuperNode := CreateTestBackuperNode()
backuperNode.createBackupUseCase = &CreateProgressiveBackupUsecase{}
// Create a backup record
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backuperNode.MakeBackup(backup.ID, false)
// Verify backup was marked as failed with IsSkipRetry=true
updatedBackup, err := backupRepository.FindByID(backup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusFailed, updatedBackup.Status)
assert.True(t, updatedBackup.IsSkipRetry)
assert.NotNil(t, updatedBackup.FailMessage)
assert.Contains(t, *updatedBackup.FailMessage, "exceeded maximum allowed size")
assert.Contains(t, *updatedBackup.FailMessage, "10.00 MB")
assert.Contains(t, *updatedBackup.FailMessage, "5 MB")
assert.Greater(t, updatedBackup.BackupSizeMb, float64(5))
})
t.Run("SizeWithinLimit_BackupCompletes", func(t *testing.T) {
// Enable backups with 100 MB limit
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backupConfig.MaxBackupSizeMB = 100
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backuperNode := CreateTestBackuperNode()
backuperNode.createBackupUseCase = &CreateMediumBackupUsecase{}
// Create a backup record
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backuperNode.MakeBackup(backup.ID, false)
// Verify backup completed successfully
updatedBackup, err := backupRepository.FindByID(backup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
assert.Equal(t, float64(50), updatedBackup.BackupSizeMb)
assert.Nil(t, updatedBackup.FailMessage)
})
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/google/uuid"
"databasus-backend/internal/config"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/storages"
@@ -26,6 +27,7 @@ type BackupCleaner struct {
backupRepository *backups_core.BackupRepository
storageService *storages.StorageService
backupConfigService *backups_config.BackupConfigService
billingService BillingService
fieldEncryptor util_encryption.FieldEncryptor
logger *slog.Logger
backupRemoveListeners []backups_core.BackupRemoveListener
@@ -44,6 +46,10 @@ func (c *BackupCleaner) Run(ctx context.Context) {
return
}
retentionLog := c.logger.With("task_name", "clean_by_retention_policy")
exceededLog := c.logger.With("task_name", "clean_exceeded_storage_backups")
staleLog := c.logger.With("task_name", "clean_stale_basebackups")
ticker := time.NewTicker(cleanerTickerInterval)
defer ticker.Stop()
@@ -52,16 +58,16 @@ func (c *BackupCleaner) Run(ctx context.Context) {
case <-ctx.Done():
return
case <-ticker.C:
if err := c.cleanByRetentionPolicy(); err != nil {
c.logger.Error("Failed to clean backups by retention policy", "error", err)
if err := c.cleanByRetentionPolicy(retentionLog); err != nil {
retentionLog.Error("failed to clean backups by retention policy", "error", err)
}
if err := c.cleanExceededBackups(); err != nil {
c.logger.Error("Failed to clean exceeded backups", "error", err)
if err := c.cleanExceededStorageBackups(exceededLog); err != nil {
exceededLog.Error("failed to clean exceeded backups", "error", err)
}
if err := c.cleanStaleUploadedBasebackups(); err != nil {
c.logger.Error("Failed to clean stale uploaded basebackups", "error", err)
if err := c.cleanStaleUploadedBasebackups(staleLog); err != nil {
staleLog.Error("failed to clean stale uploaded basebackups", "error", err)
}
}
}
@@ -104,7 +110,7 @@ func (c *BackupCleaner) AddBackupRemoveListener(listener backups_core.BackupRemo
c.backupRemoveListeners = append(c.backupRemoveListeners, listener)
}
func (c *BackupCleaner) cleanStaleUploadedBasebackups() error {
func (c *BackupCleaner) cleanStaleUploadedBasebackups(logger *slog.Logger) error {
staleBackups, err := c.backupRepository.FindStaleUploadedBasebackups(
time.Now().UTC().Add(-10 * time.Minute),
)
@@ -113,31 +119,30 @@ func (c *BackupCleaner) cleanStaleUploadedBasebackups() error {
}
for _, backup := range staleBackups {
backupLog := logger.With("database_id", backup.DatabaseID, "backup_id", backup.ID)
staleStorage, storageErr := c.storageService.GetStorageByID(backup.StorageID)
if storageErr != nil {
c.logger.Error(
"Failed to get storage for stale basebackup cleanup",
"backupId", backup.ID,
"storageId", backup.StorageID,
backupLog.Error(
"failed to get storage for stale basebackup cleanup",
"storage_id", backup.StorageID,
"error", storageErr,
)
} else {
if err := staleStorage.DeleteFile(c.fieldEncryptor, backup.FileName); err != nil {
c.logger.Error(
"Failed to delete stale basebackup file",
"backupId", backup.ID,
"fileName", backup.FileName,
"error", err,
backupLog.Error(
fmt.Sprintf("failed to delete stale basebackup file: %s", backup.FileName),
"error",
err,
)
}
metadataFileName := backup.FileName + ".metadata"
if err := staleStorage.DeleteFile(c.fieldEncryptor, metadataFileName); err != nil {
c.logger.Error(
"Failed to delete stale basebackup metadata file",
"backupId", backup.ID,
"fileName", metadataFileName,
"error", err,
backupLog.Error(
fmt.Sprintf("failed to delete stale basebackup metadata file: %s", metadataFileName),
"error",
err,
)
}
}
@@ -147,77 +152,67 @@ func (c *BackupCleaner) cleanStaleUploadedBasebackups() error {
backup.FailMessage = &failMsg
if err := c.backupRepository.Save(backup); err != nil {
c.logger.Error(
"Failed to mark stale uploaded basebackup as failed",
"backupId", backup.ID,
"error", err,
)
backupLog.Error("failed to mark stale uploaded basebackup as failed", "error", err)
continue
}
c.logger.Info(
"Marked stale uploaded basebackup as failed and cleaned storage",
"backupId", backup.ID,
"databaseId", backup.DatabaseID,
)
backupLog.Info("marked stale uploaded basebackup as failed and cleaned storage")
}
return nil
}
func (c *BackupCleaner) cleanByRetentionPolicy() error {
func (c *BackupCleaner) cleanByRetentionPolicy(logger *slog.Logger) error {
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
dbLog := logger.With("database_id", backupConfig.DatabaseID, "policy", backupConfig.RetentionPolicyType)
var cleanErr error
switch backupConfig.RetentionPolicyType {
case backups_config.RetentionPolicyTypeCount:
cleanErr = c.cleanByCount(backupConfig)
cleanErr = c.cleanByCount(dbLog, backupConfig)
case backups_config.RetentionPolicyTypeGFS:
cleanErr = c.cleanByGFS(backupConfig)
cleanErr = c.cleanByGFS(dbLog, backupConfig)
default:
cleanErr = c.cleanByTimePeriod(backupConfig)
cleanErr = c.cleanByTimePeriod(dbLog, backupConfig)
}
if cleanErr != nil {
c.logger.Error(
"Failed to clean backups by retention policy",
"databaseId", backupConfig.DatabaseID,
"policy", backupConfig.RetentionPolicyType,
"error", cleanErr,
)
dbLog.Error("failed to clean backups by retention policy", "error", cleanErr)
}
}
return nil
}
func (c *BackupCleaner) cleanExceededBackups() error {
func (c *BackupCleaner) cleanExceededStorageBackups(logger *slog.Logger) error {
if !config.GetEnv().IsCloud {
return nil
}
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
if backupConfig.MaxBackupsTotalSizeMB <= 0 {
dbLog := logger.With("database_id", backupConfig.DatabaseID)
subscription, subErr := c.billingService.GetSubscription(dbLog, backupConfig.DatabaseID)
if subErr != nil {
dbLog.Error("failed to get subscription for exceeded backups check", "error", subErr)
continue
}
if err := c.cleanExceededBackupsForDatabase(
backupConfig.DatabaseID,
backupConfig.MaxBackupsTotalSizeMB,
); err != nil {
c.logger.Error(
"Failed to clean exceeded backups for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
storageLimitMB := int64(subscription.GetBackupsStorageGB()) * 1024
if err := c.cleanExceededBackupsForDatabase(dbLog, backupConfig.DatabaseID, storageLimitMB); err != nil {
dbLog.Error("failed to clean exceeded backups for database", "error", err)
continue
}
}
@@ -225,7 +220,7 @@ func (c *BackupCleaner) cleanExceededBackups() error {
return nil
}
func (c *BackupCleaner) cleanByTimePeriod(backupConfig *backups_config.BackupConfig) error {
func (c *BackupCleaner) cleanByTimePeriod(logger *slog.Logger, backupConfig *backups_config.BackupConfig) error {
if backupConfig.RetentionTimePeriod == "" {
return nil
}
@@ -255,21 +250,17 @@ func (c *BackupCleaner) cleanByTimePeriod(backupConfig *backups_config.BackupCon
}
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
logger.Error("failed to delete old backup", "backup_id", backup.ID, "error", err)
continue
}
c.logger.Info(
"Deleted old backup",
"backupId", backup.ID,
"databaseId", backupConfig.DatabaseID,
)
logger.Info("deleted old backup", "backup_id", backup.ID)
}
return nil
}
func (c *BackupCleaner) cleanByCount(backupConfig *backups_config.BackupConfig) error {
func (c *BackupCleaner) cleanByCount(logger *slog.Logger, backupConfig *backups_config.BackupConfig) error {
if backupConfig.RetentionCount <= 0 {
return nil
}
@@ -298,28 +289,20 @@ func (c *BackupCleaner) cleanByCount(backupConfig *backups_config.BackupConfig)
}
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error(
"Failed to delete backup by count policy",
"backupId",
backup.ID,
"error",
err,
)
logger.Error("failed to delete backup by count policy", "backup_id", backup.ID, "error", err)
continue
}
c.logger.Info(
"Deleted backup by count policy",
"backupId", backup.ID,
"databaseId", backupConfig.DatabaseID,
"retentionCount", backupConfig.RetentionCount,
logger.Info(
fmt.Sprintf("deleted backup by count policy: retention count is %d", backupConfig.RetentionCount),
"backup_id", backup.ID,
)
}
return nil
}
func (c *BackupCleaner) cleanByGFS(backupConfig *backups_config.BackupConfig) error {
func (c *BackupCleaner) cleanByGFS(logger *slog.Logger, backupConfig *backups_config.BackupConfig) error {
if backupConfig.RetentionGfsHours <= 0 && backupConfig.RetentionGfsDays <= 0 &&
backupConfig.RetentionGfsWeeks <= 0 && backupConfig.RetentionGfsMonths <= 0 &&
backupConfig.RetentionGfsYears <= 0 {
@@ -357,29 +340,20 @@ func (c *BackupCleaner) cleanByGFS(backupConfig *backups_config.BackupConfig) er
}
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error(
"Failed to delete backup by GFS policy",
"backupId",
backup.ID,
"error",
err,
)
logger.Error("failed to delete backup by GFS policy", "backup_id", backup.ID, "error", err)
continue
}
c.logger.Info(
"Deleted backup by GFS policy",
"backupId", backup.ID,
"databaseId", backupConfig.DatabaseID,
)
logger.Info("deleted backup by GFS policy", "backup_id", backup.ID)
}
return nil
}
func (c *BackupCleaner) cleanExceededBackupsForDatabase(
logger *slog.Logger,
databaseID uuid.UUID,
limitperDbMB int64,
limitPerDbMB int64,
) error {
for {
backupsTotalSizeMB, err := c.backupRepository.GetTotalSizeByDatabase(databaseID)
@@ -387,7 +361,7 @@ func (c *BackupCleaner) cleanExceededBackupsForDatabase(
return err
}
if backupsTotalSizeMB <= float64(limitperDbMB) {
if backupsTotalSizeMB <= float64(limitPerDbMB) {
break
}
@@ -400,59 +374,27 @@ func (c *BackupCleaner) cleanExceededBackupsForDatabase(
}
if len(oldestBackups) == 0 {
c.logger.Warn(
"No backups to delete but still over limit",
"databaseId",
databaseID,
"totalSizeMB",
backupsTotalSizeMB,
"limitMB",
limitperDbMB,
)
logger.Warn(fmt.Sprintf(
"no backups to delete but still over limit: total size is %.1f MB, limit is %d MB",
backupsTotalSizeMB, limitPerDbMB,
))
break
}
backup := oldestBackups[0]
if isRecentBackup(backup) {
c.logger.Warn(
"Oldest backup is too recent to delete, stopping size cleanup",
"databaseId",
databaseID,
"backupId",
backup.ID,
"totalSizeMB",
backupsTotalSizeMB,
"limitMB",
limitperDbMB,
)
break
}
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error(
"Failed to delete exceeded backup",
"backupId",
backup.ID,
"databaseId",
databaseID,
"error",
err,
)
logger.Error("failed to delete exceeded backup", "backup_id", backup.ID, "error", err)
return err
}
c.logger.Info(
"Deleted exceeded backup",
"backupId",
backup.ID,
"databaseId",
databaseID,
"backupSizeMB",
backup.BackupSizeMb,
"totalSizeMB",
backupsTotalSizeMB,
"limitMB",
limitperDbMB,
logger.Info(
fmt.Sprintf("deleted exceeded backup: backup size is %.1f MB, total size is %.1f MB, limit is %d MB",
backup.BackupSizeMb, backupsTotalSizeMB, limitPerDbMB),
"backup_id", backup.ID,
)
}

View File

@@ -425,7 +425,7 @@ func Test_CleanByGFS_KeepsCorrectBackupsPerSlot(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -502,7 +502,7 @@ func Test_CleanByGFS_WithWeeklyAndMonthlySlots_KeepsWiderSpread(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -576,7 +576,7 @@ func Test_CleanByGFS_WithHourlySlots_KeepsCorrectBackups(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -677,7 +677,7 @@ func Test_CleanByGFS_SkipsRecentBackup_WhenNotInKeepSet(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -759,7 +759,7 @@ func Test_CleanByGFS_With20DailyBackups_KeepsOnlyExpectedCount(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -844,7 +844,7 @@ func Test_CleanByGFS_WithMultipleBackupsPerDay_KeepsOnlyOnePerDailySlot(t *testi
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -929,7 +929,7 @@ func Test_CleanByGFS_With24HourlySlotsAnd23DailyBackups_DeletesExcessBackups(t *
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -999,7 +999,7 @@ func Test_CleanByGFS_WithDisabledHourlySlotsAnd23DailyBackups_DeletesExcessBacku
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -1069,7 +1069,7 @@ func Test_CleanByGFS_WithDailySlotsAndWeeklyBackups_DeletesExcessBackups(t *test
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -1152,7 +1152,7 @@ func Test_CleanByGFS_WithWeeklySlotsAndMonthlyBackups_DeletesExcessBackups(t *te
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)

View File

@@ -1,14 +1,17 @@
package backuping
import (
"log/slog"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
billing_models "databasus-backend/internal/features/billing/models"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
@@ -17,6 +20,7 @@ import (
users_testing "databasus-backend/internal/features/users/testing"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/storage"
"databasus-backend/internal/util/logger"
"databasus-backend/internal/util/period"
)
@@ -51,6 +55,7 @@ func Test_CleanOldBackups_DeletesBackupsOlderThanRetentionTimePeriod(t *testing.
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -89,7 +94,7 @@ func Test_CleanOldBackups_DeletesBackupsOlderThanRetentionTimePeriod(t *testing.
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -129,6 +134,7 @@ func Test_CleanOldBackups_SkipsDatabaseWithForeverRetentionPeriod(t *testing.T)
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -145,7 +151,7 @@ func Test_CleanOldBackups_SkipsDatabaseWithForeverRetentionPeriod(t *testing.T)
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -154,7 +160,8 @@ func Test_CleanOldBackups_SkipsDatabaseWithForeverRetentionPeriod(t *testing.T)
assert.Equal(t, oldBackup.ID, remainingBackups[0].ID)
}
func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
func Test_CleanExceededBackups_WhenUnderStorageLimit_NoBackupsDeleted(t *testing.T) {
enableCloud(t)
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
@@ -178,14 +185,14 @@ func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 100,
BackupIntervalID: interval.ID,
BackupInterval: interval,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -196,15 +203,18 @@ func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 16.67,
BackupSizeMb: 100,
CreatedAt: time.Now().UTC().Add(-time.Duration(i) * time.Hour),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
}
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
}
cleaner := CreateTestBackupCleaner(mockBilling)
err = cleaner.cleanExceededStorageBackups(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -212,7 +222,8 @@ func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
assert.Equal(t, 3, len(remainingBackups))
}
func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T) {
func Test_CleanExceededBackups_WhenOverStorageLimit_DeletesOldestBackups(t *testing.T) {
enableCloud(t)
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
@@ -236,18 +247,20 @@ func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T)
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 30,
BackupIntervalID: interval.ID,
BackupInterval: interval,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// 5 backups at 300 MB each = 1500 MB total, limit = 1 GB (1024 MB)
// Expect 2 oldest deleted, 3 remain (900 MB < 1024 MB)
now := time.Now().UTC()
var backupIDs []uuid.UUID
for i := 0; i < 5; i++ {
@@ -256,7 +269,7 @@ func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T)
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
BackupSizeMb: 300,
CreatedAt: now.Add(-time.Duration(4-i) * time.Hour),
}
err = backupRepository.Save(backup)
@@ -264,8 +277,11 @@ func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T)
backupIDs = append(backupIDs, backup.ID)
}
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
}
cleaner := CreateTestBackupCleaner(mockBilling)
err = cleaner.cleanExceededStorageBackups(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -284,6 +300,7 @@ func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T)
}
func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
enableCloud(t)
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
@@ -307,20 +324,21 @@ func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 50,
BackupIntervalID: interval.ID,
BackupInterval: interval,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
now := time.Now().UTC()
// 3 completed at 500 MB each = 1500 MB, limit = 1 GB (1024 MB)
completedBackups := make([]*backups_core.Backup, 3)
for i := 0; i < 3; i++ {
backup := &backups_core.Backup{
@@ -328,7 +346,7 @@ func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 30,
BackupSizeMb: 500,
CreatedAt: now.Add(-time.Duration(3-i) * time.Hour),
}
err = backupRepository.Save(backup)
@@ -347,8 +365,11 @@ func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
err = backupRepository.Save(inProgressBackup)
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
}
cleaner := CreateTestBackupCleaner(mockBilling)
err = cleaner.cleanExceededStorageBackups(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -365,7 +386,8 @@ func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
assert.True(t, inProgressFound, "In-progress backup should not be deleted")
}
func Test_CleanExceededBackups_WithZeroLimit_SkipsDatabase(t *testing.T) {
func Test_CleanExceededBackups_WithZeroStorageLimit_RemovesAllBackups(t *testing.T) {
enableCloud(t)
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
@@ -389,14 +411,14 @@ func Test_CleanExceededBackups_WithZeroLimit_SkipsDatabase(t *testing.T) {
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 0,
BackupIntervalID: interval.ID,
BackupInterval: interval,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -408,19 +430,23 @@ func Test_CleanExceededBackups_WithZeroLimit_SkipsDatabase(t *testing.T) {
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 100,
CreatedAt: time.Now().UTC().Add(-time.Duration(i) * time.Hour),
CreatedAt: time.Now().UTC().Add(-time.Duration(i+2) * time.Hour),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
}
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
// StorageGB=0 means no storage allowed — all backups should be removed
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{StorageGB: 0, Status: billing_models.StatusActive},
}
cleaner := CreateTestBackupCleaner(mockBilling)
err = cleaner.cleanExceededStorageBackups(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Equal(t, 10, len(remainingBackups))
assert.Equal(t, 0, len(remainingBackups))
}
func Test_GetTotalSizeByDatabase_CalculatesCorrectly(t *testing.T) {
@@ -522,6 +548,7 @@ func Test_CleanByCount_KeepsNewestNBackups_DeletesOlder(t *testing.T) {
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -545,7 +572,7 @@ func Test_CleanByCount_KeepsNewestNBackups_DeletesOlder(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -594,6 +621,7 @@ func Test_CleanByCount_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -612,7 +640,7 @@ func Test_CleanByCount_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -651,6 +679,7 @@ func Test_CleanByCount_DoesNotDeleteInProgressBackups(t *testing.T) {
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -682,7 +711,7 @@ func Test_CleanByCount_DoesNotDeleteInProgressBackups(t *testing.T) {
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -776,6 +805,7 @@ func Test_CleanByTimePeriod_SkipsRecentBackup_EvenIfOlderThanRetention(t *testin
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -805,7 +835,7 @@ func Test_CleanByTimePeriod_SkipsRecentBackup_EvenIfOlderThanRetention(t *testin
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -847,6 +877,7 @@ func Test_CleanByCount_SkipsRecentBackup_EvenIfOverLimit(t *testing.T) {
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -893,7 +924,7 @@ func Test_CleanByCount_SkipsRecentBackup_EvenIfOverLimit(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -914,7 +945,8 @@ func Test_CleanByCount_SkipsRecentBackup_EvenIfOverLimit(t *testing.T) {
assert.True(t, remainingIDs[newestBackup.ID], "Newest backup should be preserved")
}
func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testing.T) {
func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverStorageLimit(t *testing.T) {
enableCloud(t)
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
@@ -937,18 +969,18 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
interval := createTestInterval()
// Total size limit is 10 MB. We have two backups of 8 MB each (16 MB total).
// Total size limit = 1 GB (1024 MB). Two backups of 600 MB each (1200 MB total).
// The oldest backup was created 30 minutes ago — within the grace period.
// The cleaner must stop and leave both backups intact.
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 10,
BackupIntervalID: interval.ID,
BackupInterval: interval,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -960,7 +992,7 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 8,
BackupSizeMb: 600,
CreatedAt: now.Add(-30 * time.Minute),
}
newerRecentBackup := &backups_core.Backup{
@@ -968,7 +1000,7 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 8,
BackupSizeMb: 600,
CreatedAt: now.Add(-10 * time.Minute),
}
@@ -977,8 +1009,11 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
err = backupRepository.Save(newerRecentBackup)
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
}
cleaner := CreateTestBackupCleaner(mockBilling)
err = cleaner.cleanExceededStorageBackups(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -991,6 +1026,82 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
)
}
func Test_CleanExceededStorageBackups_WhenNonCloud_SkipsCleanup(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// 5 backups at 500 MB each = 2500 MB, would exceed 1 GB limit in cloud mode
now := time.Now().UTC()
for i := 0; i < 5; i++ {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 500,
CreatedAt: now.Add(-time.Duration(i+2) * time.Hour),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
}
// IsCloud is false by default — cleaner should skip entirely
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
}
cleaner := CreateTestBackupCleaner(mockBilling)
err = cleaner.cleanExceededStorageBackups(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Equal(t, 5, len(remainingBackups), "All backups must remain in non-cloud mode")
}
type mockBillingService struct {
subscription *billing_models.Subscription
err error
}
func (m *mockBillingService) GetSubscription(
logger *slog.Logger,
databaseID uuid.UUID,
) (*billing_models.Subscription, error) {
return m.subscription, m.err
}
// Mock listener for testing
type mockBackupRemoveListener struct {
onBeforeBackupRemove func(*backups_core.Backup) error
@@ -1041,7 +1152,7 @@ func Test_CleanStaleUploadedBasebackups_MarksAsFailed(t *testing.T) {
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanStaleUploadedBasebackups()
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
assert.NoError(t, err)
updated, err := backupRepository.FindByID(staleBackup.ID)
@@ -1088,7 +1199,7 @@ func Test_CleanStaleUploadedBasebackups_SkipsRecentUploads(t *testing.T) {
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanStaleUploadedBasebackups()
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
assert.NoError(t, err)
updated, err := backupRepository.FindByID(recentBackup.ID)
@@ -1131,7 +1242,7 @@ func Test_CleanStaleUploadedBasebackups_SkipsActiveStreaming(t *testing.T) {
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanStaleUploadedBasebackups()
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
assert.NoError(t, err)
updated, err := backupRepository.FindByID(activeBackup.ID)
@@ -1179,7 +1290,7 @@ func Test_CleanStaleUploadedBasebackups_CleansStorageFiles(t *testing.T) {
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanStaleUploadedBasebackups()
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
assert.NoError(t, err)
updated, err := backupRepository.FindByID(staleBackup.ID)
@@ -1189,6 +1300,18 @@ func Test_CleanStaleUploadedBasebackups_CleansStorageFiles(t *testing.T) {
assert.Contains(t, *updated.FailMessage, "finalization timed out")
}
func enableCloud(t *testing.T) {
t.Helper()
config.GetEnv().IsCloud = true
t.Cleanup(func() {
config.GetEnv().IsCloud = false
})
}
func testLogger() *slog.Logger {
return logger.GetLogger().With("task_name", "test")
}
func createTestInterval() *intervals.Interval {
timeOfDay := "04:00"
interval := &intervals.Interval{

View File

@@ -10,6 +10,7 @@ import (
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/usecases"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/billing"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
@@ -28,6 +29,7 @@ var backupCleaner = &BackupCleaner{
backupRepository,
storages.GetStorageService(),
backups_config.GetBackupConfigService(),
billing.GetBillingService(),
encryption.GetFieldEncryptor(),
logger.GetLogger(),
[]backups_core.BackupRemoveListener{},
@@ -73,6 +75,7 @@ var backupsScheduler = &BackupsScheduler{
taskCancelManager,
backupNodesRegistry,
databases.GetDatabaseService(),
billing.GetBillingService(),
time.Now().UTC(),
logger.GetLogger(),
make(map[uuid.UUID]BackupToNodeRelation),

View File

@@ -0,0 +1,13 @@
package backuping
import (
"log/slog"
"github.com/google/uuid"
billing_models "databasus-backend/internal/features/billing/models"
)
type BillingService interface {
GetSubscription(logger *slog.Logger, databaseID uuid.UUID) (*billing_models.Subscription, error)
}

View File

@@ -29,6 +29,7 @@ type BackupsScheduler struct {
taskCancelManager *task_cancellation.TaskCancelManager
backupNodesRegistry *BackupNodesRegistry
databaseService *databases.DatabaseService
billingService BillingService
lastBackupTime time.Time
logger *slog.Logger
@@ -127,6 +128,34 @@ func (s *BackupsScheduler) StartBackup(database *databases.Database, isCallNotif
return
}
if config.GetEnv().IsCloud {
subscription, subErr := s.billingService.GetSubscription(s.logger, database.ID)
if subErr != nil || !subscription.CanCreateNewBackups() {
failMessage := "subscription has expired, please renew"
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: *backupConfig.StorageID,
Status: backups_core.BackupStatusFailed,
FailMessage: &failMessage,
IsSkipRetry: true,
CreatedAt: time.Now().UTC(),
}
backup.GenerateFilename(database.Name)
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error(
"failed to save failed backup for expired subscription",
"database_id", database.ID,
"error", err,
)
}
return
}
}
// Check for existing in-progress backups
inProgressBackups, err := s.backupRepository.FindByDatabaseIdAndStatus(
database.ID,
@@ -346,6 +375,27 @@ func (s *BackupsScheduler) runPendingBackups() error {
continue
}
if config.GetEnv().IsCloud {
subscription, subErr := s.billingService.GetSubscription(s.logger, backupConfig.DatabaseID)
if subErr != nil {
s.logger.Warn(
"failed to get subscription, skipping backup",
"database_id", backupConfig.DatabaseID,
"error", subErr,
)
continue
}
if !subscription.CanCreateNewBackups() {
s.logger.Debug(
"subscription is not active, skipping scheduled backup",
"database_id", backupConfig.DatabaseID,
"subscription_status", subscription.Status,
)
continue
}
}
s.StartBackup(database, remainedBackupTryCount == 1)
continue
}

View File

@@ -10,6 +10,7 @@ import (
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
billing_models "databasus-backend/internal/features/billing/models"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
@@ -968,7 +969,7 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T
cache_utils.ClearAllCache()
// Start scheduler so it can handle task completions
scheduler := CreateTestScheduler()
scheduler := CreateTestScheduler(nil)
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
@@ -1065,7 +1066,7 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
cache_utils.ClearAllCache()
// Start scheduler so it can handle task completions
scheduler := CreateTestScheduler()
scheduler := CreateTestScheduler(nil)
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
@@ -1332,7 +1333,7 @@ func Test_StartBackup_When2BackupsStartedForDifferentDatabases_BothUseCasesAreCa
defer StopBackuperNodeForTest(t, cancel, backuperNode)
// Create scheduler
scheduler := CreateTestScheduler()
scheduler := CreateTestScheduler(nil)
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
@@ -1458,3 +1459,313 @@ func Test_StartBackup_When2BackupsStartedForDifferentDatabases_BothUseCasesAreCa
time.Sleep(200 * time.Millisecond)
}
func Test_StartBackup_WhenCloudAndSubscriptionExpired_CreatesFailedBackup(t *testing.T) {
cache_utils.ClearAllCache()
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{
Status: billing_models.StatusExpired,
},
}
scheduler := CreateTestScheduler(mockBilling)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
enableCloud(t)
scheduler.StartBackup(database, false)
time.Sleep(100 * time.Millisecond)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
newestBackup := backups[0]
assert.Equal(t, backups_core.BackupStatusFailed, newestBackup.Status)
assert.NotNil(t, newestBackup.FailMessage)
assert.Equal(t, "subscription has expired, please renew", *newestBackup.FailMessage)
assert.True(t, newestBackup.IsSkipRetry)
time.Sleep(200 * time.Millisecond)
}
func Test_StartBackup_WhenCloudAndSubscriptionActive_ProceedsNormally(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{
Status: billing_models.StatusActive,
StorageGB: 10,
},
}
scheduler := CreateTestScheduler(mockBilling)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
enableCloud(t)
scheduler.StartBackup(database, false)
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
newestBackup := backups[0]
assert.Equal(t, backups_core.BackupStatusCompleted, newestBackup.Status)
time.Sleep(200 * time.Millisecond)
}
func Test_RunPendingBackups_WhenCloudAndSubscriptionExpired_SilentlySkips(t *testing.T) {
cache_utils.ClearAllCache()
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{
Status: billing_models.StatusExpired,
},
}
scheduler := CreateTestScheduler(mockBilling)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
})
enableCloud(t)
scheduler.runPendingBackups()
time.Sleep(100 * time.Millisecond)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1, "No new backup should be created, scheduler silently skips expired subscriptions")
time.Sleep(200 * time.Millisecond)
}
func Test_StartBackup_WhenNotCloudAndSubscriptionExpired_ProceedsNormally(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{
Status: billing_models.StatusExpired,
},
}
scheduler := CreateTestScheduler(mockBilling)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
scheduler.StartBackup(database, false)
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
assert.Equal(t, backups_core.BackupStatusCompleted, backups[0].Status,
"Billing check should not apply in non-cloud mode")
time.Sleep(200 * time.Millisecond)
}
func Test_RunPendingBackups_WhenNotCloudAndSubscriptionExpired_ProceedsNormally(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{
Status: billing_models.StatusExpired,
},
}
scheduler := CreateTestScheduler(mockBilling)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
})
scheduler.runPendingBackups()
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 2, "Billing check should not apply in non-cloud mode, new backup should be created")
time.Sleep(200 * time.Millisecond)
}

View File

@@ -35,58 +35,74 @@ func CreateTestRouter() *gin.Engine {
return router
}
func CreateTestBackupCleaner(billingService BillingService) *BackupCleaner {
return &BackupCleaner{
backupRepository,
storages.GetStorageService(),
backups_config.GetBackupConfigService(),
billingService,
encryption.GetFieldEncryptor(),
logger.GetLogger(),
[]backups_core.BackupRemoveListener{},
sync.Once{},
atomic.Bool{},
}
}
func CreateTestBackuperNode() *BackuperNode {
return &BackuperNode{
databaseService: databases.GetDatabaseService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
workspaceService: workspaces_services.GetWorkspaceService(),
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
notificationSender: notifiers.GetNotifierService(),
backupCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
logger: logger.GetLogger(),
createBackupUseCase: usecases.GetCreateBackupUsecase(),
nodeID: uuid.New(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
databases.GetDatabaseService(),
encryption.GetFieldEncryptor(),
workspaces_services.GetWorkspaceService(),
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
taskCancelManager,
backupNodesRegistry,
logger.GetLogger(),
usecases.GetCreateBackupUsecase(),
uuid.New(),
time.Time{},
sync.Once{},
atomic.Bool{},
}
}
func CreateTestBackuperNodeWithUseCase(useCase backups_core.CreateBackupUsecase) *BackuperNode {
return &BackuperNode{
databaseService: databases.GetDatabaseService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
workspaceService: workspaces_services.GetWorkspaceService(),
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
notificationSender: notifiers.GetNotifierService(),
backupCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
logger: logger.GetLogger(),
createBackupUseCase: useCase,
nodeID: uuid.New(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
databases.GetDatabaseService(),
encryption.GetFieldEncryptor(),
workspaces_services.GetWorkspaceService(),
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
taskCancelManager,
backupNodesRegistry,
logger.GetLogger(),
useCase,
uuid.New(),
time.Time{},
sync.Once{},
atomic.Bool{},
}
}
func CreateTestScheduler() *BackupsScheduler {
func CreateTestScheduler(billingService BillingService) *BackupsScheduler {
return &BackupsScheduler{
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
taskCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
lastBackupTime: time.Now().UTC(),
logger: logger.GetLogger(),
backupToNodeRelations: make(map[uuid.UUID]BackupToNodeRelation),
backuperNode: CreateTestBackuperNode(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
backupRepository,
backups_config.GetBackupConfigService(),
taskCancelManager,
backupNodesRegistry,
databases.GetDatabaseService(),
billingService,
time.Now().UTC(),
logger.GetLogger(),
make(map[uuid.UUID]BackupToNodeRelation),
CreateTestBackuperNode(),
sync.Once{},
atomic.Bool{},
}
}

View File

@@ -1263,7 +1263,7 @@ func Test_MakeBackup_VerifyBackupAndMetadataFilesExistInStorage(t *testing.T) {
backuperCancel := backuping.StartBackuperNodeForTest(t, backuperNode)
defer backuping.StopBackuperNodeForTest(t, backuperCancel, backuperNode)
scheduler := backuping.CreateTestScheduler()
scheduler := backuping.CreateTestScheduler(nil)
schedulerCancel := backuping.StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
@@ -1838,7 +1838,7 @@ func Test_DeleteBackup_RemovesBackupAndMetadataFilesFromDisk(t *testing.T) {
backuperCancel := backuping.StartBackuperNodeForTest(t, backuperNode)
defer backuping.StopBackuperNodeForTest(t, backuperCancel, backuperNode)
scheduler := backuping.CreateTestScheduler()
scheduler := backuping.CreateTestScheduler(nil)
schedulerCancel := backuping.StartSchedulerForTest(t, scheduler)
defer schedulerCancel()

View File

@@ -16,7 +16,6 @@ type BackupConfigController struct {
func (c *BackupConfigController) RegisterRoutes(router *gin.RouterGroup) {
router.POST("/backup-configs/save", c.SaveBackupConfig)
router.GET("/backup-configs/database/:id/plan", c.GetDatabasePlan)
router.GET("/backup-configs/database/:id", c.GetBackupConfigByDbID)
router.GET("/backup-configs/storage/:id/is-using", c.IsStorageUsing)
router.GET("/backup-configs/storage/:id/databases-count", c.CountDatabasesForStorage)
@@ -93,39 +92,6 @@ func (c *BackupConfigController) GetBackupConfigByDbID(ctx *gin.Context) {
ctx.JSON(http.StatusOK, backupConfig)
}
// GetDatabasePlan
// @Summary Get database plan by database ID
// @Description Get the plan limits for a specific database (max backup size, max total size, max storage period)
// @Tags backup-configs
// @Produce json
// @Param id path string true "Database ID"
// @Success 200 {object} plans.DatabasePlan
// @Failure 400 {object} map[string]string "Invalid database ID"
// @Failure 401 {object} map[string]string "User not authenticated"
// @Failure 404 {object} map[string]string "Database not found or access denied"
// @Router /backup-configs/database/{id}/plan [get]
func (c *BackupConfigController) GetDatabasePlan(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
return
}
plan, err := c.backupConfigService.GetDatabasePlan(user, id)
if err != nil {
ctx.JSON(http.StatusNotFound, gin.H{"error": "database plan not found"})
return
}
ctx.JSON(http.StatusOK, plan)
}
// IsStorageUsing
// @Summary Check if storage is being used
// @Description Check if a storage is currently being used by any backup configuration

View File

@@ -17,14 +17,12 @@ import (
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/features/storages"
local_storage "databasus-backend/internal/features/storages/models/local"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/storage"
"databasus-backend/internal/util/period"
test_utils "databasus-backend/internal/util/testing"
"databasus-backend/internal/util/tools"
@@ -326,218 +324,13 @@ func Test_GetBackupConfigByDbID_ReturnsDefaultConfigForNewDatabase(t *testing.T)
&response,
)
var plan plans.DatabasePlan
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
"Bearer "+owner.Token,
http.StatusOK,
&plan,
)
assert.Equal(t, database.ID, response.DatabaseID)
assert.False(t, response.IsBackupsEnabled)
assert.Equal(t, plan.MaxStoragePeriod, response.RetentionTimePeriod)
assert.Equal(t, plan.MaxBackupSizeMB, response.MaxBackupSizeMB)
assert.Equal(t, plan.MaxBackupsTotalSizeMB, response.MaxBackupsTotalSizeMB)
assert.True(t, response.IsRetryIfFailed)
assert.Equal(t, 3, response.MaxFailedTriesCount)
assert.NotNil(t, response.BackupInterval)
}
func Test_GetDatabasePlan_ForNewDatabase_PlanAlwaysReturned(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
var response plans.DatabasePlan
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
"Bearer "+owner.Token,
http.StatusOK,
&response,
)
assert.Equal(t, database.ID, response.DatabaseID)
assert.NotNil(t, response.MaxBackupSizeMB)
assert.NotNil(t, response.MaxBackupsTotalSizeMB)
assert.NotEmpty(t, response.MaxStoragePeriod)
}
func Test_SaveBackupConfig_WhenPlanLimitsAreAdjusted_ValidationEnforced(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Get plan via API (triggers auto-creation)
var plan plans.DatabasePlan
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
"Bearer "+owner.Token,
http.StatusOK,
&plan,
)
assert.Equal(t, database.ID, plan.DatabaseID)
// Adjust plan limits directly in database to fixed restrictive values
err := storage.GetDb().Model(&plans.DatabasePlan{}).
Where("database_id = ?", database.ID).
Updates(map[string]any{
"max_backup_size_mb": 100,
"max_backups_total_size_mb": 1000,
"max_storage_period": period.PeriodMonth,
}).Error
assert.NoError(t, err)
// Test 1: Try to save backup config with exceeded backup size limit
timeOfDay := "04:00"
backupConfigExceededSize := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 200, // Exceeds limit of 100
MaxBackupsTotalSizeMB: 800,
}
respExceededSize := test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigExceededSize,
http.StatusBadRequest,
)
assert.Contains(t, string(respExceededSize.Body), "max backup size exceeds plan limit")
// Test 2: Try to save backup config with exceeded total size limit
backupConfigExceededTotal := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 50,
MaxBackupsTotalSizeMB: 2000, // Exceeds limit of 1000
}
respExceededTotal := test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigExceededTotal,
http.StatusBadRequest,
)
assert.Contains(t, string(respExceededTotal.Body), "max total backups size exceeds plan limit")
// Test 3: Try to save backup config with exceeded storage period limit
backupConfigExceededPeriod := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodYear, // Exceeds limit of Month
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 80,
MaxBackupsTotalSizeMB: 800,
}
respExceededPeriod := test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigExceededPeriod,
http.StatusBadRequest,
)
assert.Contains(t, string(respExceededPeriod.Body), "storage period exceeds plan limit")
// Test 4: Save backup config within all limits - should succeed
backupConfigValid := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek, // Within Month limit
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 80, // Within 100 limit
MaxBackupsTotalSizeMB: 800, // Within 1000 limit
}
var responseValid BackupConfig
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigValid,
http.StatusOK,
&responseValid,
)
assert.Equal(t, database.ID, responseValid.DatabaseID)
assert.Equal(t, int64(80), responseValid.MaxBackupSizeMB)
assert.Equal(t, int64(800), responseValid.MaxBackupsTotalSizeMB)
assert.Equal(t, period.PeriodWeek, responseValid.RetentionTimePeriod)
}
func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string

View File

@@ -6,7 +6,6 @@ import (
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/features/storages"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/logger"
@@ -20,7 +19,6 @@ var (
storages.GetStorageService(),
notifiers.GetNotifierService(),
workspaces_services.GetWorkspaceService(),
plans.GetDatabasePlanService(),
nil,
}
)

View File

@@ -9,7 +9,6 @@ import (
"databasus-backend/internal/config"
"databasus-backend/internal/features/intervals"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/features/storages"
"databasus-backend/internal/util/period"
)
@@ -42,11 +41,6 @@ type BackupConfig struct {
MaxFailedTriesCount int `json:"maxFailedTriesCount" gorm:"column:max_failed_tries_count;type:int;not null"`
Encryption BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
// MaxBackupSizeMB limits individual backup size. 0 = unlimited.
MaxBackupSizeMB int64 `json:"maxBackupSizeMb" gorm:"column:max_backup_size_mb;type:int;not null"`
// MaxBackupsTotalSizeMB limits total size of all backups. 0 = unlimited.
MaxBackupsTotalSizeMB int64 `json:"maxBackupsTotalSizeMb" gorm:"column:max_backups_total_size_mb;type:int;not null"`
}
func (h *BackupConfig) TableName() string {
@@ -86,12 +80,12 @@ func (b *BackupConfig) AfterFind(tx *gorm.DB) error {
return nil
}
func (b *BackupConfig) Validate(plan *plans.DatabasePlan) error {
func (b *BackupConfig) Validate() error {
if b.BackupIntervalID == uuid.Nil && b.BackupInterval == nil {
return errors.New("backup interval is required")
}
if err := b.validateRetentionPolicy(plan); err != nil {
if err := b.validateRetentionPolicy(); err != nil {
return err
}
@@ -110,67 +104,38 @@ func (b *BackupConfig) Validate(plan *plans.DatabasePlan) error {
}
}
if b.MaxBackupSizeMB < 0 {
return errors.New("max backup size must be non-negative")
}
if b.MaxBackupsTotalSizeMB < 0 {
return errors.New("max backups total size must be non-negative")
}
if plan.MaxBackupSizeMB > 0 {
if b.MaxBackupSizeMB == 0 || b.MaxBackupSizeMB > plan.MaxBackupSizeMB {
return errors.New("max backup size exceeds plan limit")
}
}
if plan.MaxBackupsTotalSizeMB > 0 {
if b.MaxBackupsTotalSizeMB == 0 ||
b.MaxBackupsTotalSizeMB > plan.MaxBackupsTotalSizeMB {
return errors.New("max total backups size exceeds plan limit")
}
}
return nil
}
func (b *BackupConfig) Copy(newDatabaseID uuid.UUID) *BackupConfig {
return &BackupConfig{
DatabaseID: newDatabaseID,
IsBackupsEnabled: b.IsBackupsEnabled,
RetentionPolicyType: b.RetentionPolicyType,
RetentionTimePeriod: b.RetentionTimePeriod,
RetentionCount: b.RetentionCount,
RetentionGfsHours: b.RetentionGfsHours,
RetentionGfsDays: b.RetentionGfsDays,
RetentionGfsWeeks: b.RetentionGfsWeeks,
RetentionGfsMonths: b.RetentionGfsMonths,
RetentionGfsYears: b.RetentionGfsYears,
BackupIntervalID: uuid.Nil,
BackupInterval: b.BackupInterval.Copy(),
StorageID: b.StorageID,
SendNotificationsOn: b.SendNotificationsOn,
IsRetryIfFailed: b.IsRetryIfFailed,
MaxFailedTriesCount: b.MaxFailedTriesCount,
Encryption: b.Encryption,
MaxBackupSizeMB: b.MaxBackupSizeMB,
MaxBackupsTotalSizeMB: b.MaxBackupsTotalSizeMB,
DatabaseID: newDatabaseID,
IsBackupsEnabled: b.IsBackupsEnabled,
RetentionPolicyType: b.RetentionPolicyType,
RetentionTimePeriod: b.RetentionTimePeriod,
RetentionCount: b.RetentionCount,
RetentionGfsHours: b.RetentionGfsHours,
RetentionGfsDays: b.RetentionGfsDays,
RetentionGfsWeeks: b.RetentionGfsWeeks,
RetentionGfsMonths: b.RetentionGfsMonths,
RetentionGfsYears: b.RetentionGfsYears,
BackupIntervalID: uuid.Nil,
BackupInterval: b.BackupInterval.Copy(),
StorageID: b.StorageID,
SendNotificationsOn: b.SendNotificationsOn,
IsRetryIfFailed: b.IsRetryIfFailed,
MaxFailedTriesCount: b.MaxFailedTriesCount,
Encryption: b.Encryption,
}
}
func (b *BackupConfig) validateRetentionPolicy(plan *plans.DatabasePlan) error {
func (b *BackupConfig) validateRetentionPolicy() error {
switch b.RetentionPolicyType {
case RetentionPolicyTypeTimePeriod, "":
if b.RetentionTimePeriod == "" {
return errors.New("retention time period is required")
}
if plan.MaxStoragePeriod != period.PeriodForever {
if b.RetentionTimePeriod.CompareTo(plan.MaxStoragePeriod) > 0 {
return errors.New("storage period exceeds plan limit")
}
}
case RetentionPolicyTypeCount:
if b.RetentionCount <= 0 {
return errors.New("retention count must be greater than 0")

View File

@@ -6,248 +6,34 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
"databasus-backend/internal/features/intervals"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/util/period"
)
func Test_Validate_WhenRetentionTimePeriodIsWeekAndPlanAllowsMonth_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodWeek
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodMonth
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenRetentionTimePeriodIsYearAndPlanAllowsMonth_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodYear
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodMonth
err := config.Validate(plan)
assert.EqualError(t, err, "storage period exceeds plan limit")
}
func Test_Validate_WhenRetentionTimePeriodIsForeverAndPlanAllowsForever_ValidationPasses(
t *testing.T,
) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodForever
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodForever
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenRetentionTimePeriodIsForeverAndPlanAllowsYear_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodForever
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodYear
err := config.Validate(plan)
assert.EqualError(t, err, "storage period exceeds plan limit")
}
func Test_Validate_WhenRetentionTimePeriodEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodMonth
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodMonth
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenBackupSize100MBAndPlanAllows500MB_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 100
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 500
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenBackupSize500MBAndPlanAllows100MB_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 500
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 100
err := config.Validate(plan)
assert.EqualError(t, err, "max backup size exceeds plan limit")
}
func Test_Validate_WhenBackupSizeIsUnlimitedAndPlanAllowsUnlimited_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 0
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 0
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenBackupSizeIsUnlimitedAndPlanHas500MBLimit_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 0
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 500
err := config.Validate(plan)
assert.EqualError(t, err, "max backup size exceeds plan limit")
}
func Test_Validate_WhenBackupSizeEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 500
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 500
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenTotalSize1GBAndPlanAllows5GB_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 1000
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 5000
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenTotalSize5GBAndPlanAllows1GB_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 5000
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 1000
err := config.Validate(plan)
assert.EqualError(t, err, "max total backups size exceeds plan limit")
}
func Test_Validate_WhenTotalSizeIsUnlimitedAndPlanAllowsUnlimited_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 0
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 0
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenTotalSizeIsUnlimitedAndPlanHas1GBLimit_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 0
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 1000
err := config.Validate(plan)
assert.EqualError(t, err, "max total backups size exceeds plan limit")
}
func Test_Validate_WhenTotalSizeEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 5000
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 5000
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenAllLimitsAreUnlimitedInPlan_AnyConfigurationPasses(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodForever
config.MaxBackupSizeMB = 0
config.MaxBackupsTotalSizeMB = 0
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenMultipleLimitsExceeded_ValidationFailsWithFirstError(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodYear
config.MaxBackupSizeMB = 500
config.MaxBackupsTotalSizeMB = 5000
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodMonth
plan.MaxBackupSizeMB = 100
plan.MaxBackupsTotalSizeMB = 1000
err := config.Validate(plan)
assert.Error(t, err)
assert.EqualError(t, err, "storage period exceeds plan limit")
}
func Test_Validate_WhenConfigHasInvalidIntervalButPlanIsValid_ValidationFailsOnInterval(
t *testing.T,
) {
func Test_Validate_WhenIntervalIsMissing_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.BackupIntervalID = uuid.Nil
config.BackupInterval = nil
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.EqualError(t, err, "backup interval is required")
}
func Test_Validate_WhenIntervalIsMissing_ValidationFailsRegardlessOfPlan(t *testing.T) {
config := createValidBackupConfig()
config.BackupIntervalID = uuid.Nil
config.BackupInterval = nil
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "backup interval is required")
}
func Test_Validate_WhenRetryEnabledButMaxTriesIsZero_ValidationFailsRegardlessOfPlan(t *testing.T) {
func Test_Validate_WhenRetryEnabledButMaxTriesIsZero_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.IsRetryIfFailed = true
config.MaxFailedTriesCount = 0
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.EqualError(t, err, "max failed tries count must be greater than 0")
}
func Test_Validate_WhenEncryptionIsInvalid_ValidationFailsRegardlessOfPlan(t *testing.T) {
func Test_Validate_WhenEncryptionIsInvalid_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.Encryption = "INVALID"
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.EqualError(t, err, "encryption must be NONE or ENCRYPTED")
}
@@ -255,125 +41,16 @@ func Test_Validate_WhenRetentionTimePeriodIsEmpty_ValidationFails(t *testing.T)
config := createValidBackupConfig()
config.RetentionTimePeriod = ""
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.EqualError(t, err, "retention time period is required")
}
func Test_Validate_WhenMaxBackupSizeIsNegative_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = -100
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "max backup size must be non-negative")
}
func Test_Validate_WhenMaxTotalSizeIsNegative_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = -1000
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "max backups total size must be non-negative")
}
func Test_Validate_WhenPlanLimitsAreAtBoundary_ValidationWorks(t *testing.T) {
tests := []struct {
name string
configPeriod period.TimePeriod
planPeriod period.TimePeriod
configSize int64
planSize int64
configTotal int64
planTotal int64
shouldSucceed bool
}{
{
name: "all values just under limit",
configPeriod: period.PeriodWeek,
planPeriod: period.PeriodMonth,
configSize: 99,
planSize: 100,
configTotal: 999,
planTotal: 1000,
shouldSucceed: true,
},
{
name: "all values equal to limit",
configPeriod: period.PeriodMonth,
planPeriod: period.PeriodMonth,
configSize: 100,
planSize: 100,
configTotal: 1000,
planTotal: 1000,
shouldSucceed: true,
},
{
name: "period just over limit",
configPeriod: period.Period3Month,
planPeriod: period.PeriodMonth,
configSize: 100,
planSize: 100,
configTotal: 1000,
planTotal: 1000,
shouldSucceed: false,
},
{
name: "size just over limit",
configPeriod: period.PeriodMonth,
planPeriod: period.PeriodMonth,
configSize: 101,
planSize: 100,
configTotal: 1000,
planTotal: 1000,
shouldSucceed: false,
},
{
name: "total size just over limit",
configPeriod: period.PeriodMonth,
planPeriod: period.PeriodMonth,
configSize: 100,
planSize: 100,
configTotal: 1001,
planTotal: 1000,
shouldSucceed: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = tt.configPeriod
config.MaxBackupSizeMB = tt.configSize
config.MaxBackupsTotalSizeMB = tt.configTotal
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = tt.planPeriod
plan.MaxBackupSizeMB = tt.planSize
plan.MaxBackupsTotalSizeMB = tt.planTotal
err := config.Validate(plan)
if tt.shouldSucceed {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
})
}
}
func Test_Validate_WhenPolicyTypeIsCount_RequiresPositiveCount(t *testing.T) {
config := createValidBackupConfig()
config.RetentionPolicyType = RetentionPolicyTypeCount
config.RetentionCount = 0
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.EqualError(t, err, "retention count must be greater than 0")
}
@@ -382,9 +59,7 @@ func Test_Validate_WhenPolicyTypeIsCount_WithPositiveCount_ValidationPasses(t *t
config.RetentionPolicyType = RetentionPolicyTypeCount
config.RetentionCount = 10
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.NoError(t, err)
}
@@ -396,9 +71,7 @@ func Test_Validate_WhenPolicyTypeIsGFS_RequiresAtLeastOneField(t *testing.T) {
config.RetentionGfsMonths = 0
config.RetentionGfsYears = 0
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.EqualError(t, err, "at least one GFS retention field must be greater than 0")
}
@@ -407,9 +80,7 @@ func Test_Validate_WhenPolicyTypeIsGFS_WithOnlyHours_ValidationPasses(t *testing
config.RetentionPolicyType = RetentionPolicyTypeGFS
config.RetentionGfsHours = 24
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.NoError(t, err)
}
@@ -418,9 +89,7 @@ func Test_Validate_WhenPolicyTypeIsGFS_WithOnlyDays_ValidationPasses(t *testing.
config.RetentionPolicyType = RetentionPolicyTypeGFS
config.RetentionGfsDays = 7
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.NoError(t, err)
}
@@ -433,9 +102,7 @@ func Test_Validate_WhenPolicyTypeIsGFS_WithAllFields_ValidationPasses(t *testing
config.RetentionGfsMonths = 12
config.RetentionGfsYears = 3
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.NoError(t, err)
}
@@ -443,35 +110,59 @@ func Test_Validate_WhenPolicyTypeIsInvalid_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.RetentionPolicyType = "INVALID"
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.EqualError(t, err, "invalid retention policy type")
}
func Test_Validate_WhenCloudAndEncryptionIsNotEncrypted_ValidationFails(t *testing.T) {
enableCloud(t)
backupConfig := createValidBackupConfig()
backupConfig.Encryption = BackupEncryptionNone
err := backupConfig.Validate()
assert.EqualError(t, err, "encryption is mandatory for cloud storage")
}
func Test_Validate_WhenCloudAndEncryptionIsEncrypted_ValidationPasses(t *testing.T) {
enableCloud(t)
backupConfig := createValidBackupConfig()
backupConfig.Encryption = BackupEncryptionEncrypted
err := backupConfig.Validate()
assert.NoError(t, err)
}
func Test_Validate_WhenNotCloudAndEncryptionIsNotEncrypted_ValidationPasses(t *testing.T) {
backupConfig := createValidBackupConfig()
backupConfig.Encryption = BackupEncryptionNone
err := backupConfig.Validate()
assert.NoError(t, err)
}
func enableCloud(t *testing.T) {
t.Helper()
config.GetEnv().IsCloud = true
t.Cleanup(func() {
config.GetEnv().IsCloud = false
})
}
func createValidBackupConfig() *BackupConfig {
intervalID := uuid.New()
return &BackupConfig{
DatabaseID: uuid.New(),
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodMonth,
BackupIntervalID: intervalID,
BackupInterval: &intervals.Interval{ID: intervalID},
SendNotificationsOn: []BackupNotificationType{},
IsRetryIfFailed: false,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 100,
MaxBackupsTotalSizeMB: 1000,
}
}
func createUnlimitedPlan() *plans.DatabasePlan {
return &plans.DatabasePlan{
DatabaseID: uuid.New(),
MaxBackupSizeMB: 0,
MaxBackupsTotalSizeMB: 0,
MaxStoragePeriod: period.PeriodForever,
return &BackupConfig{
DatabaseID: uuid.New(),
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodMonth,
BackupIntervalID: intervalID,
BackupInterval: &intervals.Interval{ID: intervalID},
SendNotificationsOn: []BackupNotificationType{},
IsRetryIfFailed: false,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
}
}

View File

@@ -8,10 +8,10 @@ import (
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/features/storages"
users_models "databasus-backend/internal/features/users/models"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/period"
)
type BackupConfigService struct {
@@ -20,7 +20,6 @@ type BackupConfigService struct {
storageService *storages.StorageService
notifierService *notifiers.NotifierService
workspaceService *workspaces_services.WorkspaceService
databasePlanService *plans.DatabasePlanService
dbStorageChangeListener BackupConfigStorageChangeListener
}
@@ -46,12 +45,7 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
user *users_models.User,
backupConfig *BackupConfig,
) (*BackupConfig, error) {
plan, err := s.databasePlanService.GetDatabasePlan(backupConfig.DatabaseID)
if err != nil {
return nil, err
}
if err := backupConfig.Validate(plan); err != nil {
if err := backupConfig.Validate(); err != nil {
return nil, err
}
@@ -88,12 +82,7 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
func (s *BackupConfigService) SaveBackupConfig(
backupConfig *BackupConfig,
) (*BackupConfig, error) {
plan, err := s.databasePlanService.GetDatabasePlan(backupConfig.DatabaseID)
if err != nil {
return nil, err
}
if err := backupConfig.Validate(plan); err != nil {
if err := backupConfig.Validate(); err != nil {
return nil, err
}
@@ -131,18 +120,6 @@ func (s *BackupConfigService) GetBackupConfigByDbIdWithAuth(
return s.GetBackupConfigByDbId(databaseID)
}
func (s *BackupConfigService) GetDatabasePlan(
user *users_models.User,
databaseID uuid.UUID,
) (*plans.DatabasePlan, error) {
_, err := s.databaseService.GetDatabase(user, databaseID)
if err != nil {
return nil, err
}
return s.databasePlanService.GetDatabasePlan(databaseID)
}
func (s *BackupConfigService) GetBackupConfigByDbId(
databaseID uuid.UUID,
) (*BackupConfig, error) {
@@ -322,20 +299,13 @@ func (s *BackupConfigService) TransferDatabaseToWorkspace(
func (s *BackupConfigService) initializeDefaultConfig(
databaseID uuid.UUID,
) error {
plan, err := s.databasePlanService.GetDatabasePlan(databaseID)
if err != nil {
return err
}
timeOfDay := "04:00"
_, err = s.backupConfigRepository.Save(&BackupConfig{
DatabaseID: databaseID,
IsBackupsEnabled: false,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: plan.MaxStoragePeriod,
MaxBackupSizeMB: plan.MaxBackupSizeMB,
MaxBackupsTotalSizeMB: plan.MaxBackupsTotalSizeMB,
_, err := s.backupConfigRepository.Save(&BackupConfig{
DatabaseID: databaseID,
IsBackupsEnabled: false,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.Period3Month,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,

View File

@@ -0,0 +1,305 @@
package billing
import (
"errors"
"net/http"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
users_middleware "databasus-backend/internal/features/users/middleware"
"databasus-backend/internal/util/logger"
)
type BillingController struct {
billingService *BillingService
}
func (c *BillingController) RegisterRoutes(router *gin.RouterGroup) {
billing := router.Group("/billing")
billing.POST("/subscription", c.CreateSubscription)
billing.POST("/subscription/change-storage", c.ChangeSubscriptionStorage)
billing.POST("/subscription/portal/:subscription_id", c.GetPortalSession)
billing.GET("/subscription/events/:subscription_id", c.GetSubscriptionEvents)
billing.GET("/subscription/invoices/:subscription_id", c.GetInvoices)
billing.GET("/subscription/:database_id", c.GetSubscription)
}
// CreateSubscription
// @Summary Create a new subscription
// @Description Create a billing subscription for the specified database with the given storage
// @Tags billing
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body CreateSubscriptionRequest true "Subscription creation data"
// @Success 200 {object} CreateSubscriptionResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /billing/subscription [post]
func (c *BillingController) CreateSubscription(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(401, gin.H{"error": "User not authenticated"})
return
}
var request CreateSubscriptionRequest
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(400, gin.H{"error": "Invalid request"})
return
}
log := logger.GetLogger().With(
"request_id", uuid.New(),
"database_id", request.DatabaseID,
"user_id", user.ID,
)
transactionID, err := c.billingService.CreateSubscription(
log,
user,
request.DatabaseID,
request.StorageGB,
)
if err != nil {
log.Error("Failed to create subscription", "error", err)
ctx.JSON(500, gin.H{"error": "Failed to create subscription"})
return
}
ctx.JSON(200, CreateSubscriptionResponse{PaddleTransactionID: transactionID})
}
// ChangeSubscriptionStorage
// @Summary Change subscription storage
// @Description Update the storage allocation for an existing subscription
// @Tags billing
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body ChangeStorageRequest true "New storage configuration"
// @Success 200 {object} ChangeStorageResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /billing/subscription/change-storage [post]
func (c *BillingController) ChangeSubscriptionStorage(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(401, gin.H{"error": "User not authenticated"})
return
}
var request ChangeStorageRequest
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(400, gin.H{"error": "Invalid request"})
return
}
log := logger.GetLogger().With(
"request_id", uuid.New(),
"database_id", request.DatabaseID,
"user_id", user.ID,
)
result, err := c.billingService.ChangeSubscriptionStorage(log, user, request.DatabaseID, request.StorageGB)
if err != nil {
log.Error("Failed to change subscription storage", "error", err)
ctx.JSON(500, gin.H{"error": "Failed to change subscription storage"})
return
}
ctx.JSON(200, ChangeStorageResponse{
ApplyMode: result.ApplyMode,
CurrentGB: result.CurrentGB,
PendingGB: result.PendingGB,
})
}
// GetPortalSession
// @Summary Get billing portal session
// @Description Generate a portal session URL for managing the subscription
// @Tags billing
// @Produce json
// @Security BearerAuth
// @Param subscription_id path string true "Subscription ID"
// @Success 200 {object} GetPortalSessionResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /billing/subscription/portal/{subscription_id} [post]
func (c *BillingController) GetPortalSession(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(401, gin.H{"error": "User not authenticated"})
return
}
subscriptionID := ctx.Param("subscription_id")
if subscriptionID == "" {
ctx.JSON(400, gin.H{"error": "Subscription ID is required"})
return
}
log := logger.GetLogger().With(
"request_id", uuid.New(),
"subscription_id", subscriptionID,
"user_id", user.ID,
)
url, err := c.billingService.GetPortalURL(log, user, uuid.MustParse(subscriptionID))
if err != nil {
log.Error("Failed to get portal session", "error", err)
ctx.JSON(500, gin.H{"error": "Failed to get portal session"})
return
}
ctx.JSON(200, GetPortalSessionResponse{PortalURL: url})
}
// GetSubscriptionEvents
// @Summary Get subscription events
// @Description Retrieve the event history for a subscription
// @Tags billing
// @Produce json
// @Security BearerAuth
// @Param subscription_id path string true "Subscription ID"
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Success 200 {object} GetSubscriptionEventsResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /billing/subscription/events/{subscription_id} [get]
func (c *BillingController) GetSubscriptionEvents(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(401, gin.H{"error": "User not authenticated"})
return
}
subscriptionID, err := uuid.Parse(ctx.Param("subscription_id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid subscription ID"})
return
}
var request PaginatedRequest
if err := ctx.ShouldBindQuery(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
log := logger.GetLogger().With(
"request_id", uuid.New(),
"subscription_id", subscriptionID,
"user_id", user.ID,
)
response, err := c.billingService.GetSubscriptionEvents(log, user, subscriptionID, request.Limit, request.Offset)
if err != nil {
log.Error("Failed to get subscription events", "error", err)
ctx.JSON(500, gin.H{"error": "Failed to get subscription events"})
return
}
ctx.JSON(200, response)
}
// GetInvoices
// @Summary Get subscription invoices
// @Description Retrieve all invoices for a subscription
// @Tags billing
// @Produce json
// @Security BearerAuth
// @Param subscription_id path string true "Subscription ID"
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Success 200 {object} GetInvoicesResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /billing/subscription/invoices/{subscription_id} [get]
func (c *BillingController) GetInvoices(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(401, gin.H{"error": "User not authenticated"})
return
}
subscriptionID, err := uuid.Parse(ctx.Param("subscription_id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid subscription ID"})
return
}
var request PaginatedRequest
if err := ctx.ShouldBindQuery(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
log := logger.GetLogger().With(
"request_id", uuid.New(),
"subscription_id", subscriptionID,
"user_id", user.ID,
)
response, err := c.billingService.GetSubscriptionInvoices(log, user, subscriptionID, request.Limit, request.Offset)
if err != nil {
log.Error("Failed to get invoices", "error", err)
ctx.JSON(500, gin.H{"error": "Failed to get invoices"})
return
}
ctx.JSON(200, response)
}
// GetSubscription
// @Summary Get subscription by database
// @Description Retrieve the subscription associated with a specific database
// @Tags billing
// @Produce json
// @Security BearerAuth
// @Param database_id path string true "Database ID"
// @Success 200 {object} billing_models.Subscription
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /billing/subscription/{database_id} [get]
func (c *BillingController) GetSubscription(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(401, gin.H{"error": "User not authenticated"})
return
}
databaseID, err := uuid.Parse(ctx.Param("database_id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid database ID"})
return
}
log := logger.GetLogger().With(
"request_id", uuid.New(),
"database_id", databaseID,
"user_id", user.ID,
)
subscription, err := c.billingService.GetSubscriptionByDatabaseID(log, user, databaseID)
if err != nil {
if errors.Is(err, ErrSubscriptionNotFound) {
ctx.JSON(http.StatusNotFound, gin.H{"error": "Subscription not found"})
return
}
log.Error("failed to get subscription", "error", err)
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get subscription"})
return
}
ctx.JSON(200, subscription)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,49 @@
package billing
import (
"sync"
"sync/atomic"
billing_repositories "databasus-backend/internal/features/billing/repositories"
"databasus-backend/internal/features/databases"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/logger"
)
var (
billingService = &BillingService{
&billing_repositories.SubscriptionRepository{},
&billing_repositories.SubscriptionEventRepository{},
&billing_repositories.InvoiceRepository{},
nil, // billing provider will be set later to avoid circular dependency
workspaces_services.GetWorkspaceService(),
*databases.GetDatabaseService(),
sync.Once{},
atomic.Bool{},
}
billingController = &BillingController{billingService}
setupOnce sync.Once
isSetup atomic.Bool
)
func GetBillingService() *BillingService {
return billingService
}
func GetBillingController() *BillingController {
return billingController
}
func SetupDependencies() {
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
databases.GetDatabaseService().AddDbCreationListener(billingService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("billing.SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -0,0 +1,67 @@
package billing
import (
"github.com/google/uuid"
billing_models "databasus-backend/internal/features/billing/models"
)
type CreateSubscriptionRequest struct {
DatabaseID uuid.UUID `json:"databaseId" validate:"required"`
StorageGB int `json:"storageGb" validate:"required,min=1"`
}
type CreateSubscriptionResponse struct {
PaddleTransactionID string `json:"paddleTransactionId"`
}
type ChangeStorageApplyMode string
const (
ChangeStorageApplyImmediate ChangeStorageApplyMode = "immediate"
ChangeStorageApplyNextCycle ChangeStorageApplyMode = "next_cycle"
)
type ChangeStorageRequest struct {
DatabaseID uuid.UUID `json:"databaseId" validate:"required"`
StorageGB int `json:"storageGb" validate:"required,min=1"`
}
type ChangeStorageResponse struct {
ApplyMode ChangeStorageApplyMode `json:"applyMode"`
CurrentGB int `json:"currentGb"`
PendingGB *int `json:"pendingGb,omitempty"`
}
type PortalResponse struct {
URL string `json:"url"`
}
type ChangeStorageResult struct {
ApplyMode ChangeStorageApplyMode
CurrentGB int
PendingGB *int
}
type GetPortalSessionResponse struct {
PortalURL string `json:"url"`
}
type PaginatedRequest struct {
Limit int `form:"limit" json:"limit"`
Offset int `form:"offset" json:"offset"`
}
type GetSubscriptionEventsResponse struct {
Events []*billing_models.SubscriptionEvent `json:"events"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}
type GetInvoicesResponse struct {
Invoices []*billing_models.Invoice `json:"invoices"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}

View File

@@ -0,0 +1,15 @@
package billing
import "errors"
var (
ErrInvalidStorage = errors.New("storage must be between 20 and 10000 GB")
ErrAlreadySubscribed = errors.New("database already has an active subscription")
ErrExceedsUsage = errors.New("cannot downgrade below current storage usage")
ErrNoChange = errors.New("requested storage is the same as current")
ErrDuplicate = errors.New("duplicate event already processed")
ErrProviderUnavailable = errors.New("payment provider unavailable")
ErrNoActiveSubscription = errors.New("no active subscription for this database")
ErrAccessDenied = errors.New("user does not have access to this database")
ErrSubscriptionNotFound = errors.New("subscription not found")
)

View File

@@ -0,0 +1,24 @@
package billing_models
import (
"time"
"github.com/google/uuid"
)
type Invoice struct {
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
SubscriptionID uuid.UUID `json:"subscriptionId" gorm:"column:subscription_id;type:uuid;not null"`
ProviderInvoiceID string `json:"providerInvoiceId" gorm:"column:provider_invoice_id;type:text;not null"`
AmountCents int64 `json:"amountCents" gorm:"column:amount_cents;type:bigint;not null"`
StorageGB int `json:"storageGb" gorm:"column:storage_gb;type:int;not null"`
PeriodStart time.Time `json:"periodStart" gorm:"column:period_start;type:timestamptz;not null"`
PeriodEnd time.Time `json:"periodEnd" gorm:"column:period_end;type:timestamptz;not null"`
Status InvoiceStatus `json:"status" gorm:"column:status;type:text;not null"`
PaidAt *time.Time `json:"paidAt,omitempty" gorm:"column:paid_at;type:timestamptz"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;type:timestamptz;not null"`
}
func (Invoice) TableName() string {
return "invoices"
}

View File

@@ -0,0 +1,11 @@
package billing_models
type InvoiceStatus string
const (
InvoiceStatusPending InvoiceStatus = "pending"
InvoiceStatusPaid InvoiceStatus = "paid"
InvoiceStatusFailed InvoiceStatus = "failed"
InvoiceStatusRefunded InvoiceStatus = "refunded"
InvoiceStatusDisputed InvoiceStatus = "disputed"
)

View File

@@ -0,0 +1,72 @@
package billing_models
import (
"time"
"github.com/google/uuid"
"databasus-backend/internal/config"
)
type Subscription struct {
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;not null"`
Status SubscriptionStatus `json:"status" gorm:"column:status;type:text;not null"`
StorageGB int `json:"storageGb" gorm:"column:storage_gb;type:int;not null"`
PendingStorageGB *int `json:"pendingStorageGb,omitempty" gorm:"column:pending_storage_gb;type:int"`
CurrentPeriodStart time.Time `json:"currentPeriodStart" gorm:"column:current_period_start;type:timestamptz;not null"`
CurrentPeriodEnd time.Time `json:"currentPeriodEnd" gorm:"column:current_period_end;type:timestamptz;not null"`
CanceledAt *time.Time `json:"canceledAt,omitempty" gorm:"column:canceled_at;type:timestamptz"`
DataRetentionGracePeriodUntil *time.Time `json:"dataRetentionGracePeriodUntil,omitempty" gorm:"column:data_retention_grace_period_until;type:timestamptz"`
ProviderName *string `json:"providerName,omitempty" gorm:"column:provider_name;type:text"`
ProviderSubID *string `json:"providerSubId,omitempty" gorm:"column:provider_sub_id;type:text"`
ProviderCustomerID *string `json:"providerCustomerId,omitempty" gorm:"column:provider_customer_id;type:text"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;type:timestamptz;not null"`
UpdatedAt time.Time `json:"updatedAt" gorm:"column:updated_at;type:timestamptz;not null"`
}
func (Subscription) TableName() string {
return "subscriptions"
}
func (s *Subscription) PriceCents() int64 {
return int64(s.StorageGB) * config.GetEnv().PricePerGBCents
}
// CanCreateNewBackups - whether it is allowed to create new backups
// by scheduler or for user manually. Clarification: in grace period
// user can download, delete and restore backups, but cannot create new ones
func (s *Subscription) CanCreateNewBackups() bool {
switch s.Status {
case StatusActive, StatusPastDue:
return true
case StatusTrial, StatusCanceled:
return time.Now().Before(s.CurrentPeriodEnd)
case StatusExpired:
return false
default:
panic("unknown subscription status")
}
}
func (s *Subscription) GetBackupsStorageGB() int {
switch s.Status {
case StatusActive, StatusPastDue, StatusCanceled:
return s.StorageGB
case StatusTrial:
if time.Now().Before(s.CurrentPeriodEnd) {
return s.StorageGB
}
return 0
case StatusExpired:
return 0
default:
panic("unknown subscription status")
}
}

View File

@@ -0,0 +1,25 @@
package billing_models
import (
"time"
"github.com/google/uuid"
)
type SubscriptionEvent struct {
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
SubscriptionID uuid.UUID `json:"subscriptionId" gorm:"column:subscription_id;type:uuid;not null"`
ProviderEventID *string `json:"providerEventId,omitempty" gorm:"column:provider_event_id;type:text"`
Type SubscriptionEventType `json:"type" gorm:"column:type;type:text;not null"`
OldStorageGB *int `json:"oldStorageGb,omitempty" gorm:"column:old_storage_gb;type:int"`
NewStorageGB *int `json:"newStorageGb,omitempty" gorm:"column:new_storage_gb;type:int"`
OldStatus *SubscriptionStatus `json:"oldStatus,omitempty" gorm:"column:old_status;type:text"`
NewStatus *SubscriptionStatus `json:"newStatus,omitempty" gorm:"column:new_status;type:text"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;type:timestamptz;not null"`
}
func (SubscriptionEvent) TableName() string {
return "subscription_events"
}

View File

@@ -0,0 +1,17 @@
package billing_models
type SubscriptionEventType string
const (
EventCreated SubscriptionEventType = "subscription.created"
EventUpgraded SubscriptionEventType = "subscription.upgraded"
EventDowngraded SubscriptionEventType = "subscription.downgraded"
EventNewBillingCycleStarted SubscriptionEventType = "subscription.new_billing_cycle_started"
EventCanceled SubscriptionEventType = "subscription.canceled"
EventReactivated SubscriptionEventType = "subscription.reactivated"
EventExpired SubscriptionEventType = "subscription.expired"
EventPastDue SubscriptionEventType = "subscription.past_due"
EventRecoveredFromPastDue SubscriptionEventType = "subscription.recovered_from_past_due"
EventRefund SubscriptionEventType = "payment.refund"
EventDispute SubscriptionEventType = "payment.dispute"
)

View File

@@ -0,0 +1,11 @@
package billing_models
type SubscriptionStatus string
const (
StatusTrial SubscriptionStatus = "trial" // trial period (~24h after DB creation)
StatusActive SubscriptionStatus = "active" // paid, everything works
StatusPastDue SubscriptionStatus = "past_due" // payment failed, trying to charge again, but everything still works
StatusCanceled SubscriptionStatus = "canceled" // subscription canceled by user or after past_due (grace period is active)
StatusExpired SubscriptionStatus = "expired" // grace period ended, data marked for deletion, can come from canceled and trial
)

View File

@@ -0,0 +1,22 @@
package billing_models
import (
"time"
"github.com/google/uuid"
)
type WebhookEvent struct {
RequestID uuid.UUID
ProviderEventID string
DatabaseID *uuid.UUID
Type WebhookEventType
ProviderSubscriptionID string
ProviderCustomerID string
ProviderInvoiceID string
QuantityGB int
Status SubscriptionStatus
PeriodStart *time.Time
PeriodEnd *time.Time
AmountCents int64
}

View File

@@ -0,0 +1,13 @@
package billing_models
type WebhookEventType string
const (
WHEventSubscriptionCreated WebhookEventType = "subscription.created"
WHEventSubscriptionUpdated WebhookEventType = "subscription.updated"
WHEventSubscriptionCanceled WebhookEventType = "subscription.canceled"
WHEventSubscriptionPastDue WebhookEventType = "subscription.past_due"
WHEventSubscriptionReactivated WebhookEventType = "subscription.reactivated"
WHEventPaymentSucceeded WebhookEventType = "payment.succeeded"
WHEventSubscriptionDisputeCreated WebhookEventType = "dispute.created"
)

View File

@@ -0,0 +1,5 @@
**Paddle hints:**
- **max_quantity on price:** Paddle limits `quantity` on a price to 100 by default. You need to explicitly set the range (`quantity: {minimum: 20, maximum: 10000}`) when creating a price via API or dashboard. Otherwise requests with quantity > 100 will return an error.
- **Full items list on update:** Unlike Stripe, Paddle requires sending **all** subscription items in `PATCH /subscriptions/{id}`, not just the changed ones. `proration_billing_mode` is also required. Without this you can accidentally remove a line item or get a 400.
- **Webhook events mapping:** Paddle uses `transaction.completed` instead of `payment.succeeded`, `transaction.payment_failed` instead of `payment.failed`, `adjustment.created` instead of `dispute.created`.

View File

@@ -0,0 +1,83 @@
package billing_paddle
import (
"encoding/json"
"errors"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
billing_webhooks "databasus-backend/internal/features/billing/webhooks"
"databasus-backend/internal/util/logger"
)
type PaddleBillingController struct {
paddleBillingService *PaddleBillingService
}
func (c *PaddleBillingController) RegisterPublicRoutes(router *gin.RouterGroup) {
router.POST("/billing/paddle/webhook", c.HandlePaddleWebhook)
}
// HandlePaddleWebhook
// @Summary Handle Paddle webhook
// @Description Process incoming webhook events from Paddle payment provider
// @Tags billing
// @Accept json
// @Produce json
// @Param Paddle-Signature header string true "Paddle webhook signature"
// @Success 200
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 500
// @Router /billing/paddle/webhook [post]
func (c *PaddleBillingController) HandlePaddleWebhook(ctx *gin.Context) {
requestID := uuid.New()
log := logger.GetLogger().With("request_id", requestID)
body, err := io.ReadAll(io.LimitReader(ctx.Request.Body, 1<<20))
if err != nil {
log.Error("failed to read webhook request body", "error", err)
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Failed to read request body"})
return
}
headers := make(map[string]string)
for k := range ctx.Request.Header {
headers[k] = ctx.GetHeader(k)
}
if err := c.paddleBillingService.VerifyWebhookSignature(body, headers); err != nil {
log.Warn("paddle webhook signature verification failed", "error", err)
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid webhook signature"})
return
}
var webhookDTO PaddleWebhookDTO
if err := json.Unmarshal(body, &webhookDTO); err != nil {
log.Error("failed to unmarshal webhook payload", "error", err)
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid webhook payload"})
return
}
log = log.With(
"provider_event_id", webhookDTO.EventID,
"event_type", webhookDTO.EventType,
)
if err := c.paddleBillingService.ProcessWebhookEvent(log, requestID, webhookDTO, body); err != nil {
if errors.Is(err, billing_webhooks.ErrDuplicateWebhook) {
log.Info("duplicate webhook event, returning 200 to not force retry")
ctx.Status(http.StatusOK)
return
}
log.Error("Failed to process paddle webhook", "error", err)
ctx.Status(http.StatusInternalServerError)
return
}
ctx.Status(http.StatusOK)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,72 @@
package billing_paddle
import (
"sync"
"github.com/PaddleHQ/paddle-go-sdk"
"databasus-backend/internal/config"
"databasus-backend/internal/features/billing"
billing_webhooks "databasus-backend/internal/features/billing/webhooks"
)
var (
paddleBillingService *PaddleBillingService
paddleBillingController *PaddleBillingController
initOnce sync.Once
)
func GetPaddleBillingService() *PaddleBillingService {
if !config.GetEnv().IsCloud {
return nil
}
initOnce.Do(func() {
if config.GetEnv().IsPaddleSandbox {
paddleClient, err := paddle.NewSandbox(config.GetEnv().PaddleApiKey)
if err != nil {
return
}
paddleBillingService = &PaddleBillingService{
paddleClient,
paddle.NewWebhookVerifier(config.GetEnv().PaddleWebhookSecret),
config.GetEnv().PaddlePriceID,
billing_webhooks.WebhookRepository{},
billing.GetBillingService(),
}
} else {
paddleClient, err := paddle.New(config.GetEnv().PaddleApiKey)
if err != nil {
return
}
paddleBillingService = &PaddleBillingService{
paddleClient,
paddle.NewWebhookVerifier(config.GetEnv().PaddleWebhookSecret),
config.GetEnv().PaddlePriceID,
billing_webhooks.WebhookRepository{},
billing.GetBillingService(),
}
}
paddleBillingController = &PaddleBillingController{paddleBillingService}
})
return paddleBillingService
}
func GetPaddleBillingController() *PaddleBillingController {
if !config.GetEnv().IsCloud {
return nil
}
// Ensure service + controller are initialized
GetPaddleBillingService()
return paddleBillingController
}
func SetupDependencies() {
billing.GetBillingService().SetBillingProvider(GetPaddleBillingService())
}

View File

@@ -0,0 +1,9 @@
package billing_paddle
import "encoding/json"
type PaddleWebhookDTO struct {
EventID string `json:"event_id"`
EventType string `json:"event_type"`
Data json.RawMessage
}

View File

@@ -0,0 +1,50 @@
package billing_paddle
import "time"
type TestSubscriptionCreatedPayload struct {
EventID string
SubID string
CustomerID string
DatabaseID string
QuantityGB int
PeriodStart time.Time
PeriodEnd time.Time
}
type TestSubscriptionUpdatedPayload struct {
EventID string
SubID string
CustomerID string
QuantityGB int
PeriodStart time.Time
PeriodEnd time.Time
HasScheduledChange bool
ScheduledChangeAction string
}
type TestSubscriptionCanceledPayload struct {
EventID string
SubID string
CustomerID string
}
type TestTransactionCompletedPayload struct {
EventID string
TxnID string
SubID string
CustomerID string
TotalCents int64
QuantityGB int
PeriodStart time.Time
PeriodEnd time.Time
}
type TestSubscriptionPastDuePayload struct {
EventID string
SubID string
CustomerID string
QuantityGB int
PeriodStart time.Time
PeriodEnd time.Time
}

View File

@@ -0,0 +1,638 @@
package billing_paddle
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strconv"
"time"
"github.com/PaddleHQ/paddle-go-sdk"
"github.com/google/uuid"
"databasus-backend/internal/features/billing"
billing_models "databasus-backend/internal/features/billing/models"
billing_provider "databasus-backend/internal/features/billing/provider"
billing_webhooks "databasus-backend/internal/features/billing/webhooks"
)
type PaddleBillingService struct {
client *paddle.SDK
webhookVerified *paddle.WebhookVerifier
priceID string
webhookRepository billing_webhooks.WebhookRepository
billingService *billing.BillingService
}
func (s *PaddleBillingService) GetProviderName() billing_provider.ProviderName {
return billing_provider.ProviderPaddle
}
func (s *PaddleBillingService) CreateCheckoutSession(
logger *slog.Logger,
request billing_provider.CheckoutRequest,
) (string, error) {
logger = logger.With("database_id", request.DatabaseID)
logger.Debug(fmt.Sprintf("paddle: creating checkout session for %d GB", request.StorageGB))
txRequest := &paddle.CreateTransactionRequest{
Items: []paddle.CreateTransactionItems{
*paddle.NewCreateTransactionItemsCatalogItem(&paddle.CatalogItem{
PriceID: s.priceID,
Quantity: request.StorageGB,
}),
},
CustomData: paddle.CustomData{"database_id": request.DatabaseID.String()},
Checkout: &paddle.TransactionCheckout{
URL: &request.SuccessURL,
},
}
tx, err := s.client.CreateTransaction(context.Background(), txRequest)
if err != nil {
logger.Error("paddle: failed to create transaction", "error", err)
return "", err
}
return tx.ID, nil
}
func (s *PaddleBillingService) UpgradeQuantityWithSurcharge(
logger *slog.Logger,
providerSubscriptionID string,
quantityGB int,
) error {
logger = logger.With("provider_subscription_id", providerSubscriptionID)
logger.Debug(fmt.Sprintf("paddle: applying upgrade: new storage %d GB", quantityGB))
// important: paddle requires to send all items
// in the subscription when updating, not just the changed one
subscription, err := s.client.GetSubscription(context.Background(), &paddle.GetSubscriptionRequest{
SubscriptionID: providerSubscriptionID,
})
if err != nil {
logger.Error("paddle: failed to get subscription", "error", err)
return err
}
currentQuantity := subscription.Items[0].Quantity
if currentQuantity == quantityGB {
logger.Info("paddle: subscription already at requested quantity, skipping upgrade",
"current_quantity_gb", currentQuantity,
"requested_quantity_gb", quantityGB,
)
return nil
}
priceID := subscription.Items[0].Price.ID
_, err = s.client.UpdateSubscription(context.Background(), &paddle.UpdateSubscriptionRequest{
SubscriptionID: providerSubscriptionID,
Items: paddle.NewPatchField([]paddle.SubscriptionUpdateCatalogItem{
{
PriceID: priceID,
Quantity: quantityGB,
},
}),
ProrationBillingMode: paddle.NewPatchField(paddle.ProrationBillingModeProratedImmediately),
})
if err != nil {
logger.Error("paddle: failed to update subscription", "error", err)
return err
}
logger.Debug("paddle: successfully applied upgrade")
return nil
}
func (s *PaddleBillingService) ScheduleQuantityDowngradeFromNextBillingCycle(
logger *slog.Logger,
providerSubscriptionID string,
quantityGB int,
) error {
logger = logger.With("provider_subscription_id", providerSubscriptionID)
logger.Debug(fmt.Sprintf("paddle: scheduling downgrade from next billing cycle: new storage %d GB", quantityGB))
// important: paddle requires to send all items
// in the subscription when updating, not just the changed one
subscription, err := s.client.GetSubscription(context.Background(), &paddle.GetSubscriptionRequest{
SubscriptionID: providerSubscriptionID,
})
if err != nil {
logger.Error("paddle: failed to get subscription", "error", err)
return err
}
currentQuantity := subscription.Items[0].Quantity
if currentQuantity == quantityGB {
logger.Info("paddle: subscription already at requested quantity, skipping downgrade",
"current_quantity_gb", currentQuantity,
"requested_quantity_gb", quantityGB,
)
return nil
}
if subscription.ScheduledChange != nil {
logger.Info("paddle: subscription already has a scheduled change, skipping downgrade")
return nil
}
priceID := subscription.Items[0].Price.ID
// apply downgrade from next billing cycle by setting the proration billing mode to "prorate on next billing period"
_, err = s.client.UpdateSubscription(context.Background(), &paddle.UpdateSubscriptionRequest{
SubscriptionID: providerSubscriptionID,
Items: paddle.NewPatchField([]paddle.SubscriptionUpdateCatalogItem{
{
PriceID: priceID,
Quantity: quantityGB,
},
}),
ProrationBillingMode: paddle.NewPatchField(paddle.ProrationBillingModeFullNextBillingPeriod),
})
if err != nil {
logger.Error("paddle: failed to update subscription for downgrade", "error", err)
return fmt.Errorf("failed to update subscription: %w", err)
}
logger.Debug("paddle: successfully scheduled downgrade from next billing cycle")
return nil
}
func (s *PaddleBillingService) GetSubscription(
logger *slog.Logger,
providerSubscriptionID string,
) (billing_provider.ProviderSubscription, error) {
logger = logger.With("provider_subscription_id", providerSubscriptionID)
logger.Debug("paddle: getting subscription details")
subscription, err := s.client.GetSubscription(context.Background(), &paddle.GetSubscriptionRequest{
SubscriptionID: providerSubscriptionID,
})
if err != nil {
logger.Error("paddle: failed to get subscription", "error", err)
return billing_provider.ProviderSubscription{}, err
}
logger.Debug(
fmt.Sprintf(
"paddle: successfully got subscription details: status=%s, quantity=%d",
subscription.Status,
subscription.Items[0].Quantity,
),
)
return s.toProviderSubscription(logger, subscription)
}
func (s *PaddleBillingService) CreatePortalSession(
logger *slog.Logger,
providerCustomerID, returnURL string,
) (string, error) {
logger = logger.With("provider_customer_id", providerCustomerID)
logger.Debug("paddle: creating portal session")
subscriptions, err := s.client.ListSubscriptions(context.Background(), &paddle.ListSubscriptionsRequest{
CustomerID: []string{providerCustomerID},
Status: []string{
string(paddle.SubscriptionStatusActive),
string(paddle.SubscriptionStatusPastDue),
},
})
if err != nil {
logger.Error("paddle: failed to list subscriptions for portal session", "error", err)
return "", err
}
res := subscriptions.Next(context.Background())
if !res.Ok() {
if res.Err() != nil {
logger.Error("paddle: failed to iterate subscriptions", "error", res.Err())
return "", res.Err()
}
logger.Error("paddle: no active subscriptions found for customer")
return "", fmt.Errorf("no active subscriptions found for customer %s", providerCustomerID)
}
subscription := res.Value()
if subscription.ManagementURLs.UpdatePaymentMethod == nil {
logger.Error("paddle: subscription has no management URL")
return "", fmt.Errorf("subscription %s has no management URL", subscription.ID)
}
return *subscription.ManagementURLs.UpdatePaymentMethod, nil
}
func (s *PaddleBillingService) VerifyWebhookSignature(body []byte, headers map[string]string) error {
req, _ := http.NewRequestWithContext(context.Background(), "POST", "/", bytes.NewReader(body))
for k, v := range headers {
req.Header.Set(k, v)
}
ok, err := s.webhookVerified.Verify(req)
if err != nil || !ok {
return fmt.Errorf("failed to verify webhook signature: %w", err)
}
return nil
}
func (s *PaddleBillingService) ProcessWebhookEvent(
logger *slog.Logger,
requestID uuid.UUID,
webhookDTO PaddleWebhookDTO,
rawBody []byte,
) error {
webhookEvent, err := s.normalizeWebhookEvent(
logger,
requestID,
webhookDTO.EventID,
webhookDTO.EventType,
webhookDTO.Data,
)
if err != nil {
if errors.Is(err, billing_webhooks.ErrUnsupportedEventType) {
return s.skipWebhookEvent(logger, requestID, webhookDTO, rawBody)
}
logger.Error("paddle: failed to normalize webhook event", "error", err)
return err
}
logArgs := []any{
"provider_event_id", webhookEvent.ProviderEventID,
"provider_subscription_id", webhookEvent.ProviderSubscriptionID,
"provider_customer_id", webhookEvent.ProviderCustomerID,
}
if webhookEvent.DatabaseID != nil {
logArgs = append(logArgs, "database_id", webhookEvent.DatabaseID)
}
logger = logger.With(logArgs...)
existingRecord, err := s.webhookRepository.FindSuccessfulByProviderEventID(webhookEvent.ProviderEventID)
if err == nil && existingRecord != nil {
logger.Info("paddle: webhook already processed successfully, skipping",
"existing_request_id", existingRecord.RequestID,
)
return billing_webhooks.ErrDuplicateWebhook
}
webhookRecord := &billing_webhooks.WebhookRecord{
RequestID: requestID,
ProviderName: billing_provider.ProviderPaddle,
EventType: string(webhookEvent.Type),
ProviderEventID: webhookEvent.ProviderEventID,
RawPayload: string(rawBody),
}
if err := s.webhookRepository.Insert(webhookRecord); err != nil {
logger.Error("paddle: failed to save webhook record", "error", err)
return err
}
if err := s.processWebhookEvent(logger, webhookEvent); err != nil {
logger.Error("paddle: failed to process webhook event", "error", err)
if markErr := s.webhookRepository.MarkError(requestID.String(), err.Error()); markErr != nil {
logger.Error("paddle: failed to mark webhook as errored", "error", markErr)
}
return err
}
if markErr := s.webhookRepository.MarkProcessed(requestID.String()); markErr != nil {
logger.Error("paddle: failed to mark webhook as processed", "error", markErr)
}
return nil
}
func (s *PaddleBillingService) skipWebhookEvent(
logger *slog.Logger,
requestID uuid.UUID,
webhookDTO PaddleWebhookDTO,
rawBody []byte,
) error {
webhookRecord := &billing_webhooks.WebhookRecord{
RequestID: requestID,
ProviderName: billing_provider.ProviderPaddle,
EventType: webhookDTO.EventType,
ProviderEventID: webhookDTO.EventID,
RawPayload: string(rawBody),
}
if err := s.webhookRepository.Insert(webhookRecord); err != nil {
logger.Error("paddle: failed to save skipped webhook record", "error", err)
return err
}
if err := s.webhookRepository.MarkSkipped(requestID.String()); err != nil {
logger.Error("paddle: failed to mark webhook as skipped", "error", err)
}
return nil
}
func (s *PaddleBillingService) processWebhookEvent(
logger *slog.Logger,
webhookEvent billing_models.WebhookEvent,
) error {
logger.Debug("processing webhook event")
// subscription.created - there is no subscription in the database yet
if webhookEvent.Type == billing_models.WHEventSubscriptionCreated {
return s.billingService.ActivateSubscription(logger, webhookEvent)
}
// dispute - finds subscription via invoice, no provider subscription ID available
if webhookEvent.Type == billing_models.WHEventSubscriptionDisputeCreated {
return s.billingService.RecordDispute(logger, webhookEvent)
}
// for others - search subscription first
subscription, err := s.billingService.GetSubscriptionByProviderSubID(logger, webhookEvent.ProviderSubscriptionID)
if err != nil {
logger.Error("paddle: failed to find subscription for webhook event", "error", err)
return err
}
logger = logger.With("subscription_id", subscription.ID, "database_id", subscription.DatabaseID)
logger.Debug(fmt.Sprintf("found subscription in DB with ID: %s", subscription.ID))
switch webhookEvent.Type {
case billing_models.WHEventSubscriptionUpdated:
if subscription.Status == billing_models.StatusCanceled {
return s.billingService.ReactivateSubscription(logger, subscription, webhookEvent)
}
return s.billingService.SyncSubscriptionFromProvider(logger, subscription, webhookEvent)
case billing_models.WHEventSubscriptionCanceled:
return s.billingService.CancelSubscription(logger, subscription, webhookEvent)
case billing_models.WHEventPaymentSucceeded:
return s.billingService.RecordPaymentSuccess(logger, subscription, webhookEvent)
case billing_models.WHEventSubscriptionPastDue:
return s.billingService.RecordPaymentFailed(logger, subscription, webhookEvent)
default:
logger.Error(fmt.Sprintf("unhandled webhook event type: %s", string(webhookEvent.Type)))
return nil
}
}
func (s *PaddleBillingService) normalizeWebhookEvent(
logger *slog.Logger,
requestID uuid.UUID,
eventID, eventType string,
data json.RawMessage,
) (billing_models.WebhookEvent, error) {
webhookEvent := billing_models.WebhookEvent{
RequestID: requestID,
ProviderEventID: eventID,
}
switch eventType {
case "subscription.created":
webhookEvent.Type = billing_models.WHEventSubscriptionCreated
var subscription paddle.Subscription
if err := json.Unmarshal(data, &subscription); err != nil {
logger.Error("paddle: failed to unmarshal subscription.created webhook data", "error", err)
return webhookEvent, err
}
webhookEvent.ProviderSubscriptionID = subscription.ID
webhookEvent.ProviderCustomerID = subscription.CustomerID
webhookEvent.QuantityGB = subscription.Items[0].Quantity
status, err := mapPaddleStatus(logger, subscription.Status)
if err != nil {
return webhookEvent, err
}
webhookEvent.Status = status
if subscription.CurrentBillingPeriod != nil {
periodStart := mustParseRFC3339(logger, "period start", subscription.CurrentBillingPeriod.StartsAt)
periodEnd := mustParseRFC3339(logger, "period end", subscription.CurrentBillingPeriod.EndsAt)
webhookEvent.PeriodStart = &periodStart
webhookEvent.PeriodEnd = &periodEnd
}
if subscription.CustomData == nil || subscription.CustomData["database_id"] == "" {
logger.Error("paddle: subscription has no database_id in custom data")
}
databaseIDStr, isOk := subscription.CustomData["database_id"].(string)
if !isOk {
logger.Error("paddle: database_id in custom data is not a string")
return webhookEvent, fmt.Errorf("invalid database_id type in custom data")
}
databaseID := uuid.MustParse(databaseIDStr)
webhookEvent.DatabaseID = &databaseID
case "subscription.updated":
var subscription paddle.Subscription
if err := json.Unmarshal(data, &subscription); err != nil {
return webhookEvent, err
}
webhookEvent.ProviderSubscriptionID = subscription.ID
webhookEvent.ProviderCustomerID = subscription.CustomerID
webhookEvent.QuantityGB = subscription.Items[0].Quantity
webhookEvent.Type = billing_models.WHEventSubscriptionUpdated
status, err := mapPaddleStatus(logger, subscription.Status)
if err != nil {
return webhookEvent, err
}
webhookEvent.Status = status
if subscription.CurrentBillingPeriod != nil {
periodStart := mustParseRFC3339(logger, "period start", subscription.CurrentBillingPeriod.StartsAt)
periodEnd := mustParseRFC3339(logger, "period end", subscription.CurrentBillingPeriod.EndsAt)
webhookEvent.PeriodStart = &periodStart
webhookEvent.PeriodEnd = &periodEnd
}
if subscription.ScheduledChange != nil &&
subscription.ScheduledChange.Action == paddle.ScheduledChangeActionCancel {
webhookEvent.Type = billing_models.WHEventSubscriptionCanceled
}
case "subscription.canceled":
webhookEvent.Type = billing_models.WHEventSubscriptionCanceled
var subscription paddle.Subscription
if err := json.Unmarshal(data, &subscription); err != nil {
return webhookEvent, err
}
webhookEvent.ProviderSubscriptionID = subscription.ID
webhookEvent.ProviderCustomerID = subscription.CustomerID
status, err := mapPaddleStatus(logger, subscription.Status)
if err != nil {
return webhookEvent, err
}
webhookEvent.Status = status
case "transaction.completed":
webhookEvent.Type = billing_models.WHEventPaymentSucceeded
var transaction paddle.Transaction
if err := json.Unmarshal(data, &transaction); err != nil {
return webhookEvent, err
}
webhookEvent.ProviderInvoiceID = transaction.ID
if len(transaction.Items) > 0 {
webhookEvent.QuantityGB = transaction.Items[0].Quantity
}
if transaction.SubscriptionID != nil {
webhookEvent.ProviderSubscriptionID = *transaction.SubscriptionID
}
if transaction.CustomerID != nil {
webhookEvent.ProviderCustomerID = *transaction.CustomerID
}
amountCents, err := strconv.ParseInt(transaction.Details.Totals.Total, 10, 64)
if err != nil {
logger.Error("paddle: failed to parse transaction total", "error", err)
} else {
webhookEvent.AmountCents = amountCents
}
if transaction.BillingPeriod != nil {
periodStart := mustParseRFC3339(logger, "period start", transaction.BillingPeriod.StartsAt)
periodEnd := mustParseRFC3339(logger, "period end", transaction.BillingPeriod.EndsAt)
webhookEvent.PeriodStart = &periodStart
webhookEvent.PeriodEnd = &periodEnd
}
case "subscription.past_due":
webhookEvent.Type = billing_models.WHEventSubscriptionPastDue
var subscription paddle.Subscription
if err := json.Unmarshal(data, &subscription); err != nil {
return webhookEvent, err
}
webhookEvent.ProviderSubscriptionID = subscription.ID
webhookEvent.ProviderCustomerID = subscription.CustomerID
webhookEvent.QuantityGB = subscription.Items[0].Quantity
status, err := mapPaddleStatus(logger, subscription.Status)
if err != nil {
return webhookEvent, err
}
webhookEvent.Status = status
if subscription.CurrentBillingPeriod != nil {
periodStart := mustParseRFC3339(logger, "period start", subscription.CurrentBillingPeriod.StartsAt)
periodEnd := mustParseRFC3339(logger, "period end", subscription.CurrentBillingPeriod.EndsAt)
webhookEvent.PeriodStart = &periodStart
webhookEvent.PeriodEnd = &periodEnd
}
case "adjustment.created":
webhookEvent.Type = billing_models.WHEventSubscriptionDisputeCreated
var adjustment struct {
TransactionID string `json:"transaction_id"`
}
if err := json.Unmarshal(data, &adjustment); err != nil {
return webhookEvent, err
}
webhookEvent.ProviderInvoiceID = adjustment.TransactionID
default:
logger.Debug("unsupported paddle event type, skipping", "event_type", eventType)
return webhookEvent, billing_webhooks.ErrUnsupportedEventType
}
return webhookEvent, nil
}
func (s *PaddleBillingService) toProviderSubscription(
logger *slog.Logger,
paddleSubscription *paddle.Subscription,
) (billing_provider.ProviderSubscription, error) {
status, err := mapPaddleStatus(logger, paddleSubscription.Status)
if err != nil {
return billing_provider.ProviderSubscription{}, err
}
if len(paddleSubscription.Items) == 0 {
return billing_provider.ProviderSubscription{}, fmt.Errorf(
"paddle subscription %s has no items",
paddleSubscription.ID,
)
}
providerSubscription := &billing_provider.ProviderSubscription{
ProviderSubscriptionID: paddleSubscription.ID,
ProviderCustomerID: paddleSubscription.CustomerID,
Status: status,
QuantityGB: paddleSubscription.Items[0].Quantity,
}
if paddleSubscription.CurrentBillingPeriod != nil {
providerSubscription.PeriodStart = mustParseRFC3339(
logger,
"period start",
paddleSubscription.CurrentBillingPeriod.StartsAt,
)
providerSubscription.PeriodEnd = mustParseRFC3339(
logger,
"period end",
paddleSubscription.CurrentBillingPeriod.EndsAt,
)
}
return *providerSubscription, nil
}
func mustParseRFC3339(logger *slog.Logger, label, value string) time.Time {
parsed, err := time.Parse(time.RFC3339, value)
if err != nil {
logger.Error(fmt.Sprintf("paddle: failed to parse %s", label), "error", err)
}
return parsed
}
func mapPaddleStatus(logger *slog.Logger, s paddle.SubscriptionStatus) (billing_models.SubscriptionStatus, error) {
switch s {
case paddle.SubscriptionStatusActive:
return billing_models.StatusActive, nil
case paddle.SubscriptionStatusPastDue:
return billing_models.StatusPastDue, nil
case paddle.SubscriptionStatusCanceled:
return billing_models.StatusCanceled, nil
case paddle.SubscriptionStatusTrialing:
return billing_models.StatusTrial, nil
case paddle.SubscriptionStatusPaused:
return billing_models.StatusCanceled, nil
default:
logger.Error(fmt.Sprintf("paddle: unknown subscription status: %s", string(s)))
return "", fmt.Errorf("paddle: unknown subscription status: %s", string(s))
}
}

View File

@@ -0,0 +1,38 @@
package billing_provider
import (
"time"
"github.com/google/uuid"
billing_models "databasus-backend/internal/features/billing/models"
)
type CreateSubscriptionRequest struct {
ProviderCustomerID string
DatabaseID uuid.UUID
StorageGB int
}
type ProviderSubscription struct {
ProviderSubscriptionID string
ProviderCustomerID string
Status billing_models.SubscriptionStatus
QuantityGB int
PeriodStart time.Time
PeriodEnd time.Time
}
type CheckoutRequest struct {
DatabaseID uuid.UUID
Email string
StorageGB int
SuccessURL string
CancelURL string
}
type ProviderName string
const (
ProviderPaddle ProviderName = "paddle"
)

View File

@@ -0,0 +1,21 @@
package billing_provider
import "log/slog"
type BillingProvider interface {
GetProviderName() ProviderName
UpgradeQuantityWithSurcharge(logger *slog.Logger, providerSubscriptionID string, quantityGB int) error
ScheduleQuantityDowngradeFromNextBillingCycle(
logger *slog.Logger,
providerSubscriptionID string,
quantityGB int,
) error
GetSubscription(logger *slog.Logger, providerSubscriptionID string) (ProviderSubscription, error)
CreateCheckoutSession(logger *slog.Logger, req CheckoutRequest) (checkoutURL string, err error)
CreatePortalSession(logger *slog.Logger, providerCustomerID, returnURL string) (portalURL string, err error)
}

View File

@@ -0,0 +1,72 @@
package billing_repositories
import (
"errors"
"github.com/google/uuid"
"gorm.io/gorm"
billing_models "databasus-backend/internal/features/billing/models"
"databasus-backend/internal/storage"
)
type InvoiceRepository struct{}
func (r *InvoiceRepository) Save(invoice billing_models.Invoice) error {
if invoice.SubscriptionID == uuid.Nil {
return errors.New("subscription id is required")
}
db := storage.GetDb()
if invoice.ID == uuid.Nil {
invoice.ID = uuid.New()
return db.Create(&invoice).Error
}
return db.Save(invoice).Error
}
func (r *InvoiceRepository) FindByProviderInvID(providerInvoiceID string) (*billing_models.Invoice, error) {
var invoice billing_models.Invoice
if err := storage.GetDb().Where("provider_invoice_id = ?", providerInvoiceID).
First(&invoice).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &invoice, nil
}
func (r *InvoiceRepository) FindByDatabaseID(
databaseID uuid.UUID,
limit, offset int,
) ([]*billing_models.Invoice, error) {
var invoices []*billing_models.Invoice
if err := storage.GetDb().Joins("JOIN subscriptions ON subscriptions.id = invoices.subscription_id").
Where("subscriptions.database_id = ?", databaseID).
Order("invoices.created_at DESC").
Limit(limit).
Offset(offset).
Find(&invoices).Error; err != nil {
return nil, err
}
return invoices, nil
}
func (r *InvoiceRepository) CountByDatabaseID(databaseID uuid.UUID) (int64, error) {
var count int64
err := storage.GetDb().Model(&billing_models.Invoice{}).
Joins("JOIN subscriptions ON subscriptions.id = invoices.subscription_id").
Where("subscriptions.database_id = ?", databaseID).
Count(&count).Error
return count, err
}

View File

@@ -0,0 +1,50 @@
package billing_repositories
import (
"errors"
"github.com/google/uuid"
billing_models "databasus-backend/internal/features/billing/models"
"databasus-backend/internal/storage"
)
type SubscriptionEventRepository struct{}
func (r *SubscriptionEventRepository) Create(event billing_models.SubscriptionEvent) error {
if event.SubscriptionID == uuid.Nil {
return errors.New("subscription id is required")
}
event.ID = uuid.New()
return storage.GetDb().Create(&event).Error
}
func (r *SubscriptionEventRepository) FindByDatabaseID(
databaseID uuid.UUID,
limit, offset int,
) ([]*billing_models.SubscriptionEvent, error) {
var events []*billing_models.SubscriptionEvent
if err := storage.GetDb().Joins("JOIN subscriptions ON subscriptions.id = subscription_events.subscription_id").
Where("subscriptions.database_id = ?", databaseID).
Order("subscription_events.created_at DESC").
Limit(limit).
Offset(offset).
Find(&events).Error; err != nil {
return nil, err
}
return events, nil
}
func (r *SubscriptionEventRepository) CountByDatabaseID(databaseID uuid.UUID) (int64, error) {
var count int64
err := storage.GetDb().Model(&billing_models.SubscriptionEvent{}).
Joins("JOIN subscriptions ON subscriptions.id = subscription_events.subscription_id").
Where("subscriptions.database_id = ?", databaseID).
Count(&count).Error
return count, err
}

View File

@@ -0,0 +1,123 @@
package billing_repositories
import (
"errors"
"time"
"github.com/google/uuid"
"gorm.io/gorm"
billing_models "databasus-backend/internal/features/billing/models"
"databasus-backend/internal/storage"
)
type SubscriptionRepository struct{}
func (r *SubscriptionRepository) Save(sub billing_models.Subscription) error {
db := storage.GetDb()
if sub.ID == uuid.Nil {
sub.ID = uuid.New()
return db.Create(&sub).Error
}
return db.Save(&sub).Error
}
func (r *SubscriptionRepository) FindByID(id uuid.UUID) (*billing_models.Subscription, error) {
var sub billing_models.Subscription
if err := storage.GetDb().Where("id = ?", id).First(&sub).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &sub, nil
}
func (r *SubscriptionRepository) FindByDatabaseIDAndStatuses(
databaseID uuid.UUID,
stauses []billing_models.SubscriptionStatus,
) ([]*billing_models.Subscription, error) {
var subs []*billing_models.Subscription
if err := storage.GetDb().Where("database_id = ? AND status IN ?", databaseID, stauses).
Find(&subs).Error; err != nil {
return nil, err
}
return subs, nil
}
func (r *SubscriptionRepository) FindLatestByDatabaseID(databaseID uuid.UUID) (*billing_models.Subscription, error) {
var sub billing_models.Subscription
if err := storage.GetDb().
Where("database_id = ?", databaseID).
Order("created_at DESC").
First(&sub).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &sub, nil
}
func (r *SubscriptionRepository) FindByProviderSubID(providerSubID string) (*billing_models.Subscription, error) {
var sub billing_models.Subscription
if err := storage.GetDb().Where("provider_sub_id = ?", providerSubID).
First(&sub).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &sub, nil
}
func (r *SubscriptionRepository) FindByStatuses(
statuses []billing_models.SubscriptionStatus,
) ([]billing_models.Subscription, error) {
var subs []billing_models.Subscription
if err := storage.GetDb().Where("status IN ?", statuses).Find(&subs).Error; err != nil {
return nil, err
}
return subs, nil
}
func (r *SubscriptionRepository) FindCanceledWithEndedGracePeriod(
now time.Time,
) ([]billing_models.Subscription, error) {
var subs []billing_models.Subscription
if err := storage.GetDb().
Where("status = ? AND data_retention_grace_period_until < ?", billing_models.StatusCanceled, now).
Find(&subs).
Error; err != nil {
return nil, err
}
return subs, nil
}
func (r *SubscriptionRepository) FindExpiredTrials(now time.Time) ([]billing_models.Subscription, error) {
var subs []billing_models.Subscription
if err := storage.GetDb().Where("status = ? AND current_period_end < ?", billing_models.StatusTrial, now).
Find(&subs).Error; err != nil {
return nil, err
}
return subs, nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,8 @@
package billing_webhooks
import "errors"
var (
ErrDuplicateWebhook = errors.New("duplicate webhook event")
ErrUnsupportedEventType = errors.New("unsupported webhook event type")
)

View File

@@ -0,0 +1,25 @@
package billing_webhooks
import (
"time"
"github.com/google/uuid"
billing_provider "databasus-backend/internal/features/billing/provider"
)
type WebhookRecord struct {
RequestID uuid.UUID `gorm:"column:request_id;primaryKey;type:uuid;default:gen_random_uuid()"`
ProviderName billing_provider.ProviderName `gorm:"column:provider_name;type:text;not null"`
EventType string `gorm:"column:event_type;type:text;not null"`
ProviderEventID string `gorm:"column:provider_event_id;type:text;not null;index"`
RawPayload string `gorm:"column:raw_payload;type:text;not null"`
ProcessedAt *time.Time `gorm:"column:processed_at"`
IsSkipped bool `gorm:"column:is_skipped;not null;default:false"`
Error *string `gorm:"column:error"`
CreatedAt time.Time `gorm:"column:created_at;not null"`
}
func (WebhookRecord) TableName() string {
return "webhook_records"
}

View File

@@ -0,0 +1,73 @@
package billing_webhooks
import (
"errors"
"time"
"gorm.io/gorm"
"databasus-backend/internal/storage"
)
type WebhookRepository struct{}
func (r *WebhookRepository) FindSuccessfulByProviderEventID(providerEventID string) (*WebhookRecord, error) {
var record WebhookRecord
err := storage.GetDb().
Where("provider_event_id = ? AND processed_at IS NOT NULL", providerEventID).
First(&record).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
return nil, err
}
return &record, nil
}
func (r *WebhookRepository) Insert(record *WebhookRecord) error {
if record.ProviderEventID == "" {
return errors.New("provider event ID is required")
}
record.CreatedAt = time.Now().UTC()
return storage.GetDb().Create(record).Error
}
func (r *WebhookRepository) MarkProcessed(requestID string) error {
now := time.Now().UTC()
return storage.
GetDb().
Model(&WebhookRecord{}).
Where("request_id = ?", requestID).
Update("processed_at", now).
Error
}
func (r *WebhookRepository) MarkSkipped(requestID string) error {
now := time.Now().UTC()
return storage.
GetDb().
Model(&WebhookRecord{}).
Where("request_id = ?", requestID).
Updates(map[string]any{
"is_skipped": true,
"processed_at": now,
}).
Error
}
func (r *WebhookRepository) MarkError(requestID, errMsg string) error {
return storage.
GetDb().
Model(&WebhookRecord{}).
Where("request_id = ?", requestID).
Update("error", errMsg).
Error
}

View File

@@ -1328,6 +1328,143 @@ func getTestPostgresConfig() *postgresql.PostgresqlDatabase {
}
}
func Test_CreateDatabase_WhenCloudAndUserIsNotReadOnly_ReturnsBadRequest(t *testing.T) {
enableCloud(t)
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Cloud Not ReadOnly", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
request := Database{
Name: "Cloud Non-ReadOnly DB",
WorkspaceID: &workspace.ID,
Type: DatabaseTypePostgres,
Postgresql: getTestPostgresConfig(),
}
resp := test_utils.MakePostRequest(
t,
router,
"/api/v1/databases/create",
"Bearer "+owner.Token,
request,
http.StatusBadRequest,
)
assert.Contains(t, string(resp.Body), "in cloud mode, only read-only database users are allowed")
}
func Test_CreateDatabase_WhenCloudAndUserIsReadOnly_DatabaseCreated(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Cloud ReadOnly", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Temp DB for RO User", workspace.ID, owner.Token, router)
readOnlyUser := createReadOnlyUserViaAPI(t, router, database.ID, owner.Token)
assert.NotEmpty(t, readOnlyUser.Username)
assert.NotEmpty(t, readOnlyUser.Password)
RemoveTestDatabase(database)
enableCloud(t)
pgConfig := getTestPostgresConfig()
pgConfig.Username = readOnlyUser.Username
pgConfig.Password = readOnlyUser.Password
request := Database{
Name: "Cloud ReadOnly DB",
WorkspaceID: &workspace.ID,
Type: DatabaseTypePostgres,
Postgresql: pgConfig,
}
var response Database
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/databases/create",
"Bearer "+owner.Token,
request,
http.StatusCreated,
&response,
)
defer RemoveTestDatabase(&response)
assert.Equal(t, "Cloud ReadOnly DB", response.Name)
assert.NotEqual(t, uuid.Nil, response.ID)
}
func Test_CreateDatabase_WhenNotCloudAndUserIsNotReadOnly_DatabaseCreated(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Non-Cloud", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
request := Database{
Name: "Non-Cloud DB",
WorkspaceID: &workspace.ID,
Type: DatabaseTypePostgres,
Postgresql: getTestPostgresConfig(),
}
var response Database
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/databases/create",
"Bearer "+owner.Token,
request,
http.StatusCreated,
&response,
)
defer RemoveTestDatabase(&response)
assert.Equal(t, "Non-Cloud DB", response.Name)
assert.NotEqual(t, uuid.Nil, response.ID)
}
func enableCloud(t *testing.T) {
t.Helper()
config.GetEnv().IsCloud = true
t.Cleanup(func() {
config.GetEnv().IsCloud = false
})
}
func createReadOnlyUserViaAPI(
t *testing.T,
router *gin.Engine,
databaseID uuid.UUID,
token string,
) *CreateReadOnlyUserResponse {
var database Database
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/databases/%s", databaseID.String()),
"Bearer "+token,
http.StatusOK,
&database,
)
var response CreateReadOnlyUserResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/databases/create-readonly-user",
"Bearer "+token,
database,
http.StatusOK,
&response,
)
return &response
}
func getTestMariadbConfig() *mariadb.MariadbDatabase {
env := config.GetEnv()
portStr := env.TestMariadb1011Port

View File

@@ -81,8 +81,8 @@ func (p *PostgresqlDatabase) Validate() error {
p.BackupType = PostgresBackupTypePgDump
}
if p.BackupType == PostgresBackupTypePgDump && config.GetEnv().IsCloud {
return errors.New("PG_DUMP backup type is not supported in cloud mode")
if p.BackupType != PostgresBackupTypePgDump && config.GetEnv().IsCloud {
return errors.New("only PG_DUMP backup type is supported in cloud mode")
}
if p.BackupType == PostgresBackupTypePgDump {

View File

@@ -1310,6 +1310,46 @@ func Test_Validate_WhenRequiredFieldsMissing_ReturnsError(t *testing.T) {
}
}
func Test_Validate_WhenCloudAndBackupTypeIsNotPgDump_ValidationFails(t *testing.T) {
enableCloud(t)
model := &PostgresqlDatabase{
Host: "example.com",
Port: 5432,
Username: "user",
Password: "pass",
CpuCount: 1,
BackupType: PostgresBackupTypeWalV1,
}
err := model.Validate()
assert.EqualError(t, err, "only PG_DUMP backup type is supported in cloud mode")
}
func Test_Validate_WhenCloudAndBackupTypeIsPgDump_ValidationPasses(t *testing.T) {
enableCloud(t)
model := &PostgresqlDatabase{
Host: "example.com",
Port: 5432,
Username: "user",
Password: "pass",
CpuCount: 1,
BackupType: PostgresBackupTypePgDump,
}
err := model.Validate()
assert.NoError(t, err)
}
func enableCloud(t *testing.T) {
t.Helper()
config.GetEnv().IsCloud = true
t.Cleanup(func() {
config.GetEnv().IsCloud = false
})
}
type PostgresContainer struct {
Host string
Port int

View File

@@ -1,20 +0,0 @@
package plans
import (
"databasus-backend/internal/util/logger"
)
var databasePlanRepository = &DatabasePlanRepository{}
var databasePlanService = &DatabasePlanService{
databasePlanRepository,
logger.GetLogger(),
}
func GetDatabasePlanService() *DatabasePlanService {
return databasePlanService
}
func GetDatabasePlanRepository() *DatabasePlanRepository {
return databasePlanRepository
}

View File

@@ -1,19 +0,0 @@
package plans
import (
"github.com/google/uuid"
"databasus-backend/internal/util/period"
)
type DatabasePlan struct {
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;primaryKey;not null"`
MaxBackupSizeMB int64 `json:"maxBackupSizeMb" gorm:"column:max_backup_size_mb;type:int;not null"`
MaxBackupsTotalSizeMB int64 `json:"maxBackupsTotalSizeMb" gorm:"column:max_backups_total_size_mb;type:int;not null"`
MaxStoragePeriod period.TimePeriod `json:"maxStoragePeriod" gorm:"column:max_storage_period;type:text;not null"`
}
func (p *DatabasePlan) TableName() string {
return "database_plans"
}

View File

@@ -1,27 +0,0 @@
package plans
import (
"github.com/google/uuid"
"databasus-backend/internal/storage"
)
type DatabasePlanRepository struct{}
func (r *DatabasePlanRepository) GetDatabasePlan(databaseID uuid.UUID) (*DatabasePlan, error) {
var databasePlan DatabasePlan
if err := storage.GetDb().Where("database_id = ?", databaseID).First(&databasePlan).Error; err != nil {
if err.Error() == "record not found" {
return nil, nil
}
return nil, err
}
return &databasePlan, nil
}
func (r *DatabasePlanRepository) CreateDatabasePlan(databasePlan *DatabasePlan) error {
return storage.GetDb().Create(&databasePlan).Error
}

View File

@@ -1,68 +0,0 @@
package plans
import (
"log/slog"
"github.com/google/uuid"
"databasus-backend/internal/config"
"databasus-backend/internal/util/period"
)
type DatabasePlanService struct {
databasePlanRepository *DatabasePlanRepository
logger *slog.Logger
}
func (s *DatabasePlanService) GetDatabasePlan(databaseID uuid.UUID) (*DatabasePlan, error) {
plan, err := s.databasePlanRepository.GetDatabasePlan(databaseID)
if err != nil {
return nil, err
}
if plan == nil {
s.logger.Info("no database plan found, creating default plan", "databaseID", databaseID)
defaultPlan := s.createDefaultDatabasePlan(databaseID)
err := s.databasePlanRepository.CreateDatabasePlan(defaultPlan)
if err != nil {
s.logger.Error("failed to create default database plan", "error", err)
return nil, err
}
return defaultPlan, nil
}
return plan, nil
}
func (s *DatabasePlanService) createDefaultDatabasePlan(databaseID uuid.UUID) *DatabasePlan {
var plan DatabasePlan
isCloud := config.GetEnv().IsCloud
if isCloud {
s.logger.Info("creating default database plan for cloud", "databaseID", databaseID)
// for playground we set limited storages enough to test,
// but not too expensive to provide it for Databasus
plan = DatabasePlan{
DatabaseID: databaseID,
MaxBackupSizeMB: 100, // ~ 1.5GB database
MaxBackupsTotalSizeMB: 4000, // ~ 30 daily backups + 10 manual backups
MaxStoragePeriod: period.PeriodWeek,
}
} else {
s.logger.Info("creating default database plan for self hosted", "databaseID", databaseID)
// by default - everything is unlimited in self hosted mode
plan = DatabasePlan{
DatabaseID: databaseID,
MaxBackupSizeMB: 0,
MaxBackupsTotalSizeMB: 0,
MaxStoragePeriod: period.PeriodForever,
}
}
return &plan
}

View File

@@ -775,7 +775,123 @@ func cleanupDatabaseWithBackup(database *databases.Database, backup *backups_cor
}
}
func Test_RestoreBackup_WhenCloudAndCpuCountMoreThanOne_ReturnsBadRequest(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
enableCloud(t)
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: env_config.GetEnv().TestLocalhost,
Port: 5432,
Username: "postgres",
Password: "postgres",
CpuCount: 4,
},
}
testResp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
"Bearer "+owner.Token,
request,
http.StatusBadRequest,
)
assert.Contains(t, string(testResp.Body), "multi-thread restore is not supported in cloud mode")
}
func Test_RestoreBackup_WhenCloudAndCpuCountIsOne_RestoreInitiated(t *testing.T) {
router := createTestRouter()
_, cleanup := SetupMockRestoreNode(t)
defer cleanup()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
enableCloud(t)
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: env_config.GetEnv().TestLocalhost,
Port: 5432,
Username: "postgres",
Password: "postgres",
CpuCount: 1,
},
}
testResp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
"Bearer "+owner.Token,
request,
http.StatusOK,
)
assert.Contains(t, string(testResp.Body), "restore started successfully")
}
func Test_RestoreBackup_WhenNotCloudAndCpuCountMoreThanOne_RestoreInitiated(t *testing.T) {
router := createTestRouter()
_, cleanup := SetupMockRestoreNode(t)
defer cleanup()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: env_config.GetEnv().TestLocalhost,
Port: 5432,
Username: "postgres",
Password: "postgres",
CpuCount: 4,
},
}
testResp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
"Bearer "+owner.Token,
request,
http.StatusOK,
)
assert.Contains(t, string(testResp.Body), "restore started successfully")
}
func cleanupBackup(backup *backups_core.Backup) {
repo := &backups_core.BackupRepository{}
repo.DeleteByID(backup.ID)
}
func enableCloud(t *testing.T) {
t.Helper()
env_config.GetEnv().IsCloud = true
t.Cleanup(func() {
env_config.GetEnv().IsCloud = false
})
}

View File

@@ -129,11 +129,14 @@ func (s *RestoreService) RestoreBackupWithAuth(
return err
}
if config.GetEnv().IsCloud {
// in cloud mode we use only single thread mode,
// because otherwise we will exhaust local storage
// space (instead of streaming from S3 directly to DB)
requestDTO.PostgresqlDatabase.CpuCount = 1
if config.GetEnv().IsCloud && requestDTO.PostgresqlDatabase != nil &&
requestDTO.PostgresqlDatabase.CpuCount > 1 {
s.logger.Warn("restore rejected: multi-thread mode not supported in cloud",
"requested_cpu_count", requestDTO.PostgresqlDatabase.CpuCount)
return errors.New(
"multi-thread restore is not supported in cloud mode, only single thread (CPU=1) is allowed",
)
}
if err := s.validateVersionCompatibility(backupDatabase, requestDTO); err != nil {

View File

@@ -21,13 +21,19 @@ func NewMultiHandler(
}
func (h *MultiHandler) Enabled(ctx context.Context, level slog.Level) bool {
if h.victoriaLogsWriter != nil {
return level >= slog.LevelDebug
}
return h.stdoutHandler.Enabled(ctx, level)
}
func (h *MultiHandler) Handle(ctx context.Context, record slog.Record) error {
// Send to stdout handler
if err := h.stdoutHandler.Handle(ctx, record); err != nil {
return err
// Send to stdout handler (only if level is enabled for stdout)
if h.stdoutHandler.Enabled(ctx, record.Level) {
if err := h.stdoutHandler.Handle(ctx, record); err != nil {
return err
}
}
// Send to VictoriaLogs if configured

View File

@@ -0,0 +1,102 @@
-- +goose Up
-- +goose StatementBegin
CREATE TABLE subscriptions (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
database_id UUID NOT NULL,
status TEXT NOT NULL,
storage_gb INT NOT NULL,
pending_storage_gb INT,
current_period_start TIMESTAMPTZ NOT NULL,
current_period_end TIMESTAMPTZ NOT NULL,
canceled_at TIMESTAMPTZ,
data_retention_grace_period_until TIMESTAMPTZ,
provider_name TEXT,
provider_sub_id TEXT,
provider_customer_id TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX idx_subscriptions_database_id ON subscriptions (database_id);
CREATE INDEX idx_subscriptions_status ON subscriptions (status);
CREATE INDEX idx_subscriptions_provider_sub_id ON subscriptions (provider_sub_id);
CREATE TABLE invoices (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
subscription_id UUID NOT NULL,
provider_invoice_id TEXT NOT NULL,
amount_cents BIGINT NOT NULL,
storage_gb INT NOT NULL,
period_start TIMESTAMPTZ NOT NULL,
period_end TIMESTAMPTZ NOT NULL,
status TEXT NOT NULL,
paid_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
ALTER TABLE invoices
ADD CONSTRAINT fk_invoices_subscription_id
FOREIGN KEY (subscription_id)
REFERENCES subscriptions (id)
ON DELETE CASCADE;
CREATE INDEX idx_invoices_subscription_id ON invoices (subscription_id);
CREATE INDEX idx_invoices_provider_invoice_id ON invoices (provider_invoice_id);
CREATE TABLE subscription_events (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
subscription_id UUID NOT NULL,
provider_event_id TEXT,
type TEXT NOT NULL,
old_storage_gb INT,
new_storage_gb INT,
old_status TEXT,
new_status TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
ALTER TABLE subscription_events
ADD CONSTRAINT fk_subscription_events_subscription_id
FOREIGN KEY (subscription_id)
REFERENCES subscriptions (id)
ON DELETE CASCADE;
CREATE INDEX idx_subscription_events_subscription_id ON subscription_events (subscription_id);
CREATE TABLE webhook_records (
request_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
provider_name TEXT NOT NULL,
event_type TEXT NOT NULL,
provider_event_id TEXT NOT NULL,
raw_payload TEXT NOT NULL,
processed_at TIMESTAMPTZ,
error TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX idx_webhook_records_provider_event_id ON webhook_records (provider_event_id);
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
DROP INDEX IF EXISTS idx_webhook_records_provider_event_id;
DROP TABLE IF EXISTS webhook_records;
DROP INDEX IF EXISTS idx_subscription_events_subscription_id;
ALTER TABLE subscription_events DROP CONSTRAINT IF EXISTS fk_subscription_events_subscription_id;
DROP TABLE IF EXISTS subscription_events;
DROP INDEX IF EXISTS idx_invoices_provider_invoice_id;
DROP INDEX IF EXISTS idx_invoices_subscription_id;
ALTER TABLE invoices DROP CONSTRAINT IF EXISTS fk_invoices_subscription_id;
DROP TABLE IF EXISTS invoices;
DROP INDEX IF EXISTS idx_subscriptions_provider_sub_id;
DROP INDEX IF EXISTS idx_subscriptions_status;
DROP INDEX IF EXISTS idx_subscriptions_database_id;
DROP TABLE IF EXISTS subscriptions;
-- +goose StatementEnd

View File

@@ -0,0 +1,39 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE backup_configs
DROP COLUMN IF EXISTS max_backup_size_mb,
DROP COLUMN IF EXISTS max_backups_total_size_mb;
DROP INDEX IF EXISTS idx_database_plans_database_id;
ALTER TABLE database_plans
DROP CONSTRAINT IF EXISTS fk_database_plans_database_id;
DROP TABLE IF EXISTS database_plans;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE backup_configs
ADD COLUMN max_backup_size_mb BIGINT NOT NULL DEFAULT 0,
ADD COLUMN max_backups_total_size_mb BIGINT NOT NULL DEFAULT 0;
CREATE TABLE database_plans (
database_id UUID PRIMARY KEY,
max_backup_size_mb BIGINT NOT NULL,
max_backups_total_size_mb BIGINT NOT NULL,
max_storage_period TEXT NOT NULL
);
ALTER TABLE database_plans
ADD CONSTRAINT fk_database_plans_database_id
FOREIGN KEY (database_id)
REFERENCES databases (id)
ON DELETE CASCADE;
CREATE INDEX idx_database_plans_database_id ON database_plans (database_id);
-- +goose StatementEnd

View File

@@ -0,0 +1,11 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE webhook_records
ADD COLUMN is_skipped BOOLEAN NOT NULL DEFAULT FALSE;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE webhook_records
DROP COLUMN is_skipped;
-- +goose StatementEnd