FEATURE (cloud): Add cloud

This commit is contained in:
Rostislav Dugin
2026-03-26 12:35:32 +03:00
parent c648e9c29f
commit 61a0bcabb1
106 changed files with 8924 additions and 1963 deletions

609
AGENTS.md
View File

@@ -9,7 +9,6 @@ This is NOT a strict set of rules, but a set of recommendations to help you writ
- [Engineering philosophy](#engineering-philosophy)
- [Backend guidelines](#backend-guidelines)
- [Code style](#code-style)
- [Boolean naming](#boolean-naming)
- [Add reasonable new lines between logical statements](#add-reasonable-new-lines-between-logical-statements)
- [Comments](#comments)
@@ -19,6 +18,7 @@ This is NOT a strict set of rules, but a set of recommendations to help you writ
- [Refactoring](#refactoring)
- [Testing](#testing)
- [Time handling](#time-handling)
- [Logging](#logging)
- [CRUD examples](#crud-examples)
- [Frontend guidelines](#frontend-guidelines)
- [React component structure](#react-component-structure)
@@ -122,157 +122,6 @@ Good:
Exclusion: widely used variables like "db", "ctx", "req", "res", etc.
### Code style
**Always place private methods to the bottom of file**
This rule applies to ALL Go files including tests, services, controllers, repositories, etc.
In Go, exported (public) functions/methods start with uppercase letters, while unexported (private) ones start with lowercase letters.
#### Structure order:
1. Type definitions and constants
2. Public methods/functions (uppercase)
3. Private methods/functions (lowercase)
#### Examples:
**Service with methods:**
```go
type UserService struct {
repository *UserRepository
}
// Public methods first
func (s *UserService) CreateUser(user *User) error {
if err := s.validateUser(user); err != nil {
return err
}
return s.repository.Save(user)
}
func (s *UserService) GetUser(id uuid.UUID) (*User, error) {
return s.repository.FindByID(id)
}
// Private methods at the bottom
func (s *UserService) validateUser(user *User) error {
if user.Name == "" {
return errors.New("name is required")
}
return nil
}
```
**Package-level functions:**
```go
package utils
// Public functions first
func ProcessData(data []byte) (Result, error) {
cleaned := sanitizeInput(data)
return parseData(cleaned)
}
func ValidateInput(input string) bool {
return isValidFormat(input) && checkLength(input)
}
// Private functions at the bottom
func sanitizeInput(data []byte) []byte {
// implementation
}
func parseData(data []byte) (Result, error) {
// implementation
}
func isValidFormat(input string) bool {
// implementation
}
func checkLength(input string) bool {
// implementation
}
```
**Test files:**
```go
package user_test
// Public test functions first
func Test_CreateUser_ValidInput_UserCreated(t *testing.T) {
user := createTestUser()
result, err := service.CreateUser(user)
assert.NoError(t, err)
assert.NotNil(t, result)
}
func Test_GetUser_ExistingUser_ReturnsUser(t *testing.T) {
user := createTestUser()
// test implementation
}
// Private helper functions at the bottom
func createTestUser() *User {
return &User{
Name: "Test User",
Email: "test@example.com",
}
}
func setupTestDatabase() *Database {
// setup implementation
}
```
**Controller example:**
```go
type ProjectController struct {
service *ProjectService
}
// Public HTTP handlers first
func (c *ProjectController) CreateProject(ctx *gin.Context) {
var request CreateProjectRequest
if err := ctx.ShouldBindJSON(&request); err != nil {
c.handleError(ctx, err)
return
}
// handler logic
}
func (c *ProjectController) GetProject(ctx *gin.Context) {
projectID := c.extractProjectID(ctx)
// handler logic
}
// Private helper methods at the bottom
func (c *ProjectController) handleError(ctx *gin.Context, err error) {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
}
func (c *ProjectController) extractProjectID(ctx *gin.Context) uuid.UUID {
return uuid.MustParse(ctx.Param("projectId"))
}
```
#### Key points:
- **Exported/Public** = starts with uppercase letter (CreateUser, GetProject)
- **Unexported/Private** = starts with lowercase letter (validateUser, handleError)
- This improves code readability by showing the public API first
- Private helpers are implementation details, so they go at the bottom
- Apply this rule consistently across ALL Go files in the project
---
### Boolean naming
**Always prefix boolean variables with verbs like `is`, `has`, `was`, `should`, `can`, etc.**
@@ -1001,6 +850,20 @@ func Test_BackupLifecycle_CreateAndDelete(t *testing.T) {
}
```
#### Cloud testing
If you are testing cloud, set isCloud = true before test run and defer isCloud = false after test run. Example helper function:
```go
func enableCloud(t *testing.T) {
t.Helper()
config.GetEnv().IsCloud = true
t.Cleanup(func() {
config.GetEnv().IsCloud = false
})
}
```
#### Testing utilities structure
**Create `testing.go` or `testing/testing.go` files with common utilities:**
@@ -1112,6 +975,100 @@ This ensures consistent timezone handling across the application.
---
### Logging
We use `log/slog` for structured logging. Follow these conventions to keep logs consistent, searchable, and useful for debugging.
#### Scoped loggers for tracing
Attach IDs via `logger.With(...)` as early as possible so every downstream log line carries them automatically. Common IDs: `database_id`, `subscription_id`, `backup_id`, `storage_id`, `user_id`.
```go
func (s *BillingService) CreateSubscription(logger *slog.Logger, user *users_models.User, databaseID uuid.UUID, storageGB int) {
logger = logger.With("database_id", databaseID)
// all subsequent logger calls automatically include database_id
logger.Debug(fmt.Sprintf("creating subscription for storage %d GB", storageGB))
}
```
For background services, create scoped loggers with `task_name` for each subtask in `Run()`:
```go
func (c *BackupCleaner) Run(ctx context.Context) {
retentionLog := c.logger.With("task_name", "clean_by_retention_policy")
exceededLog := c.logger.With("task_name", "clean_exceeded_backups")
// pass scoped logger to each method
c.cleanByRetentionPolicy(retentionLog)
c.cleanExceededBackups(exceededLog)
}
```
Within loops, scope further:
```go
for _, backupConfig := range enabledBackupConfigs {
dbLog := logger.With("database_id", backupConfig.DatabaseID, "policy", backupConfig.RetentionPolicyType)
// ...
}
```
#### Values in message, IDs as kv pairs
**Values and statuses** (sizes, counts, status transitions) go into the message via `fmt.Sprintf`:
```go
logger.Info(fmt.Sprintf("subscription renewed: %s -> %s, %d GB", oldStatus, newStatus, sub.StorageGB))
logger.Info(
fmt.Sprintf("deleted exceeded backup: backup size is %.1f MB, total size is %.1f MB, limit is %d MB",
backup.BackupSizeMb, backupsTotalSizeMB, limitPerDbMB),
"backup_id", backup.ID,
)
```
**IDs** stay as structured kv pairs — never inline them into the message string. This keeps them searchable in log aggregation tools:
```go
// good
logger.Info("deleted old backup", "backup_id", backup.ID)
// bad — ID buried in message, not searchable
logger.Info(fmt.Sprintf("deleted old backup %s", backup.ID))
```
**`error` is always a kv pair**, never inlined into the message:
```go
// good
logger.Error("failed to save subscription", "error", err)
// bad
logger.Error(fmt.Sprintf("failed to save subscription: %v", err))
```
#### Key naming and message style
- **snake_case for all log keys**: `database_id`, `backup_id`, `task_name`, `total_size_mb` — not camelCase
- **Lowercase log messages**: start with lowercase, no trailing period
```go
// good
logger.Error("failed to create checkout session", "error", err)
// bad
logger.Error("Failed to create checkout session.", "error", err)
```
#### Log level usage
- **Debug**: routine operations, entering a function, query results count (`"getting subscription events"`, `"found 5 invoices"`)
- **Info**: significant state changes, completed actions (`"subscription activated"`, `"deleted exceeded backup"`)
- **Warn**: degraded but recoverable situations (`"oldest backup is too recent to delete"`, `"requested storage is the same as current"`)
- **Error**: failures that need attention (`"failed to save subscription"`, `"failed to delete backup file"`)
---
### CRUD examples
This is an example of complete CRUD implementation structure:
@@ -1127,7 +1084,6 @@ import (
user_models "databasus-backend/internal/features/users/models"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type AuditLogController struct {
@@ -1135,7 +1091,6 @@ type AuditLogController struct {
}
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
// All audit log endpoints require authentication (handled in main.go)
auditRoutes := router.Group("/audit-logs")
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
@@ -1151,7 +1106,6 @@ func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
// @Security BearerAuth
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
// @Success 200 {object} GetAuditLogsResponse
// @Failure 401 {object} map[string]string
// @Failure 403 {object} map[string]string
@@ -1182,54 +1136,7 @@ func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
ctx.JSON(http.StatusOK, response)
}
// GetUserAuditLogs
// @Summary Get user audit logs
// @Description Retrieve audit logs for a specific user
// @Tags audit-logs
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param userId path string true "User ID"
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
// @Success 200 {object} GetAuditLogsResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 403 {object} map[string]string
// @Router /audit-logs/users/{userId} [get]
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
user, isOk := ctx.MustGet("user").(*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
userIDStr := ctx.Param("userId")
targetUserID, err := uuid.Parse(userIDStr)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
if err != nil {
if err.Error() == "insufficient permissions to view user audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
ctx.JSON(http.StatusOK, response)
}
// GetUserAuditLogs follows the same pattern...
```
#### controller_test.go
@@ -1237,34 +1144,13 @@ func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
```go
package audit_logs
import (
"fmt"
"net/http"
"testing"
"time"
user_enums "databasus-backend/internal/features/users/enums"
users_middleware "databasus-backend/internal/features/users/middleware"
users_services "databasus-backend/internal/features/users/services"
users_testing "databasus-backend/internal/features/users/testing"
"databasus-backend/internal/storage"
test_utils "databasus-backend/internal/util/testing"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_GetGlobalAuditLogs_AdminSucceedsAndMemberGetsForbidden(t *testing.T) {
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
memberUser := users_testing.CreateTestUser(user_enums.UserRoleMember)
router := createRouter()
service := GetAuditLogService()
projectID := uuid.New()
// Create test logs
createAuditLog(service, "Test log with user", &adminUser.UserID, nil)
createAuditLog(service, "Test log with project", nil, &projectID)
createAuditLog(service, "Test log standalone", nil, nil)
// Test ADMIN can access global logs
@@ -1272,13 +1158,9 @@ func Test_GetGlobalAuditLogs_AdminSucceedsAndMemberGetsForbidden(t *testing.T) {
test_utils.MakeGetRequestAndUnmarshal(t, router,
"/api/v1/audit-logs/global?limit=10", "Bearer "+adminUser.Token, http.StatusOK, &response)
assert.GreaterOrEqual(t, len(response.AuditLogs), 3)
assert.GreaterOrEqual(t, response.Total, int64(3))
assert.GreaterOrEqual(t, len(response.AuditLogs), 2)
messages := extractMessages(response.AuditLogs)
assert.Contains(t, messages, "Test log with user")
assert.Contains(t, messages, "Test log with project")
assert.Contains(t, messages, "Test log standalone")
// Test MEMBER cannot access global logs
resp := test_utils.MakeGetRequest(t, router, "/api/v1/audit-logs/global",
@@ -1286,79 +1168,6 @@ func Test_GetGlobalAuditLogs_AdminSucceedsAndMemberGetsForbidden(t *testing.T) {
assert.Contains(t, string(resp.Body), "only administrators can view global audit logs")
}
func Test_GetUserAuditLogs_PermissionsEnforcedCorrectly(t *testing.T) {
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
router := createRouter()
service := GetAuditLogService()
projectID := uuid.New()
// Create test logs for different users
createAuditLog(service, "Test log user1 first", &user1.UserID, nil)
createAuditLog(service, "Test log user1 second", &user1.UserID, &projectID)
createAuditLog(service, "Test log user2 first", &user2.UserID, nil)
createAuditLog(service, "Test log user2 second", &user2.UserID, &projectID)
createAuditLog(service, "Test project log", nil, &projectID)
// Test ADMIN can view any user's logs
var user1Response GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router,
fmt.Sprintf("/api/v1/audit-logs/users/%s?limit=10", user1.UserID.String()),
"Bearer "+adminUser.Token, http.StatusOK, &user1Response)
assert.Equal(t, 2, len(user1Response.AuditLogs))
messages := extractMessages(user1Response.AuditLogs)
assert.Contains(t, messages, "Test log user1 first")
assert.Contains(t, messages, "Test log user1 second")
// Test user can view own logs
var ownLogsResponse GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router,
fmt.Sprintf("/api/v1/audit-logs/users/%s", user2.UserID.String()),
"Bearer "+user2.Token, http.StatusOK, &ownLogsResponse)
assert.Equal(t, 2, len(ownLogsResponse.AuditLogs))
// Test user cannot view other user's logs
resp := test_utils.MakeGetRequest(t, router,
fmt.Sprintf("/api/v1/audit-logs/users/%s", user1.UserID.String()),
"Bearer "+user2.Token, http.StatusForbidden)
assert.Contains(t, string(resp.Body), "insufficient permissions")
}
func Test_FilterAuditLogsByTime_ReturnsOnlyLogsBeforeDate(t *testing.T) {
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
router := createRouter()
service := GetAuditLogService()
db := storage.GetDb()
baseTime := time.Now().UTC()
// Create logs with different timestamps
createTimedLog(db, &adminUser.UserID, "Test old log", baseTime.Add(-2*time.Hour))
createTimedLog(db, &adminUser.UserID, "Test recent log", baseTime.Add(-30*time.Minute))
createAuditLog(service, "Test current log", &adminUser.UserID, nil)
// Test filtering - get logs before 1 hour ago
beforeTime := baseTime.Add(-1 * time.Hour)
var filteredResponse GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router,
fmt.Sprintf("/api/v1/audit-logs/global?beforeDate=%s", beforeTime.Format(time.RFC3339)),
"Bearer "+adminUser.Token, http.StatusOK, &filteredResponse)
// Verify only old log is returned
messages := extractMessages(filteredResponse.AuditLogs)
assert.Contains(t, messages, "Test old log")
assert.NotContains(t, messages, "Test recent log")
assert.NotContains(t, messages, "Test current log")
// Test without filter - should get all logs
var allResponse GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router, "/api/v1/audit-logs/global",
"Bearer "+adminUser.Token, http.StatusOK, &allResponse)
assert.GreaterOrEqual(t, len(allResponse.AuditLogs), 3)
}
func createRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
@@ -1384,12 +1193,10 @@ import (
var auditLogRepository = &AuditLogRepository{}
var auditLogService = &AuditLogService{
auditLogRepository: auditLogRepository,
logger: logger.GetLogger(),
}
var auditLogController = &AuditLogController{
auditLogService: auditLogService,
auditLogRepository,
logger.GetLogger(),
}
var auditLogController = &AuditLogController{auditLogService}
func GetAuditLogService() *AuditLogService {
return auditLogService
@@ -1427,7 +1234,7 @@ type GetAuditLogsResponse struct {
}
```
#### models.go
#### model.go
```go
package audit_logs
@@ -1490,63 +1297,7 @@ func (r *AuditLogRepository) GetGlobal(limit, offset int, beforeDate *time.Time)
return auditLogs, err
}
func (r *AuditLogRepository) GetByUser(
userID uuid.UUID,
limit, offset int,
beforeDate *time.Time,
) ([]*AuditLog, error) {
var auditLogs []*AuditLog
query := storage.GetDb().
Where("user_id = ?", userID).
Order("created_at DESC")
if beforeDate != nil {
query = query.Where("created_at < ?", *beforeDate)
}
err := query.
Limit(limit).
Offset(offset).
Find(&auditLogs).Error
return auditLogs, err
}
func (r *AuditLogRepository) GetByProject(
projectID uuid.UUID,
limit, offset int,
beforeDate *time.Time,
) ([]*AuditLog, error) {
var auditLogs []*AuditLog
query := storage.GetDb().
Where("project_id = ?", projectID).
Order("created_at DESC")
if beforeDate != nil {
query = query.Where("created_at < ?", *beforeDate)
}
err := query.
Limit(limit).
Offset(offset).
Find(&auditLogs).Error
return auditLogs, err
}
func (r *AuditLogRepository) CountGlobal(beforeDate *time.Time) (int64, error) {
var count int64
query := storage.GetDb().Model(&AuditLog{})
if beforeDate != nil {
query = query.Where("created_at < ?", *beforeDate)
}
err := query.Count(&count).Error
return count, err
}
// GetByUser, GetByProject, CountGlobal follow the same pattern...
```
#### service.go
@@ -1570,11 +1321,7 @@ type AuditLogService struct {
logger *slog.Logger
}
func (s *AuditLogService) WriteAuditLog(
message string,
userID *uuid.UUID,
projectID *uuid.UUID,
) {
func (s *AuditLogService) WriteAuditLog(message string, userID *uuid.UUID, projectID *uuid.UUID) {
auditLog := &AuditLog{
UserID: userID,
ProjectID: projectID,
@@ -1582,17 +1329,11 @@ func (s *AuditLogService) WriteAuditLog(
CreatedAt: time.Now().UTC(),
}
err := s.auditLogRepository.Create(auditLog)
if err != nil {
if err := s.auditLogRepository.Create(auditLog); err != nil {
s.logger.Error("failed to create audit log", "error", err)
return
}
}
func (s *AuditLogService) CreateAuditLog(auditLog *AuditLog) error {
return s.auditLogRepository.Create(auditLog)
}
func (s *AuditLogService) GetGlobalAuditLogs(
user *user_models.User,
request *GetAuditLogsRequest,
@@ -1626,59 +1367,7 @@ func (s *AuditLogService) GetGlobalAuditLogs(
}, nil
}
func (s *AuditLogService) GetUserAuditLogs(
targetUserID uuid.UUID,
user *user_models.User,
request *GetAuditLogsRequest,
) (*GetAuditLogsResponse, error) {
// Users can view their own logs, ADMIN can view any user's logs
if user.Role != user_enums.UserRoleAdmin && user.ID != targetUserID {
return nil, errors.New("insufficient permissions to view user audit logs")
}
limit := request.Limit
if limit <= 0 || limit > 1000 {
limit = 100
}
offset := max(request.Offset, 0)
auditLogs, err := s.auditLogRepository.GetByUser(targetUserID, limit, offset, request.BeforeDate)
if err != nil {
return nil, err
}
return &GetAuditLogsResponse{
AuditLogs: auditLogs,
Total: int64(len(auditLogs)),
Limit: limit,
Offset: offset,
}, nil
}
func (s *AuditLogService) GetProjectAuditLogs(
projectID uuid.UUID,
request *GetAuditLogsRequest,
) (*GetAuditLogsResponse, error) {
limit := request.Limit
if limit <= 0 || limit > 1000 {
limit = 100
}
offset := max(request.Offset, 0)
auditLogs, err := s.auditLogRepository.GetByProject(projectID, limit, offset, request.BeforeDate)
if err != nil {
return nil, err
}
return &GetAuditLogsResponse{
AuditLogs: auditLogs,
Total: int64(len(auditLogs)),
Limit: limit,
Offset: offset,
}, nil
}
// GetUserAuditLogs, GetProjectAuditLogs follow the same pattern...
```
#### service_test.go
@@ -1686,34 +1375,16 @@ func (s *AuditLogService) GetProjectAuditLogs(
```go
package audit_logs
import (
"testing"
"time"
user_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
)
func Test_AuditLogs_ProjectSpecificLogs(t *testing.T) {
service := GetAuditLogService()
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
project1ID, project2ID := uuid.New(), uuid.New()
project1ID := uuid.New()
// Create test logs for projects
createAuditLog(service, "Test project1 log first", &user1.UserID, &project1ID)
createAuditLog(service, "Test project1 log second", &user2.UserID, &project1ID)
createAuditLog(service, "Test project2 log first", &user1.UserID, &project2ID)
createAuditLog(service, "Test project2 log second", &user2.UserID, &project2ID)
createAuditLog(service, "Test no project log", &user1.UserID, nil)
createAuditLog(service, "Test project1 log second", &user1.UserID, &project1ID)
request := &GetAuditLogsRequest{Limit: 10, Offset: 0}
// Test project 1 logs
project1Response, err := service.GetProjectAuditLogs(project1ID, request)
assert.NoError(t, err)
assert.Equal(t, 2, len(project1Response.AuditLogs))
@@ -1721,34 +1392,6 @@ func Test_AuditLogs_ProjectSpecificLogs(t *testing.T) {
messages := extractMessages(project1Response.AuditLogs)
assert.Contains(t, messages, "Test project1 log first")
assert.Contains(t, messages, "Test project1 log second")
for _, log := range project1Response.AuditLogs {
assert.Equal(t, &project1ID, log.ProjectID)
}
// Test project 2 logs
project2Response, err := service.GetProjectAuditLogs(project2ID, request)
assert.NoError(t, err)
assert.Equal(t, 2, len(project2Response.AuditLogs))
messages2 := extractMessages(project2Response.AuditLogs)
assert.Contains(t, messages2, "Test project2 log first")
assert.Contains(t, messages2, "Test project2 log second")
// Test pagination
limitedResponse, err := service.GetProjectAuditLogs(project1ID,
&GetAuditLogsRequest{Limit: 1, Offset: 0})
assert.NoError(t, err)
assert.Equal(t, 1, len(limitedResponse.AuditLogs))
assert.Equal(t, 1, limitedResponse.Limit)
// Test beforeDate filter
beforeTime := time.Now().UTC().Add(-1 * time.Minute)
filteredResponse, err := service.GetProjectAuditLogs(project1ID,
&GetAuditLogsRequest{Limit: 10, BeforeDate: &beforeTime})
assert.NoError(t, err)
for _, log := range filteredResponse.AuditLogs {
assert.True(t, log.CreatedAt.Before(beforeTime))
}
}
func createAuditLog(service *AuditLogService, message string, userID, projectID *uuid.UUID) {
@@ -1762,16 +1405,6 @@ func extractMessages(logs []*AuditLog) []string {
}
return messages
}
func createTimedLog(db *gorm.DB, userID *uuid.UUID, message string, createdAt time.Time) {
log := &AuditLog{
ID: uuid.New(),
UserID: userID,
Message: message,
CreatedAt: createdAt,
}
db.Create(log)
}
```
---

View File

@@ -316,7 +316,9 @@ window.__RUNTIME_CONFIG__ = {
GOOGLE_CLIENT_ID: '\${GOOGLE_CLIENT_ID:-}',
IS_EMAIL_CONFIGURED: '\$IS_EMAIL_CONFIGURED',
CLOUDFLARE_TURNSTILE_SITE_KEY: '\${CLOUDFLARE_TURNSTILE_SITE_KEY:-}',
CONTAINER_ARCH: '\${CONTAINER_ARCH:-unknown}'
CONTAINER_ARCH: '\${CONTAINER_ARCH:-unknown}',
CLOUD_PRICE_PER_GB: '\${CLOUD_PRICE_PER_GB:-}',
CLOUD_PADDLE_CLIENT_TOKEN: '\${CLOUD_PADDLE_CLIENT_TOKEN:-}'
};
JSEOF
@@ -329,6 +331,14 @@ if [ -n "\${ANALYTICS_SCRIPT:-}" ]; then
fi
fi
# Inject Paddle script if client token is provided (only if not already injected)
if [ -n "\${CLOUD_PADDLE_CLIENT_TOKEN:-}" ]; then
if ! grep -q "cdn.paddle.com" /app/ui/build/index.html 2>/dev/null; then
echo "Injecting Paddle script..."
sed -i "s#</head># <script src=\"https://cdn.paddle.com/paddle/v2/paddle.js\"></script>\n </head>#" /app/ui/build/index.html
fi
fi
# Inject static HTML into root div for cloud mode (payment system requires visible legal links)
if [ "\${IS_CLOUD:-false}" = "true" ]; then
if ! grep -q "cloud-static-content" /app/ui/build/index.html 2>/dev/null; then

22
NOTICE.md Normal file
View File

@@ -0,0 +1,22 @@
Copyright © 20252026 Rostislav Dugin and contributors.
“Databasus” is a trademark of Rostislav Dugin.
The source code in this repository is licensed under the Apache License, Version 2.0.
That license applies to the code only and does not grant any right to use the
Databasus name, logo, or branding, except for reasonable and customary referential
use in describing the origin of the software and reproducing the content of this NOTICE.
Permitted referential use includes truthful use of the name “Databasus” to identify
the original Databasus project in software catalogs, deployment templates, hosting
panels, package indexes, compatibility pages, integrations, tutorials, reviews, and
similar informational materials, including phrases such as “Databasus”,
“Deploy Databasus”, “Databasus on Coolify”, and “Compatible with Databasus”.
You may not use “Databasus” as the name or primary branding of a competing product,
service, fork, distribution, or hosted offering, or in any manner likely to cause
confusion as to source, affiliation, sponsorship, or endorsement.
Nothing in this repository transfers, waives, limits, or estops any rights in the
Databasus mark. All trademark rights are reserved except for the limited referential
use stated above.

View File

@@ -1,8 +1,8 @@
<div align="center">
<img src="assets/logo.svg" alt="Databasus Logo" width="250"/>
<h3>Backup tool for PostgreSQL, MySQL and MongoDB</h3>
<p>Databasus is a free, open source and self-hosted tool to backup databases (with focus on PostgreSQL). Make backups with different storages (S3, Google Drive, FTP, etc.) and notifications about progress (Slack, Discord, Telegram, etc.)</p>
<h3>PostgreSQL backup tool (with MySQL\MariaDB and MongoDB support)</h3>
<p>Databasus is a free, open source and self-hosted tool to backup databases (with primary focus on PostgreSQL). Make backups with different storages (S3, Google Drive, FTP, etc.) and notifications about progress (Slack, Discord, Telegram, etc.)</p>
<!-- Badges -->
[![PostgreSQL](https://img.shields.io/badge/PostgreSQL-336791?logo=postgresql&logoColor=white)](https://www.postgresql.org/)
@@ -259,7 +259,9 @@ Contributions are welcome! Read the <a href="https://databasus.com/contribute">c
Also you can join our large community of developers, DBAs and DevOps engineers on Telegram [@databasus_community](https://t.me/databasus_community).
## AI disclaimer
## FAQ
### AI disclaimer
There have been questions about AI usage in project development in issues and discussions. As the project focuses on security, reliability and production usage, it's important to explain how AI is used in the development process.
@@ -295,3 +297,11 @@ Moreover, it's important to note that we do not differentiate between bad human
Even if code is written manually by a human, it's not guaranteed to be merged. Vibe code is not allowed at all and all such PRs are rejected by default (see [contributing guide](https://databasus.com/contribute)).
We also draw attention to fast issue resolution and security [vulnerability reporting](https://github.com/databasus/databasus?tab=security-ov-file#readme).
### You have a cloud version — are you truly open source?
Yes. Every feature available in Databasus Cloud is equally available in the self-hosted version with no restrictions, no feature gates and no usage limits. The entire codebase is Apache 2.0 licensed and always will be.
Databasus is not "open core." We do not withhold features behind a paid tier and then call the limited remainder "open source," as projects like GitLab or Sentry do. We believe open source means the complete product is open, not just a marketing label on a stripped-down edition.
Databasus Cloud runs the exact same code as the self-hosted version. The only difference is that we take care of infrastructure, high availability, backups, reservations, monitoring and updates for you — so you don't have to. Revenue from Cloud funds full-time development of the project. Most large open-source projects rely on corporate backing or sponsorship to survive. We chose a different path: Databasus sustains itself so it can grow and improve independently, without being tied to any enterprise or sponsor.

View File

@@ -27,6 +27,13 @@ VALKEY_PORT=6379
VALKEY_USERNAME=
VALKEY_PASSWORD=
VALKEY_IS_SSL=false
# billing
PRICE_PER_GB_CENTS=
IS_PADDLE_SANDBOX=true
PADDLE_API_KEY=
PADDLE_WEBHOOK_SECRET=
PADDLE_PRICE_ID=
PADDLE_CLIENT_TOKEN=
# testing
# to get Google Drive env variables: add storage in UI and copy data from added storage here
TEST_GOOGLE_DRIVE_CLIENT_ID=

View File

@@ -9,6 +9,7 @@ import (
"os/exec"
"os/signal"
"path/filepath"
"runtime/debug"
"syscall"
"time"
@@ -25,6 +26,8 @@ import (
backups_download "databasus-backend/internal/features/backups/backups/download"
backups_services "databasus-backend/internal/features/backups/backups/services"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/billing"
billing_paddle "databasus-backend/internal/features/billing/paddle"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/disk"
"databasus-backend/internal/features/encryption/secrets"
@@ -105,7 +108,9 @@ func main() {
go generateSwaggerDocs(log)
gin.SetMode(gin.ReleaseMode)
ginApp := gin.Default()
ginApp := gin.New()
ginApp.Use(gin.Logger())
ginApp.Use(ginRecoveryWithLogger(log))
// Add GZIP compression middleware
ginApp.Use(gzip.Gzip(
@@ -217,6 +222,10 @@ func setUpRoutes(r *gin.Engine) {
backups_controllers.GetPostgresWalBackupController().RegisterRoutes(v1)
databases.GetDatabaseController().RegisterPublicRoutes(v1)
if config.GetEnv().IsCloud {
billing_paddle.GetPaddleBillingController().RegisterPublicRoutes(v1)
}
// Setup auth middleware
userService := users_services.GetUserService()
authMiddleware := users_middleware.AuthMiddleware(userService)
@@ -240,6 +249,7 @@ func setUpRoutes(r *gin.Engine) {
audit_logs.GetAuditLogController().RegisterRoutes(protected)
users_controllers.GetManagementController().RegisterRoutes(protected)
users_controllers.GetSettingsController().RegisterRoutes(protected)
billing.GetBillingController().RegisterRoutes(protected)
}
func setUpDependencies() {
@@ -252,6 +262,11 @@ func setUpDependencies() {
storages.SetupDependencies()
backups_config.SetupDependencies()
task_cancellation.SetupDependencies()
billing.SetupDependencies()
if config.GetEnv().IsCloud {
billing_paddle.SetupDependencies()
}
}
func runBackgroundTasks(log *slog.Logger) {
@@ -308,6 +323,12 @@ func runBackgroundTasks(log *slog.Logger) {
go runWithPanicLogging(log, "restore nodes registry background service", func() {
restoring.GetRestoreNodesRegistry().Run(ctx)
})
if config.GetEnv().IsCloud {
go runWithPanicLogging(log, "billing background service", func() {
billing.GetBillingService().Run(ctx, *log)
})
}
} else {
log.Info("Skipping primary node tasks as not primary node")
}
@@ -330,7 +351,7 @@ func runBackgroundTasks(log *slog.Logger) {
func runWithPanicLogging(log *slog.Logger, serviceName string, fn func()) {
defer func() {
if r := recover(); r != nil {
log.Error("Panic in "+serviceName, "error", r)
log.Error("Panic in "+serviceName, "error", r, "stacktrace", string(debug.Stack()))
}
}()
fn()
@@ -410,6 +431,25 @@ func enableCors(ginApp *gin.Engine) {
}
}
func ginRecoveryWithLogger(log *slog.Logger) gin.HandlerFunc {
return func(ctx *gin.Context) {
defer func() {
if r := recover(); r != nil {
log.Error("Panic recovered in HTTP handler",
"error", r,
"stacktrace", string(debug.Stack()),
"method", ctx.Request.Method,
"path", ctx.Request.URL.Path,
)
ctx.AbortWithStatus(http.StatusInternalServerError)
}
}()
ctx.Next()
}
}
func mountFrontend(ginApp *gin.Engine) {
staticDir := "./ui/build"
ginApp.NoRoute(func(c *gin.Context) {

View File

@@ -5,6 +5,7 @@ go 1.26.1
require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.3
github.com/PaddleHQ/paddle-go-sdk v1.0.0
github.com/gin-contrib/cors v1.7.5
github.com/gin-contrib/gzip v1.2.3
github.com/gin-gonic/gin v1.10.0
@@ -100,6 +101,8 @@ require (
github.com/emersion/go-message v0.18.2 // indirect
github.com/emersion/go-vcard v0.0.0-20241024213814-c9703dde27ff // indirect
github.com/flynn/noise v1.1.0 // indirect
github.com/ggicci/httpin v0.19.0 // indirect
github.com/ggicci/owl v0.8.2 // indirect
github.com/go-chi/chi/v5 v5.2.3 // indirect
github.com/go-darwin/apfs v0.0.0-20211011131704-f84b94dbf348 // indirect
github.com/go-git/go-billy/v5 v5.6.2 // indirect

View File

@@ -77,6 +77,8 @@ github.com/Max-Sum/base32768 v0.0.0-20230304063302-18e6ce5945fd/go.mod h1:C8yoIf
github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/PaddleHQ/paddle-go-sdk v1.0.0 h1:+EXitsPFbRcc0CpQE/MIeudxiVOR8pFe/aOWTEUHDKU=
github.com/PaddleHQ/paddle-go-sdk v1.0.0/go.mod h1:kbBBzf0BHEj38QvhtoELqlGip3alKgA/I+vl7RQzB58=
github.com/ProtonMail/bcrypt v0.0.0-20210511135022-227b4adcab57/go.mod h1:HecWFHognK8GfRDGnFQbW/LiV7A3MX3gZVs45vk5h8I=
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf h1:yc9daCCYUefEs69zUkSzubzjBbL+cmOXgnmt9Fyd9ug=
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf/go.mod h1:o0ESU9p83twszAU8LBeJKFAAMX14tISa0yk4Oo5TOqo=
@@ -248,6 +250,10 @@ github.com/gabriel-vasile/mimetype v1.4.11/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9t
github.com/geoffgarside/ber v1.1.0/go.mod h1:jVPKeCbj6MvQZhwLYsGwaGI52oUorHoHKNecGT85ZCc=
github.com/geoffgarside/ber v1.2.0 h1:/loowoRcs/MWLYmGX9QtIAbA+V/FrnVLsMMPhwiRm64=
github.com/geoffgarside/ber v1.2.0/go.mod h1:jVPKeCbj6MvQZhwLYsGwaGI52oUorHoHKNecGT85ZCc=
github.com/ggicci/httpin v0.19.0 h1:p0B3SWLVgg770VirYiHB14M5wdRx3zR8mCTzM/TkTQ8=
github.com/ggicci/httpin v0.19.0/go.mod h1:hzsQHcbqLabmGOycf7WNw6AAzcVbsMeoOp46bWAbIWc=
github.com/ggicci/owl v0.8.2 h1:og+lhqpzSMPDdEB+NJfzoAJARP7qCG3f8uUC3xvGukA=
github.com/ggicci/owl v0.8.2/go.mod h1:PHRD57u41vFN5UtFz2SF79yTVoM3HlWpjMiE+ZU2dj4=
github.com/gin-contrib/cors v1.7.5 h1:cXC9SmofOrRg0w9PigwGlHG3ztswH6bqq4vJVXnvYMk=
github.com/gin-contrib/cors v1.7.5/go.mod h1:4q3yi7xBEDDWKapjT2o1V7mScKDDr8k+jZ0fSquGoy0=
github.com/gin-contrib/gzip v1.2.3 h1:dAhT722RuEG330ce2agAs75z7yB+NKvX/ZM1r8w0u2U=
@@ -454,6 +460,8 @@ github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1
github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk=
github.com/jtolio/noiseconn v0.0.0-20231127013910-f6d9ecbf1de7 h1:JcltaO1HXM5S2KYOYcKgAV7slU0xPy1OcvrVgn98sRQ=
github.com/jtolio/noiseconn v0.0.0-20231127013910-f6d9ecbf1de7/go.mod h1:MEkhEPFwP3yudWO0lj6vfYpLIB+3eIcuIW+e0AZzUQk=
github.com/justinas/alice v1.2.0 h1:+MHSA/vccVCF4Uq37S42jwlkvI2Xzl7zTPCN5BnZNVo=
github.com/justinas/alice v1.2.0/go.mod h1:fN5HRH/reO/zrUflLfTN43t3vXvKzvZIENsNEe7i7qA=
github.com/jzelinskie/whirlpool v0.0.0-20201016144138-0675e54bb004 h1:G+9t9cEtnC9jFiTxyptEKuNIAbiN5ZCQzX2a74lj3xg=
github.com/jzelinskie/whirlpool v0.0.0-20201016144138-0675e54bb004/go.mod h1:KmHnJWQrgEvbuy0vcvj00gtMqbvNn1L+3YUZLK/B92c=
github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU=

View File

@@ -5,6 +5,7 @@ import (
"path/filepath"
"strings"
"sync"
"time"
"github.com/ilyakaznacheev/cleanenv"
"github.com/joho/godotenv"
@@ -53,6 +54,20 @@ type EnvVariables struct {
TempFolder string
SecretKeyPath string
// Billing (always tax-exclusive)
PricePerGBCents int64 `env:"PRICE_PER_GB_CENTS"`
MinStorageGB int
MaxStorageGB int
TrialDuration time.Duration
TrialStorageGB int
GracePeriod time.Duration
// Paddle billing
IsPaddleSandbox bool `env:"IS_PADDLE_SANDBOX"`
PaddleApiKey string `env:"PADDLE_API_KEY"`
PaddleWebhookSecret string `env:"PADDLE_WEBHOOK_SECRET"`
PaddlePriceID string `env:"PADDLE_PRICE_ID"`
PaddleClientToken string `env:"PADDLE_CLIENT_TOKEN"`
TestGoogleDriveClientID string `env:"TEST_GOOGLE_DRIVE_CLIENT_ID"`
TestGoogleDriveClientSecret string `env:"TEST_GOOGLE_DRIVE_CLIENT_SECRET"`
TestGoogleDriveTokenJSON string `env:"TEST_GOOGLE_DRIVE_TOKEN_JSON"`
@@ -132,9 +147,9 @@ var (
once sync.Once
)
func GetEnv() EnvVariables {
func GetEnv() *EnvVariables {
once.Do(loadEnvVariables)
return env
return &env
}
func loadEnvVariables() {
@@ -363,5 +378,39 @@ func loadEnvVariables() {
}
// Billing
if env.IsCloud {
if env.PricePerGBCents == 0 {
log.Error("PRICE_PER_GB_CENTS is empty or zero")
os.Exit(1)
}
if env.PaddleApiKey == "" {
log.Error("PADDLE_API_KEY is empty")
os.Exit(1)
}
if env.PaddleWebhookSecret == "" {
log.Error("PADDLE_WEBHOOK_SECRET is empty")
os.Exit(1)
}
if env.PaddlePriceID == "" {
log.Error("PADDLE_PRICE_ID is empty")
os.Exit(1)
}
if env.PaddleClientToken == "" {
log.Error("PADDLE_CLIENT_TOKEN is empty")
os.Exit(1)
}
}
env.MinStorageGB = 20
env.MaxStorageGB = 10_000
env.TrialDuration = 24 * time.Hour
env.TrialStorageGB = 20
env.GracePeriod = 30 * 24 * time.Hour
log.Info("Environment variables loaded successfully!")
}

View File

@@ -171,26 +171,6 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
backup.BackupSizeMb = completedMBs
backup.BackupDurationMs = time.Since(start).Milliseconds()
// Check size limit (0 = unlimited)
if backupConfig.MaxBackupSizeMB > 0 &&
completedMBs > float64(backupConfig.MaxBackupSizeMB) {
errMsg := fmt.Sprintf(
"backup size (%.2f MB) exceeded maximum allowed size (%d MB)",
completedMBs,
backupConfig.MaxBackupSizeMB,
)
backup.Status = backups_core.BackupStatusFailed
backup.IsSkipRetry = true
backup.FailMessage = &errMsg
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to save backup with size exceeded error", "error", err)
}
cancel() // Cancel the backup context
return
}
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to update backup progress", "error", err)
}

View File

@@ -153,121 +153,3 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
assert.Equal(t, notifier.ID, capturedNotifier.ID)
})
}
func Test_BackupSizeLimits(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
t.Run("UnlimitedSize_MaxBackupSizeMBIsZero_BackupCompletes", func(t *testing.T) {
// Enable backups with unlimited size (0)
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backupConfig.MaxBackupSizeMB = 0 // unlimited
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backuperNode := CreateTestBackuperNode()
backuperNode.createBackupUseCase = &CreateLargeBackupUsecase{}
// Create a backup record
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backuperNode.MakeBackup(backup.ID, false)
// Verify backup completed successfully even with large size
updatedBackup, err := backupRepository.FindByID(backup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
assert.Equal(t, float64(10000), updatedBackup.BackupSizeMb)
assert.Nil(t, updatedBackup.FailMessage)
})
t.Run("SizeExceeded_BackupFailedWithIsSkipRetry", func(t *testing.T) {
// Enable backups with 5 MB limit
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backupConfig.MaxBackupSizeMB = 5
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backuperNode := CreateTestBackuperNode()
backuperNode.createBackupUseCase = &CreateProgressiveBackupUsecase{}
// Create a backup record
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backuperNode.MakeBackup(backup.ID, false)
// Verify backup was marked as failed with IsSkipRetry=true
updatedBackup, err := backupRepository.FindByID(backup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusFailed, updatedBackup.Status)
assert.True(t, updatedBackup.IsSkipRetry)
assert.NotNil(t, updatedBackup.FailMessage)
assert.Contains(t, *updatedBackup.FailMessage, "exceeded maximum allowed size")
assert.Contains(t, *updatedBackup.FailMessage, "10.00 MB")
assert.Contains(t, *updatedBackup.FailMessage, "5 MB")
assert.Greater(t, updatedBackup.BackupSizeMb, float64(5))
})
t.Run("SizeWithinLimit_BackupCompletes", func(t *testing.T) {
// Enable backups with 100 MB limit
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backupConfig.MaxBackupSizeMB = 100
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backuperNode := CreateTestBackuperNode()
backuperNode.createBackupUseCase = &CreateMediumBackupUsecase{}
// Create a backup record
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backuperNode.MakeBackup(backup.ID, false)
// Verify backup completed successfully
updatedBackup, err := backupRepository.FindByID(backup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
assert.Equal(t, float64(50), updatedBackup.BackupSizeMb)
assert.Nil(t, updatedBackup.FailMessage)
})
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/google/uuid"
"databasus-backend/internal/config"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/storages"
@@ -26,6 +27,7 @@ type BackupCleaner struct {
backupRepository *backups_core.BackupRepository
storageService *storages.StorageService
backupConfigService *backups_config.BackupConfigService
billingService BillingService
fieldEncryptor util_encryption.FieldEncryptor
logger *slog.Logger
backupRemoveListeners []backups_core.BackupRemoveListener
@@ -44,6 +46,10 @@ func (c *BackupCleaner) Run(ctx context.Context) {
return
}
retentionLog := c.logger.With("task_name", "clean_by_retention_policy")
exceededLog := c.logger.With("task_name", "clean_exceeded_storage_backups")
staleLog := c.logger.With("task_name", "clean_stale_basebackups")
ticker := time.NewTicker(cleanerTickerInterval)
defer ticker.Stop()
@@ -52,16 +58,16 @@ func (c *BackupCleaner) Run(ctx context.Context) {
case <-ctx.Done():
return
case <-ticker.C:
if err := c.cleanByRetentionPolicy(); err != nil {
c.logger.Error("Failed to clean backups by retention policy", "error", err)
if err := c.cleanByRetentionPolicy(retentionLog); err != nil {
retentionLog.Error("failed to clean backups by retention policy", "error", err)
}
if err := c.cleanExceededBackups(); err != nil {
c.logger.Error("Failed to clean exceeded backups", "error", err)
if err := c.cleanExceededStorageBackups(exceededLog); err != nil {
exceededLog.Error("failed to clean exceeded backups", "error", err)
}
if err := c.cleanStaleUploadedBasebackups(); err != nil {
c.logger.Error("Failed to clean stale uploaded basebackups", "error", err)
if err := c.cleanStaleUploadedBasebackups(staleLog); err != nil {
staleLog.Error("failed to clean stale uploaded basebackups", "error", err)
}
}
}
@@ -104,7 +110,7 @@ func (c *BackupCleaner) AddBackupRemoveListener(listener backups_core.BackupRemo
c.backupRemoveListeners = append(c.backupRemoveListeners, listener)
}
func (c *BackupCleaner) cleanStaleUploadedBasebackups() error {
func (c *BackupCleaner) cleanStaleUploadedBasebackups(logger *slog.Logger) error {
staleBackups, err := c.backupRepository.FindStaleUploadedBasebackups(
time.Now().UTC().Add(-10 * time.Minute),
)
@@ -113,31 +119,30 @@ func (c *BackupCleaner) cleanStaleUploadedBasebackups() error {
}
for _, backup := range staleBackups {
backupLog := logger.With("database_id", backup.DatabaseID, "backup_id", backup.ID)
staleStorage, storageErr := c.storageService.GetStorageByID(backup.StorageID)
if storageErr != nil {
c.logger.Error(
"Failed to get storage for stale basebackup cleanup",
"backupId", backup.ID,
"storageId", backup.StorageID,
backupLog.Error(
"failed to get storage for stale basebackup cleanup",
"storage_id", backup.StorageID,
"error", storageErr,
)
} else {
if err := staleStorage.DeleteFile(c.fieldEncryptor, backup.FileName); err != nil {
c.logger.Error(
"Failed to delete stale basebackup file",
"backupId", backup.ID,
"fileName", backup.FileName,
"error", err,
backupLog.Error(
fmt.Sprintf("failed to delete stale basebackup file: %s", backup.FileName),
"error",
err,
)
}
metadataFileName := backup.FileName + ".metadata"
if err := staleStorage.DeleteFile(c.fieldEncryptor, metadataFileName); err != nil {
c.logger.Error(
"Failed to delete stale basebackup metadata file",
"backupId", backup.ID,
"fileName", metadataFileName,
"error", err,
backupLog.Error(
fmt.Sprintf("failed to delete stale basebackup metadata file: %s", metadataFileName),
"error",
err,
)
}
}
@@ -147,77 +152,67 @@ func (c *BackupCleaner) cleanStaleUploadedBasebackups() error {
backup.FailMessage = &failMsg
if err := c.backupRepository.Save(backup); err != nil {
c.logger.Error(
"Failed to mark stale uploaded basebackup as failed",
"backupId", backup.ID,
"error", err,
)
backupLog.Error("failed to mark stale uploaded basebackup as failed", "error", err)
continue
}
c.logger.Info(
"Marked stale uploaded basebackup as failed and cleaned storage",
"backupId", backup.ID,
"databaseId", backup.DatabaseID,
)
backupLog.Info("marked stale uploaded basebackup as failed and cleaned storage")
}
return nil
}
func (c *BackupCleaner) cleanByRetentionPolicy() error {
func (c *BackupCleaner) cleanByRetentionPolicy(logger *slog.Logger) error {
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
dbLog := logger.With("database_id", backupConfig.DatabaseID, "policy", backupConfig.RetentionPolicyType)
var cleanErr error
switch backupConfig.RetentionPolicyType {
case backups_config.RetentionPolicyTypeCount:
cleanErr = c.cleanByCount(backupConfig)
cleanErr = c.cleanByCount(dbLog, backupConfig)
case backups_config.RetentionPolicyTypeGFS:
cleanErr = c.cleanByGFS(backupConfig)
cleanErr = c.cleanByGFS(dbLog, backupConfig)
default:
cleanErr = c.cleanByTimePeriod(backupConfig)
cleanErr = c.cleanByTimePeriod(dbLog, backupConfig)
}
if cleanErr != nil {
c.logger.Error(
"Failed to clean backups by retention policy",
"databaseId", backupConfig.DatabaseID,
"policy", backupConfig.RetentionPolicyType,
"error", cleanErr,
)
dbLog.Error("failed to clean backups by retention policy", "error", cleanErr)
}
}
return nil
}
func (c *BackupCleaner) cleanExceededBackups() error {
func (c *BackupCleaner) cleanExceededStorageBackups(logger *slog.Logger) error {
if !config.GetEnv().IsCloud {
return nil
}
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
if backupConfig.MaxBackupsTotalSizeMB <= 0 {
dbLog := logger.With("database_id", backupConfig.DatabaseID)
subscription, subErr := c.billingService.GetSubscription(dbLog, backupConfig.DatabaseID)
if subErr != nil {
dbLog.Error("failed to get subscription for exceeded backups check", "error", subErr)
continue
}
if err := c.cleanExceededBackupsForDatabase(
backupConfig.DatabaseID,
backupConfig.MaxBackupsTotalSizeMB,
); err != nil {
c.logger.Error(
"Failed to clean exceeded backups for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
storageLimitMB := int64(subscription.GetBackupsStorageGB()) * 1024
if err := c.cleanExceededBackupsForDatabase(dbLog, backupConfig.DatabaseID, storageLimitMB); err != nil {
dbLog.Error("failed to clean exceeded backups for database", "error", err)
continue
}
}
@@ -225,7 +220,7 @@ func (c *BackupCleaner) cleanExceededBackups() error {
return nil
}
func (c *BackupCleaner) cleanByTimePeriod(backupConfig *backups_config.BackupConfig) error {
func (c *BackupCleaner) cleanByTimePeriod(logger *slog.Logger, backupConfig *backups_config.BackupConfig) error {
if backupConfig.RetentionTimePeriod == "" {
return nil
}
@@ -255,21 +250,17 @@ func (c *BackupCleaner) cleanByTimePeriod(backupConfig *backups_config.BackupCon
}
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
logger.Error("failed to delete old backup", "backup_id", backup.ID, "error", err)
continue
}
c.logger.Info(
"Deleted old backup",
"backupId", backup.ID,
"databaseId", backupConfig.DatabaseID,
)
logger.Info("deleted old backup", "backup_id", backup.ID)
}
return nil
}
func (c *BackupCleaner) cleanByCount(backupConfig *backups_config.BackupConfig) error {
func (c *BackupCleaner) cleanByCount(logger *slog.Logger, backupConfig *backups_config.BackupConfig) error {
if backupConfig.RetentionCount <= 0 {
return nil
}
@@ -298,28 +289,20 @@ func (c *BackupCleaner) cleanByCount(backupConfig *backups_config.BackupConfig)
}
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error(
"Failed to delete backup by count policy",
"backupId",
backup.ID,
"error",
err,
)
logger.Error("failed to delete backup by count policy", "backup_id", backup.ID, "error", err)
continue
}
c.logger.Info(
"Deleted backup by count policy",
"backupId", backup.ID,
"databaseId", backupConfig.DatabaseID,
"retentionCount", backupConfig.RetentionCount,
logger.Info(
fmt.Sprintf("deleted backup by count policy: retention count is %d", backupConfig.RetentionCount),
"backup_id", backup.ID,
)
}
return nil
}
func (c *BackupCleaner) cleanByGFS(backupConfig *backups_config.BackupConfig) error {
func (c *BackupCleaner) cleanByGFS(logger *slog.Logger, backupConfig *backups_config.BackupConfig) error {
if backupConfig.RetentionGfsHours <= 0 && backupConfig.RetentionGfsDays <= 0 &&
backupConfig.RetentionGfsWeeks <= 0 && backupConfig.RetentionGfsMonths <= 0 &&
backupConfig.RetentionGfsYears <= 0 {
@@ -357,29 +340,20 @@ func (c *BackupCleaner) cleanByGFS(backupConfig *backups_config.BackupConfig) er
}
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error(
"Failed to delete backup by GFS policy",
"backupId",
backup.ID,
"error",
err,
)
logger.Error("failed to delete backup by GFS policy", "backup_id", backup.ID, "error", err)
continue
}
c.logger.Info(
"Deleted backup by GFS policy",
"backupId", backup.ID,
"databaseId", backupConfig.DatabaseID,
)
logger.Info("deleted backup by GFS policy", "backup_id", backup.ID)
}
return nil
}
func (c *BackupCleaner) cleanExceededBackupsForDatabase(
logger *slog.Logger,
databaseID uuid.UUID,
limitperDbMB int64,
limitPerDbMB int64,
) error {
for {
backupsTotalSizeMB, err := c.backupRepository.GetTotalSizeByDatabase(databaseID)
@@ -387,7 +361,7 @@ func (c *BackupCleaner) cleanExceededBackupsForDatabase(
return err
}
if backupsTotalSizeMB <= float64(limitperDbMB) {
if backupsTotalSizeMB <= float64(limitPerDbMB) {
break
}
@@ -400,59 +374,27 @@ func (c *BackupCleaner) cleanExceededBackupsForDatabase(
}
if len(oldestBackups) == 0 {
c.logger.Warn(
"No backups to delete but still over limit",
"databaseId",
databaseID,
"totalSizeMB",
backupsTotalSizeMB,
"limitMB",
limitperDbMB,
)
logger.Warn(fmt.Sprintf(
"no backups to delete but still over limit: total size is %.1f MB, limit is %d MB",
backupsTotalSizeMB, limitPerDbMB,
))
break
}
backup := oldestBackups[0]
if isRecentBackup(backup) {
c.logger.Warn(
"Oldest backup is too recent to delete, stopping size cleanup",
"databaseId",
databaseID,
"backupId",
backup.ID,
"totalSizeMB",
backupsTotalSizeMB,
"limitMB",
limitperDbMB,
)
break
}
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error(
"Failed to delete exceeded backup",
"backupId",
backup.ID,
"databaseId",
databaseID,
"error",
err,
)
logger.Error("failed to delete exceeded backup", "backup_id", backup.ID, "error", err)
return err
}
c.logger.Info(
"Deleted exceeded backup",
"backupId",
backup.ID,
"databaseId",
databaseID,
"backupSizeMB",
backup.BackupSizeMb,
"totalSizeMB",
backupsTotalSizeMB,
"limitMB",
limitperDbMB,
logger.Info(
fmt.Sprintf("deleted exceeded backup: backup size is %.1f MB, total size is %.1f MB, limit is %d MB",
backup.BackupSizeMb, backupsTotalSizeMB, limitPerDbMB),
"backup_id", backup.ID,
)
}

View File

@@ -425,7 +425,7 @@ func Test_CleanByGFS_KeepsCorrectBackupsPerSlot(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -502,7 +502,7 @@ func Test_CleanByGFS_WithWeeklyAndMonthlySlots_KeepsWiderSpread(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -576,7 +576,7 @@ func Test_CleanByGFS_WithHourlySlots_KeepsCorrectBackups(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -677,7 +677,7 @@ func Test_CleanByGFS_SkipsRecentBackup_WhenNotInKeepSet(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -759,7 +759,7 @@ func Test_CleanByGFS_With20DailyBackups_KeepsOnlyExpectedCount(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -844,7 +844,7 @@ func Test_CleanByGFS_WithMultipleBackupsPerDay_KeepsOnlyOnePerDailySlot(t *testi
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -929,7 +929,7 @@ func Test_CleanByGFS_With24HourlySlotsAnd23DailyBackups_DeletesExcessBackups(t *
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -999,7 +999,7 @@ func Test_CleanByGFS_WithDisabledHourlySlotsAnd23DailyBackups_DeletesExcessBacku
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -1069,7 +1069,7 @@ func Test_CleanByGFS_WithDailySlotsAndWeeklyBackups_DeletesExcessBackups(t *test
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -1152,7 +1152,7 @@ func Test_CleanByGFS_WithWeeklySlotsAndMonthlyBackups_DeletesExcessBackups(t *te
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)

View File

@@ -1,14 +1,17 @@
package backuping
import (
"log/slog"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
billing_models "databasus-backend/internal/features/billing/models"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
@@ -17,6 +20,7 @@ import (
users_testing "databasus-backend/internal/features/users/testing"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/storage"
"databasus-backend/internal/util/logger"
"databasus-backend/internal/util/period"
)
@@ -51,6 +55,7 @@ func Test_CleanOldBackups_DeletesBackupsOlderThanRetentionTimePeriod(t *testing.
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -89,7 +94,7 @@ func Test_CleanOldBackups_DeletesBackupsOlderThanRetentionTimePeriod(t *testing.
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -129,6 +134,7 @@ func Test_CleanOldBackups_SkipsDatabaseWithForeverRetentionPeriod(t *testing.T)
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -145,7 +151,7 @@ func Test_CleanOldBackups_SkipsDatabaseWithForeverRetentionPeriod(t *testing.T)
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -154,7 +160,8 @@ func Test_CleanOldBackups_SkipsDatabaseWithForeverRetentionPeriod(t *testing.T)
assert.Equal(t, oldBackup.ID, remainingBackups[0].ID)
}
func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
func Test_CleanExceededBackups_WhenUnderStorageLimit_NoBackupsDeleted(t *testing.T) {
enableCloud(t)
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
@@ -178,14 +185,14 @@ func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 100,
BackupIntervalID: interval.ID,
BackupInterval: interval,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -196,15 +203,18 @@ func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 16.67,
BackupSizeMb: 100,
CreatedAt: time.Now().UTC().Add(-time.Duration(i) * time.Hour),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
}
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
}
cleaner := CreateTestBackupCleaner(mockBilling)
err = cleaner.cleanExceededStorageBackups(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -212,7 +222,8 @@ func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
assert.Equal(t, 3, len(remainingBackups))
}
func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T) {
func Test_CleanExceededBackups_WhenOverStorageLimit_DeletesOldestBackups(t *testing.T) {
enableCloud(t)
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
@@ -236,18 +247,20 @@ func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T)
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 30,
BackupIntervalID: interval.ID,
BackupInterval: interval,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// 5 backups at 300 MB each = 1500 MB total, limit = 1 GB (1024 MB)
// Expect 2 oldest deleted, 3 remain (900 MB < 1024 MB)
now := time.Now().UTC()
var backupIDs []uuid.UUID
for i := 0; i < 5; i++ {
@@ -256,7 +269,7 @@ func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T)
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
BackupSizeMb: 300,
CreatedAt: now.Add(-time.Duration(4-i) * time.Hour),
}
err = backupRepository.Save(backup)
@@ -264,8 +277,11 @@ func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T)
backupIDs = append(backupIDs, backup.ID)
}
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
}
cleaner := CreateTestBackupCleaner(mockBilling)
err = cleaner.cleanExceededStorageBackups(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -284,6 +300,7 @@ func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T)
}
func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
enableCloud(t)
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
@@ -307,20 +324,21 @@ func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 50,
BackupIntervalID: interval.ID,
BackupInterval: interval,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
now := time.Now().UTC()
// 3 completed at 500 MB each = 1500 MB, limit = 1 GB (1024 MB)
completedBackups := make([]*backups_core.Backup, 3)
for i := 0; i < 3; i++ {
backup := &backups_core.Backup{
@@ -328,7 +346,7 @@ func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 30,
BackupSizeMb: 500,
CreatedAt: now.Add(-time.Duration(3-i) * time.Hour),
}
err = backupRepository.Save(backup)
@@ -347,8 +365,11 @@ func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
err = backupRepository.Save(inProgressBackup)
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
}
cleaner := CreateTestBackupCleaner(mockBilling)
err = cleaner.cleanExceededStorageBackups(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -365,7 +386,8 @@ func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
assert.True(t, inProgressFound, "In-progress backup should not be deleted")
}
func Test_CleanExceededBackups_WithZeroLimit_SkipsDatabase(t *testing.T) {
func Test_CleanExceededBackups_WithZeroStorageLimit_RemovesAllBackups(t *testing.T) {
enableCloud(t)
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
@@ -389,14 +411,14 @@ func Test_CleanExceededBackups_WithZeroLimit_SkipsDatabase(t *testing.T) {
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 0,
BackupIntervalID: interval.ID,
BackupInterval: interval,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -408,19 +430,23 @@ func Test_CleanExceededBackups_WithZeroLimit_SkipsDatabase(t *testing.T) {
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 100,
CreatedAt: time.Now().UTC().Add(-time.Duration(i) * time.Hour),
CreatedAt: time.Now().UTC().Add(-time.Duration(i+2) * time.Hour),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
}
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
// StorageGB=0 means no storage allowed — all backups should be removed
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{StorageGB: 0, Status: billing_models.StatusActive},
}
cleaner := CreateTestBackupCleaner(mockBilling)
err = cleaner.cleanExceededStorageBackups(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Equal(t, 10, len(remainingBackups))
assert.Equal(t, 0, len(remainingBackups))
}
func Test_GetTotalSizeByDatabase_CalculatesCorrectly(t *testing.T) {
@@ -522,6 +548,7 @@ func Test_CleanByCount_KeepsNewestNBackups_DeletesOlder(t *testing.T) {
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -545,7 +572,7 @@ func Test_CleanByCount_KeepsNewestNBackups_DeletesOlder(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -594,6 +621,7 @@ func Test_CleanByCount_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -612,7 +640,7 @@ func Test_CleanByCount_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -651,6 +679,7 @@ func Test_CleanByCount_DoesNotDeleteInProgressBackups(t *testing.T) {
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -682,7 +711,7 @@ func Test_CleanByCount_DoesNotDeleteInProgressBackups(t *testing.T) {
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -776,6 +805,7 @@ func Test_CleanByTimePeriod_SkipsRecentBackup_EvenIfOlderThanRetention(t *testin
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -805,7 +835,7 @@ func Test_CleanByTimePeriod_SkipsRecentBackup_EvenIfOlderThanRetention(t *testin
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -847,6 +877,7 @@ func Test_CleanByCount_SkipsRecentBackup_EvenIfOverLimit(t *testing.T) {
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -893,7 +924,7 @@ func Test_CleanByCount_SkipsRecentBackup_EvenIfOverLimit(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -914,7 +945,8 @@ func Test_CleanByCount_SkipsRecentBackup_EvenIfOverLimit(t *testing.T) {
assert.True(t, remainingIDs[newestBackup.ID], "Newest backup should be preserved")
}
func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testing.T) {
func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverStorageLimit(t *testing.T) {
enableCloud(t)
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
@@ -937,18 +969,18 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
interval := createTestInterval()
// Total size limit is 10 MB. We have two backups of 8 MB each (16 MB total).
// Total size limit = 1 GB (1024 MB). Two backups of 600 MB each (1200 MB total).
// The oldest backup was created 30 minutes ago — within the grace period.
// The cleaner must stop and leave both backups intact.
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 10,
BackupIntervalID: interval.ID,
BackupInterval: interval,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -960,7 +992,7 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 8,
BackupSizeMb: 600,
CreatedAt: now.Add(-30 * time.Minute),
}
newerRecentBackup := &backups_core.Backup{
@@ -968,7 +1000,7 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 8,
BackupSizeMb: 600,
CreatedAt: now.Add(-10 * time.Minute),
}
@@ -977,8 +1009,11 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
err = backupRepository.Save(newerRecentBackup)
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
}
cleaner := CreateTestBackupCleaner(mockBilling)
err = cleaner.cleanExceededStorageBackups(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -991,6 +1026,82 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
)
}
func Test_CleanExceededStorageBackups_WhenNonCloud_SkipsCleanup(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// 5 backups at 500 MB each = 2500 MB, would exceed 1 GB limit in cloud mode
now := time.Now().UTC()
for i := 0; i < 5; i++ {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 500,
CreatedAt: now.Add(-time.Duration(i+2) * time.Hour),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
}
// IsCloud is false by default — cleaner should skip entirely
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
}
cleaner := CreateTestBackupCleaner(mockBilling)
err = cleaner.cleanExceededStorageBackups(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Equal(t, 5, len(remainingBackups), "All backups must remain in non-cloud mode")
}
type mockBillingService struct {
subscription *billing_models.Subscription
err error
}
func (m *mockBillingService) GetSubscription(
logger *slog.Logger,
databaseID uuid.UUID,
) (*billing_models.Subscription, error) {
return m.subscription, m.err
}
// Mock listener for testing
type mockBackupRemoveListener struct {
onBeforeBackupRemove func(*backups_core.Backup) error
@@ -1041,7 +1152,7 @@ func Test_CleanStaleUploadedBasebackups_MarksAsFailed(t *testing.T) {
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanStaleUploadedBasebackups()
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
assert.NoError(t, err)
updated, err := backupRepository.FindByID(staleBackup.ID)
@@ -1088,7 +1199,7 @@ func Test_CleanStaleUploadedBasebackups_SkipsRecentUploads(t *testing.T) {
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanStaleUploadedBasebackups()
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
assert.NoError(t, err)
updated, err := backupRepository.FindByID(recentBackup.ID)
@@ -1131,7 +1242,7 @@ func Test_CleanStaleUploadedBasebackups_SkipsActiveStreaming(t *testing.T) {
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanStaleUploadedBasebackups()
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
assert.NoError(t, err)
updated, err := backupRepository.FindByID(activeBackup.ID)
@@ -1179,7 +1290,7 @@ func Test_CleanStaleUploadedBasebackups_CleansStorageFiles(t *testing.T) {
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanStaleUploadedBasebackups()
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
assert.NoError(t, err)
updated, err := backupRepository.FindByID(staleBackup.ID)
@@ -1189,6 +1300,18 @@ func Test_CleanStaleUploadedBasebackups_CleansStorageFiles(t *testing.T) {
assert.Contains(t, *updated.FailMessage, "finalization timed out")
}
func enableCloud(t *testing.T) {
t.Helper()
config.GetEnv().IsCloud = true
t.Cleanup(func() {
config.GetEnv().IsCloud = false
})
}
func testLogger() *slog.Logger {
return logger.GetLogger().With("task_name", "test")
}
func createTestInterval() *intervals.Interval {
timeOfDay := "04:00"
interval := &intervals.Interval{

View File

@@ -10,6 +10,7 @@ import (
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/usecases"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/billing"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
@@ -28,6 +29,7 @@ var backupCleaner = &BackupCleaner{
backupRepository,
storages.GetStorageService(),
backups_config.GetBackupConfigService(),
billing.GetBillingService(),
encryption.GetFieldEncryptor(),
logger.GetLogger(),
[]backups_core.BackupRemoveListener{},
@@ -73,6 +75,7 @@ var backupsScheduler = &BackupsScheduler{
taskCancelManager,
backupNodesRegistry,
databases.GetDatabaseService(),
billing.GetBillingService(),
time.Now().UTC(),
logger.GetLogger(),
make(map[uuid.UUID]BackupToNodeRelation),

View File

@@ -0,0 +1,13 @@
package backuping
import (
"log/slog"
"github.com/google/uuid"
billing_models "databasus-backend/internal/features/billing/models"
)
type BillingService interface {
GetSubscription(logger *slog.Logger, databaseID uuid.UUID) (*billing_models.Subscription, error)
}

View File

@@ -29,6 +29,7 @@ type BackupsScheduler struct {
taskCancelManager *task_cancellation.TaskCancelManager
backupNodesRegistry *BackupNodesRegistry
databaseService *databases.DatabaseService
billingService BillingService
lastBackupTime time.Time
logger *slog.Logger
@@ -127,6 +128,34 @@ func (s *BackupsScheduler) StartBackup(database *databases.Database, isCallNotif
return
}
if config.GetEnv().IsCloud {
subscription, subErr := s.billingService.GetSubscription(s.logger, database.ID)
if subErr != nil || !subscription.CanCreateNewBackups() {
failMessage := "subscription has expired, please renew"
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: *backupConfig.StorageID,
Status: backups_core.BackupStatusFailed,
FailMessage: &failMessage,
IsSkipRetry: true,
CreatedAt: time.Now().UTC(),
}
backup.GenerateFilename(database.Name)
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error(
"failed to save failed backup for expired subscription",
"database_id", database.ID,
"error", err,
)
}
return
}
}
// Check for existing in-progress backups
inProgressBackups, err := s.backupRepository.FindByDatabaseIdAndStatus(
database.ID,
@@ -346,6 +375,27 @@ func (s *BackupsScheduler) runPendingBackups() error {
continue
}
if config.GetEnv().IsCloud {
subscription, subErr := s.billingService.GetSubscription(s.logger, backupConfig.DatabaseID)
if subErr != nil {
s.logger.Warn(
"failed to get subscription, skipping backup",
"database_id", backupConfig.DatabaseID,
"error", subErr,
)
continue
}
if !subscription.CanCreateNewBackups() {
s.logger.Debug(
"subscription is not active, skipping scheduled backup",
"database_id", backupConfig.DatabaseID,
"subscription_status", subscription.Status,
)
continue
}
}
s.StartBackup(database, remainedBackupTryCount == 1)
continue
}

View File

@@ -10,6 +10,7 @@ import (
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
billing_models "databasus-backend/internal/features/billing/models"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
@@ -968,7 +969,7 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T
cache_utils.ClearAllCache()
// Start scheduler so it can handle task completions
scheduler := CreateTestScheduler()
scheduler := CreateTestScheduler(nil)
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
@@ -1065,7 +1066,7 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
cache_utils.ClearAllCache()
// Start scheduler so it can handle task completions
scheduler := CreateTestScheduler()
scheduler := CreateTestScheduler(nil)
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
@@ -1332,7 +1333,7 @@ func Test_StartBackup_When2BackupsStartedForDifferentDatabases_BothUseCasesAreCa
defer StopBackuperNodeForTest(t, cancel, backuperNode)
// Create scheduler
scheduler := CreateTestScheduler()
scheduler := CreateTestScheduler(nil)
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
@@ -1458,3 +1459,313 @@ func Test_StartBackup_When2BackupsStartedForDifferentDatabases_BothUseCasesAreCa
time.Sleep(200 * time.Millisecond)
}
func Test_StartBackup_WhenCloudAndSubscriptionExpired_CreatesFailedBackup(t *testing.T) {
cache_utils.ClearAllCache()
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{
Status: billing_models.StatusExpired,
},
}
scheduler := CreateTestScheduler(mockBilling)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
enableCloud(t)
scheduler.StartBackup(database, false)
time.Sleep(100 * time.Millisecond)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
newestBackup := backups[0]
assert.Equal(t, backups_core.BackupStatusFailed, newestBackup.Status)
assert.NotNil(t, newestBackup.FailMessage)
assert.Equal(t, "subscription has expired, please renew", *newestBackup.FailMessage)
assert.True(t, newestBackup.IsSkipRetry)
time.Sleep(200 * time.Millisecond)
}
func Test_StartBackup_WhenCloudAndSubscriptionActive_ProceedsNormally(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{
Status: billing_models.StatusActive,
StorageGB: 10,
},
}
scheduler := CreateTestScheduler(mockBilling)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
enableCloud(t)
scheduler.StartBackup(database, false)
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
newestBackup := backups[0]
assert.Equal(t, backups_core.BackupStatusCompleted, newestBackup.Status)
time.Sleep(200 * time.Millisecond)
}
func Test_RunPendingBackups_WhenCloudAndSubscriptionExpired_SilentlySkips(t *testing.T) {
cache_utils.ClearAllCache()
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{
Status: billing_models.StatusExpired,
},
}
scheduler := CreateTestScheduler(mockBilling)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
})
enableCloud(t)
scheduler.runPendingBackups()
time.Sleep(100 * time.Millisecond)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1, "No new backup should be created, scheduler silently skips expired subscriptions")
time.Sleep(200 * time.Millisecond)
}
func Test_StartBackup_WhenNotCloudAndSubscriptionExpired_ProceedsNormally(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{
Status: billing_models.StatusExpired,
},
}
scheduler := CreateTestScheduler(mockBilling)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
scheduler.StartBackup(database, false)
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
assert.Equal(t, backups_core.BackupStatusCompleted, backups[0].Status,
"Billing check should not apply in non-cloud mode")
time.Sleep(200 * time.Millisecond)
}
func Test_RunPendingBackups_WhenNotCloudAndSubscriptionExpired_ProceedsNormally(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{
Status: billing_models.StatusExpired,
},
}
scheduler := CreateTestScheduler(mockBilling)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
})
scheduler.runPendingBackups()
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 2, "Billing check should not apply in non-cloud mode, new backup should be created")
time.Sleep(200 * time.Millisecond)
}

View File

@@ -35,58 +35,74 @@ func CreateTestRouter() *gin.Engine {
return router
}
func CreateTestBackupCleaner(billingService BillingService) *BackupCleaner {
return &BackupCleaner{
backupRepository,
storages.GetStorageService(),
backups_config.GetBackupConfigService(),
billingService,
encryption.GetFieldEncryptor(),
logger.GetLogger(),
[]backups_core.BackupRemoveListener{},
sync.Once{},
atomic.Bool{},
}
}
func CreateTestBackuperNode() *BackuperNode {
return &BackuperNode{
databaseService: databases.GetDatabaseService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
workspaceService: workspaces_services.GetWorkspaceService(),
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
notificationSender: notifiers.GetNotifierService(),
backupCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
logger: logger.GetLogger(),
createBackupUseCase: usecases.GetCreateBackupUsecase(),
nodeID: uuid.New(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
databases.GetDatabaseService(),
encryption.GetFieldEncryptor(),
workspaces_services.GetWorkspaceService(),
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
taskCancelManager,
backupNodesRegistry,
logger.GetLogger(),
usecases.GetCreateBackupUsecase(),
uuid.New(),
time.Time{},
sync.Once{},
atomic.Bool{},
}
}
func CreateTestBackuperNodeWithUseCase(useCase backups_core.CreateBackupUsecase) *BackuperNode {
return &BackuperNode{
databaseService: databases.GetDatabaseService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
workspaceService: workspaces_services.GetWorkspaceService(),
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
notificationSender: notifiers.GetNotifierService(),
backupCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
logger: logger.GetLogger(),
createBackupUseCase: useCase,
nodeID: uuid.New(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
databases.GetDatabaseService(),
encryption.GetFieldEncryptor(),
workspaces_services.GetWorkspaceService(),
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
taskCancelManager,
backupNodesRegistry,
logger.GetLogger(),
useCase,
uuid.New(),
time.Time{},
sync.Once{},
atomic.Bool{},
}
}
func CreateTestScheduler() *BackupsScheduler {
func CreateTestScheduler(billingService BillingService) *BackupsScheduler {
return &BackupsScheduler{
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
taskCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
lastBackupTime: time.Now().UTC(),
logger: logger.GetLogger(),
backupToNodeRelations: make(map[uuid.UUID]BackupToNodeRelation),
backuperNode: CreateTestBackuperNode(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
backupRepository,
backups_config.GetBackupConfigService(),
taskCancelManager,
backupNodesRegistry,
databases.GetDatabaseService(),
billingService,
time.Now().UTC(),
logger.GetLogger(),
make(map[uuid.UUID]BackupToNodeRelation),
CreateTestBackuperNode(),
sync.Once{},
atomic.Bool{},
}
}

View File

@@ -1263,7 +1263,7 @@ func Test_MakeBackup_VerifyBackupAndMetadataFilesExistInStorage(t *testing.T) {
backuperCancel := backuping.StartBackuperNodeForTest(t, backuperNode)
defer backuping.StopBackuperNodeForTest(t, backuperCancel, backuperNode)
scheduler := backuping.CreateTestScheduler()
scheduler := backuping.CreateTestScheduler(nil)
schedulerCancel := backuping.StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
@@ -1838,7 +1838,7 @@ func Test_DeleteBackup_RemovesBackupAndMetadataFilesFromDisk(t *testing.T) {
backuperCancel := backuping.StartBackuperNodeForTest(t, backuperNode)
defer backuping.StopBackuperNodeForTest(t, backuperCancel, backuperNode)
scheduler := backuping.CreateTestScheduler()
scheduler := backuping.CreateTestScheduler(nil)
schedulerCancel := backuping.StartSchedulerForTest(t, scheduler)
defer schedulerCancel()

View File

@@ -16,7 +16,6 @@ type BackupConfigController struct {
func (c *BackupConfigController) RegisterRoutes(router *gin.RouterGroup) {
router.POST("/backup-configs/save", c.SaveBackupConfig)
router.GET("/backup-configs/database/:id/plan", c.GetDatabasePlan)
router.GET("/backup-configs/database/:id", c.GetBackupConfigByDbID)
router.GET("/backup-configs/storage/:id/is-using", c.IsStorageUsing)
router.GET("/backup-configs/storage/:id/databases-count", c.CountDatabasesForStorage)
@@ -93,39 +92,6 @@ func (c *BackupConfigController) GetBackupConfigByDbID(ctx *gin.Context) {
ctx.JSON(http.StatusOK, backupConfig)
}
// GetDatabasePlan
// @Summary Get database plan by database ID
// @Description Get the plan limits for a specific database (max backup size, max total size, max storage period)
// @Tags backup-configs
// @Produce json
// @Param id path string true "Database ID"
// @Success 200 {object} plans.DatabasePlan
// @Failure 400 {object} map[string]string "Invalid database ID"
// @Failure 401 {object} map[string]string "User not authenticated"
// @Failure 404 {object} map[string]string "Database not found or access denied"
// @Router /backup-configs/database/{id}/plan [get]
func (c *BackupConfigController) GetDatabasePlan(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
return
}
plan, err := c.backupConfigService.GetDatabasePlan(user, id)
if err != nil {
ctx.JSON(http.StatusNotFound, gin.H{"error": "database plan not found"})
return
}
ctx.JSON(http.StatusOK, plan)
}
// IsStorageUsing
// @Summary Check if storage is being used
// @Description Check if a storage is currently being used by any backup configuration

View File

@@ -17,14 +17,12 @@ import (
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/features/storages"
local_storage "databasus-backend/internal/features/storages/models/local"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/storage"
"databasus-backend/internal/util/period"
test_utils "databasus-backend/internal/util/testing"
"databasus-backend/internal/util/tools"
@@ -326,218 +324,13 @@ func Test_GetBackupConfigByDbID_ReturnsDefaultConfigForNewDatabase(t *testing.T)
&response,
)
var plan plans.DatabasePlan
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
"Bearer "+owner.Token,
http.StatusOK,
&plan,
)
assert.Equal(t, database.ID, response.DatabaseID)
assert.False(t, response.IsBackupsEnabled)
assert.Equal(t, plan.MaxStoragePeriod, response.RetentionTimePeriod)
assert.Equal(t, plan.MaxBackupSizeMB, response.MaxBackupSizeMB)
assert.Equal(t, plan.MaxBackupsTotalSizeMB, response.MaxBackupsTotalSizeMB)
assert.True(t, response.IsRetryIfFailed)
assert.Equal(t, 3, response.MaxFailedTriesCount)
assert.NotNil(t, response.BackupInterval)
}
func Test_GetDatabasePlan_ForNewDatabase_PlanAlwaysReturned(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
var response plans.DatabasePlan
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
"Bearer "+owner.Token,
http.StatusOK,
&response,
)
assert.Equal(t, database.ID, response.DatabaseID)
assert.NotNil(t, response.MaxBackupSizeMB)
assert.NotNil(t, response.MaxBackupsTotalSizeMB)
assert.NotEmpty(t, response.MaxStoragePeriod)
}
func Test_SaveBackupConfig_WhenPlanLimitsAreAdjusted_ValidationEnforced(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Get plan via API (triggers auto-creation)
var plan plans.DatabasePlan
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
"Bearer "+owner.Token,
http.StatusOK,
&plan,
)
assert.Equal(t, database.ID, plan.DatabaseID)
// Adjust plan limits directly in database to fixed restrictive values
err := storage.GetDb().Model(&plans.DatabasePlan{}).
Where("database_id = ?", database.ID).
Updates(map[string]any{
"max_backup_size_mb": 100,
"max_backups_total_size_mb": 1000,
"max_storage_period": period.PeriodMonth,
}).Error
assert.NoError(t, err)
// Test 1: Try to save backup config with exceeded backup size limit
timeOfDay := "04:00"
backupConfigExceededSize := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 200, // Exceeds limit of 100
MaxBackupsTotalSizeMB: 800,
}
respExceededSize := test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigExceededSize,
http.StatusBadRequest,
)
assert.Contains(t, string(respExceededSize.Body), "max backup size exceeds plan limit")
// Test 2: Try to save backup config with exceeded total size limit
backupConfigExceededTotal := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 50,
MaxBackupsTotalSizeMB: 2000, // Exceeds limit of 1000
}
respExceededTotal := test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigExceededTotal,
http.StatusBadRequest,
)
assert.Contains(t, string(respExceededTotal.Body), "max total backups size exceeds plan limit")
// Test 3: Try to save backup config with exceeded storage period limit
backupConfigExceededPeriod := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodYear, // Exceeds limit of Month
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 80,
MaxBackupsTotalSizeMB: 800,
}
respExceededPeriod := test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigExceededPeriod,
http.StatusBadRequest,
)
assert.Contains(t, string(respExceededPeriod.Body), "storage period exceeds plan limit")
// Test 4: Save backup config within all limits - should succeed
backupConfigValid := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek, // Within Month limit
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 80, // Within 100 limit
MaxBackupsTotalSizeMB: 800, // Within 1000 limit
}
var responseValid BackupConfig
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigValid,
http.StatusOK,
&responseValid,
)
assert.Equal(t, database.ID, responseValid.DatabaseID)
assert.Equal(t, int64(80), responseValid.MaxBackupSizeMB)
assert.Equal(t, int64(800), responseValid.MaxBackupsTotalSizeMB)
assert.Equal(t, period.PeriodWeek, responseValid.RetentionTimePeriod)
}
func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string

View File

@@ -6,7 +6,6 @@ import (
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/features/storages"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/logger"
@@ -20,7 +19,6 @@ var (
storages.GetStorageService(),
notifiers.GetNotifierService(),
workspaces_services.GetWorkspaceService(),
plans.GetDatabasePlanService(),
nil,
}
)

View File

@@ -9,7 +9,6 @@ import (
"databasus-backend/internal/config"
"databasus-backend/internal/features/intervals"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/features/storages"
"databasus-backend/internal/util/period"
)
@@ -42,11 +41,6 @@ type BackupConfig struct {
MaxFailedTriesCount int `json:"maxFailedTriesCount" gorm:"column:max_failed_tries_count;type:int;not null"`
Encryption BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
// MaxBackupSizeMB limits individual backup size. 0 = unlimited.
MaxBackupSizeMB int64 `json:"maxBackupSizeMb" gorm:"column:max_backup_size_mb;type:int;not null"`
// MaxBackupsTotalSizeMB limits total size of all backups. 0 = unlimited.
MaxBackupsTotalSizeMB int64 `json:"maxBackupsTotalSizeMb" gorm:"column:max_backups_total_size_mb;type:int;not null"`
}
func (h *BackupConfig) TableName() string {
@@ -86,12 +80,12 @@ func (b *BackupConfig) AfterFind(tx *gorm.DB) error {
return nil
}
func (b *BackupConfig) Validate(plan *plans.DatabasePlan) error {
func (b *BackupConfig) Validate() error {
if b.BackupIntervalID == uuid.Nil && b.BackupInterval == nil {
return errors.New("backup interval is required")
}
if err := b.validateRetentionPolicy(plan); err != nil {
if err := b.validateRetentionPolicy(); err != nil {
return err
}
@@ -110,67 +104,38 @@ func (b *BackupConfig) Validate(plan *plans.DatabasePlan) error {
}
}
if b.MaxBackupSizeMB < 0 {
return errors.New("max backup size must be non-negative")
}
if b.MaxBackupsTotalSizeMB < 0 {
return errors.New("max backups total size must be non-negative")
}
if plan.MaxBackupSizeMB > 0 {
if b.MaxBackupSizeMB == 0 || b.MaxBackupSizeMB > plan.MaxBackupSizeMB {
return errors.New("max backup size exceeds plan limit")
}
}
if plan.MaxBackupsTotalSizeMB > 0 {
if b.MaxBackupsTotalSizeMB == 0 ||
b.MaxBackupsTotalSizeMB > plan.MaxBackupsTotalSizeMB {
return errors.New("max total backups size exceeds plan limit")
}
}
return nil
}
func (b *BackupConfig) Copy(newDatabaseID uuid.UUID) *BackupConfig {
return &BackupConfig{
DatabaseID: newDatabaseID,
IsBackupsEnabled: b.IsBackupsEnabled,
RetentionPolicyType: b.RetentionPolicyType,
RetentionTimePeriod: b.RetentionTimePeriod,
RetentionCount: b.RetentionCount,
RetentionGfsHours: b.RetentionGfsHours,
RetentionGfsDays: b.RetentionGfsDays,
RetentionGfsWeeks: b.RetentionGfsWeeks,
RetentionGfsMonths: b.RetentionGfsMonths,
RetentionGfsYears: b.RetentionGfsYears,
BackupIntervalID: uuid.Nil,
BackupInterval: b.BackupInterval.Copy(),
StorageID: b.StorageID,
SendNotificationsOn: b.SendNotificationsOn,
IsRetryIfFailed: b.IsRetryIfFailed,
MaxFailedTriesCount: b.MaxFailedTriesCount,
Encryption: b.Encryption,
MaxBackupSizeMB: b.MaxBackupSizeMB,
MaxBackupsTotalSizeMB: b.MaxBackupsTotalSizeMB,
DatabaseID: newDatabaseID,
IsBackupsEnabled: b.IsBackupsEnabled,
RetentionPolicyType: b.RetentionPolicyType,
RetentionTimePeriod: b.RetentionTimePeriod,
RetentionCount: b.RetentionCount,
RetentionGfsHours: b.RetentionGfsHours,
RetentionGfsDays: b.RetentionGfsDays,
RetentionGfsWeeks: b.RetentionGfsWeeks,
RetentionGfsMonths: b.RetentionGfsMonths,
RetentionGfsYears: b.RetentionGfsYears,
BackupIntervalID: uuid.Nil,
BackupInterval: b.BackupInterval.Copy(),
StorageID: b.StorageID,
SendNotificationsOn: b.SendNotificationsOn,
IsRetryIfFailed: b.IsRetryIfFailed,
MaxFailedTriesCount: b.MaxFailedTriesCount,
Encryption: b.Encryption,
}
}
func (b *BackupConfig) validateRetentionPolicy(plan *plans.DatabasePlan) error {
func (b *BackupConfig) validateRetentionPolicy() error {
switch b.RetentionPolicyType {
case RetentionPolicyTypeTimePeriod, "":
if b.RetentionTimePeriod == "" {
return errors.New("retention time period is required")
}
if plan.MaxStoragePeriod != period.PeriodForever {
if b.RetentionTimePeriod.CompareTo(plan.MaxStoragePeriod) > 0 {
return errors.New("storage period exceeds plan limit")
}
}
case RetentionPolicyTypeCount:
if b.RetentionCount <= 0 {
return errors.New("retention count must be greater than 0")

View File

@@ -6,248 +6,34 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
"databasus-backend/internal/features/intervals"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/util/period"
)
func Test_Validate_WhenRetentionTimePeriodIsWeekAndPlanAllowsMonth_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodWeek
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodMonth
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenRetentionTimePeriodIsYearAndPlanAllowsMonth_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodYear
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodMonth
err := config.Validate(plan)
assert.EqualError(t, err, "storage period exceeds plan limit")
}
func Test_Validate_WhenRetentionTimePeriodIsForeverAndPlanAllowsForever_ValidationPasses(
t *testing.T,
) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodForever
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodForever
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenRetentionTimePeriodIsForeverAndPlanAllowsYear_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodForever
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodYear
err := config.Validate(plan)
assert.EqualError(t, err, "storage period exceeds plan limit")
}
func Test_Validate_WhenRetentionTimePeriodEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodMonth
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodMonth
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenBackupSize100MBAndPlanAllows500MB_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 100
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 500
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenBackupSize500MBAndPlanAllows100MB_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 500
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 100
err := config.Validate(plan)
assert.EqualError(t, err, "max backup size exceeds plan limit")
}
func Test_Validate_WhenBackupSizeIsUnlimitedAndPlanAllowsUnlimited_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 0
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 0
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenBackupSizeIsUnlimitedAndPlanHas500MBLimit_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 0
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 500
err := config.Validate(plan)
assert.EqualError(t, err, "max backup size exceeds plan limit")
}
func Test_Validate_WhenBackupSizeEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 500
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 500
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenTotalSize1GBAndPlanAllows5GB_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 1000
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 5000
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenTotalSize5GBAndPlanAllows1GB_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 5000
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 1000
err := config.Validate(plan)
assert.EqualError(t, err, "max total backups size exceeds plan limit")
}
func Test_Validate_WhenTotalSizeIsUnlimitedAndPlanAllowsUnlimited_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 0
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 0
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenTotalSizeIsUnlimitedAndPlanHas1GBLimit_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 0
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 1000
err := config.Validate(plan)
assert.EqualError(t, err, "max total backups size exceeds plan limit")
}
func Test_Validate_WhenTotalSizeEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 5000
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 5000
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenAllLimitsAreUnlimitedInPlan_AnyConfigurationPasses(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodForever
config.MaxBackupSizeMB = 0
config.MaxBackupsTotalSizeMB = 0
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenMultipleLimitsExceeded_ValidationFailsWithFirstError(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodYear
config.MaxBackupSizeMB = 500
config.MaxBackupsTotalSizeMB = 5000
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodMonth
plan.MaxBackupSizeMB = 100
plan.MaxBackupsTotalSizeMB = 1000
err := config.Validate(plan)
assert.Error(t, err)
assert.EqualError(t, err, "storage period exceeds plan limit")
}
func Test_Validate_WhenConfigHasInvalidIntervalButPlanIsValid_ValidationFailsOnInterval(
t *testing.T,
) {
func Test_Validate_WhenIntervalIsMissing_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.BackupIntervalID = uuid.Nil
config.BackupInterval = nil
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.EqualError(t, err, "backup interval is required")
}
func Test_Validate_WhenIntervalIsMissing_ValidationFailsRegardlessOfPlan(t *testing.T) {
config := createValidBackupConfig()
config.BackupIntervalID = uuid.Nil
config.BackupInterval = nil
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "backup interval is required")
}
func Test_Validate_WhenRetryEnabledButMaxTriesIsZero_ValidationFailsRegardlessOfPlan(t *testing.T) {
func Test_Validate_WhenRetryEnabledButMaxTriesIsZero_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.IsRetryIfFailed = true
config.MaxFailedTriesCount = 0
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.EqualError(t, err, "max failed tries count must be greater than 0")
}
func Test_Validate_WhenEncryptionIsInvalid_ValidationFailsRegardlessOfPlan(t *testing.T) {
func Test_Validate_WhenEncryptionIsInvalid_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.Encryption = "INVALID"
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.EqualError(t, err, "encryption must be NONE or ENCRYPTED")
}
@@ -255,125 +41,16 @@ func Test_Validate_WhenRetentionTimePeriodIsEmpty_ValidationFails(t *testing.T)
config := createValidBackupConfig()
config.RetentionTimePeriod = ""
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.EqualError(t, err, "retention time period is required")
}
func Test_Validate_WhenMaxBackupSizeIsNegative_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = -100
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "max backup size must be non-negative")
}
func Test_Validate_WhenMaxTotalSizeIsNegative_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = -1000
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "max backups total size must be non-negative")
}
func Test_Validate_WhenPlanLimitsAreAtBoundary_ValidationWorks(t *testing.T) {
tests := []struct {
name string
configPeriod period.TimePeriod
planPeriod period.TimePeriod
configSize int64
planSize int64
configTotal int64
planTotal int64
shouldSucceed bool
}{
{
name: "all values just under limit",
configPeriod: period.PeriodWeek,
planPeriod: period.PeriodMonth,
configSize: 99,
planSize: 100,
configTotal: 999,
planTotal: 1000,
shouldSucceed: true,
},
{
name: "all values equal to limit",
configPeriod: period.PeriodMonth,
planPeriod: period.PeriodMonth,
configSize: 100,
planSize: 100,
configTotal: 1000,
planTotal: 1000,
shouldSucceed: true,
},
{
name: "period just over limit",
configPeriod: period.Period3Month,
planPeriod: period.PeriodMonth,
configSize: 100,
planSize: 100,
configTotal: 1000,
planTotal: 1000,
shouldSucceed: false,
},
{
name: "size just over limit",
configPeriod: period.PeriodMonth,
planPeriod: period.PeriodMonth,
configSize: 101,
planSize: 100,
configTotal: 1000,
planTotal: 1000,
shouldSucceed: false,
},
{
name: "total size just over limit",
configPeriod: period.PeriodMonth,
planPeriod: period.PeriodMonth,
configSize: 100,
planSize: 100,
configTotal: 1001,
planTotal: 1000,
shouldSucceed: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = tt.configPeriod
config.MaxBackupSizeMB = tt.configSize
config.MaxBackupsTotalSizeMB = tt.configTotal
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = tt.planPeriod
plan.MaxBackupSizeMB = tt.planSize
plan.MaxBackupsTotalSizeMB = tt.planTotal
err := config.Validate(plan)
if tt.shouldSucceed {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
})
}
}
func Test_Validate_WhenPolicyTypeIsCount_RequiresPositiveCount(t *testing.T) {
config := createValidBackupConfig()
config.RetentionPolicyType = RetentionPolicyTypeCount
config.RetentionCount = 0
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.EqualError(t, err, "retention count must be greater than 0")
}
@@ -382,9 +59,7 @@ func Test_Validate_WhenPolicyTypeIsCount_WithPositiveCount_ValidationPasses(t *t
config.RetentionPolicyType = RetentionPolicyTypeCount
config.RetentionCount = 10
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.NoError(t, err)
}
@@ -396,9 +71,7 @@ func Test_Validate_WhenPolicyTypeIsGFS_RequiresAtLeastOneField(t *testing.T) {
config.RetentionGfsMonths = 0
config.RetentionGfsYears = 0
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.EqualError(t, err, "at least one GFS retention field must be greater than 0")
}
@@ -407,9 +80,7 @@ func Test_Validate_WhenPolicyTypeIsGFS_WithOnlyHours_ValidationPasses(t *testing
config.RetentionPolicyType = RetentionPolicyTypeGFS
config.RetentionGfsHours = 24
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.NoError(t, err)
}
@@ -418,9 +89,7 @@ func Test_Validate_WhenPolicyTypeIsGFS_WithOnlyDays_ValidationPasses(t *testing.
config.RetentionPolicyType = RetentionPolicyTypeGFS
config.RetentionGfsDays = 7
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.NoError(t, err)
}
@@ -433,9 +102,7 @@ func Test_Validate_WhenPolicyTypeIsGFS_WithAllFields_ValidationPasses(t *testing
config.RetentionGfsMonths = 12
config.RetentionGfsYears = 3
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.NoError(t, err)
}
@@ -443,35 +110,59 @@ func Test_Validate_WhenPolicyTypeIsInvalid_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.RetentionPolicyType = "INVALID"
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.EqualError(t, err, "invalid retention policy type")
}
func Test_Validate_WhenCloudAndEncryptionIsNotEncrypted_ValidationFails(t *testing.T) {
enableCloud(t)
backupConfig := createValidBackupConfig()
backupConfig.Encryption = BackupEncryptionNone
err := backupConfig.Validate()
assert.EqualError(t, err, "encryption is mandatory for cloud storage")
}
func Test_Validate_WhenCloudAndEncryptionIsEncrypted_ValidationPasses(t *testing.T) {
enableCloud(t)
backupConfig := createValidBackupConfig()
backupConfig.Encryption = BackupEncryptionEncrypted
err := backupConfig.Validate()
assert.NoError(t, err)
}
func Test_Validate_WhenNotCloudAndEncryptionIsNotEncrypted_ValidationPasses(t *testing.T) {
backupConfig := createValidBackupConfig()
backupConfig.Encryption = BackupEncryptionNone
err := backupConfig.Validate()
assert.NoError(t, err)
}
func enableCloud(t *testing.T) {
t.Helper()
config.GetEnv().IsCloud = true
t.Cleanup(func() {
config.GetEnv().IsCloud = false
})
}
func createValidBackupConfig() *BackupConfig {
intervalID := uuid.New()
return &BackupConfig{
DatabaseID: uuid.New(),
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodMonth,
BackupIntervalID: intervalID,
BackupInterval: &intervals.Interval{ID: intervalID},
SendNotificationsOn: []BackupNotificationType{},
IsRetryIfFailed: false,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 100,
MaxBackupsTotalSizeMB: 1000,
}
}
func createUnlimitedPlan() *plans.DatabasePlan {
return &plans.DatabasePlan{
DatabaseID: uuid.New(),
MaxBackupSizeMB: 0,
MaxBackupsTotalSizeMB: 0,
MaxStoragePeriod: period.PeriodForever,
return &BackupConfig{
DatabaseID: uuid.New(),
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodMonth,
BackupIntervalID: intervalID,
BackupInterval: &intervals.Interval{ID: intervalID},
SendNotificationsOn: []BackupNotificationType{},
IsRetryIfFailed: false,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
}
}

View File

@@ -8,10 +8,10 @@ import (
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/features/storages"
users_models "databasus-backend/internal/features/users/models"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/period"
)
type BackupConfigService struct {
@@ -20,7 +20,6 @@ type BackupConfigService struct {
storageService *storages.StorageService
notifierService *notifiers.NotifierService
workspaceService *workspaces_services.WorkspaceService
databasePlanService *plans.DatabasePlanService
dbStorageChangeListener BackupConfigStorageChangeListener
}
@@ -46,12 +45,7 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
user *users_models.User,
backupConfig *BackupConfig,
) (*BackupConfig, error) {
plan, err := s.databasePlanService.GetDatabasePlan(backupConfig.DatabaseID)
if err != nil {
return nil, err
}
if err := backupConfig.Validate(plan); err != nil {
if err := backupConfig.Validate(); err != nil {
return nil, err
}
@@ -88,12 +82,7 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
func (s *BackupConfigService) SaveBackupConfig(
backupConfig *BackupConfig,
) (*BackupConfig, error) {
plan, err := s.databasePlanService.GetDatabasePlan(backupConfig.DatabaseID)
if err != nil {
return nil, err
}
if err := backupConfig.Validate(plan); err != nil {
if err := backupConfig.Validate(); err != nil {
return nil, err
}
@@ -131,18 +120,6 @@ func (s *BackupConfigService) GetBackupConfigByDbIdWithAuth(
return s.GetBackupConfigByDbId(databaseID)
}
func (s *BackupConfigService) GetDatabasePlan(
user *users_models.User,
databaseID uuid.UUID,
) (*plans.DatabasePlan, error) {
_, err := s.databaseService.GetDatabase(user, databaseID)
if err != nil {
return nil, err
}
return s.databasePlanService.GetDatabasePlan(databaseID)
}
func (s *BackupConfigService) GetBackupConfigByDbId(
databaseID uuid.UUID,
) (*BackupConfig, error) {
@@ -322,20 +299,13 @@ func (s *BackupConfigService) TransferDatabaseToWorkspace(
func (s *BackupConfigService) initializeDefaultConfig(
databaseID uuid.UUID,
) error {
plan, err := s.databasePlanService.GetDatabasePlan(databaseID)
if err != nil {
return err
}
timeOfDay := "04:00"
_, err = s.backupConfigRepository.Save(&BackupConfig{
DatabaseID: databaseID,
IsBackupsEnabled: false,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: plan.MaxStoragePeriod,
MaxBackupSizeMB: plan.MaxBackupSizeMB,
MaxBackupsTotalSizeMB: plan.MaxBackupsTotalSizeMB,
_, err := s.backupConfigRepository.Save(&BackupConfig{
DatabaseID: databaseID,
IsBackupsEnabled: false,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.Period3Month,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,

View File

@@ -0,0 +1,305 @@
package billing
import (
"errors"
"net/http"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
users_middleware "databasus-backend/internal/features/users/middleware"
"databasus-backend/internal/util/logger"
)
type BillingController struct {
billingService *BillingService
}
func (c *BillingController) RegisterRoutes(router *gin.RouterGroup) {
billing := router.Group("/billing")
billing.POST("/subscription", c.CreateSubscription)
billing.POST("/subscription/change-storage", c.ChangeSubscriptionStorage)
billing.POST("/subscription/portal/:subscription_id", c.GetPortalSession)
billing.GET("/subscription/events/:subscription_id", c.GetSubscriptionEvents)
billing.GET("/subscription/invoices/:subscription_id", c.GetInvoices)
billing.GET("/subscription/:database_id", c.GetSubscription)
}
// CreateSubscription
// @Summary Create a new subscription
// @Description Create a billing subscription for the specified database with the given storage
// @Tags billing
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body CreateSubscriptionRequest true "Subscription creation data"
// @Success 200 {object} CreateSubscriptionResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /billing/subscription [post]
func (c *BillingController) CreateSubscription(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(401, gin.H{"error": "User not authenticated"})
return
}
var request CreateSubscriptionRequest
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(400, gin.H{"error": "Invalid request"})
return
}
log := logger.GetLogger().With(
"request_id", uuid.New(),
"database_id", request.DatabaseID,
"user_id", user.ID,
)
transactionID, err := c.billingService.CreateSubscription(
log,
user,
request.DatabaseID,
request.StorageGB,
)
if err != nil {
log.Error("Failed to create subscription", "error", err)
ctx.JSON(500, gin.H{"error": "Failed to create subscription"})
return
}
ctx.JSON(200, CreateSubscriptionResponse{PaddleTransactionID: transactionID})
}
// ChangeSubscriptionStorage
// @Summary Change subscription storage
// @Description Update the storage allocation for an existing subscription
// @Tags billing
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body ChangeStorageRequest true "New storage configuration"
// @Success 200 {object} ChangeStorageResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /billing/subscription/change-storage [post]
func (c *BillingController) ChangeSubscriptionStorage(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(401, gin.H{"error": "User not authenticated"})
return
}
var request ChangeStorageRequest
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(400, gin.H{"error": "Invalid request"})
return
}
log := logger.GetLogger().With(
"request_id", uuid.New(),
"database_id", request.DatabaseID,
"user_id", user.ID,
)
result, err := c.billingService.ChangeSubscriptionStorage(log, user, request.DatabaseID, request.StorageGB)
if err != nil {
log.Error("Failed to change subscription storage", "error", err)
ctx.JSON(500, gin.H{"error": "Failed to change subscription storage"})
return
}
ctx.JSON(200, ChangeStorageResponse{
ApplyMode: result.ApplyMode,
CurrentGB: result.CurrentGB,
PendingGB: result.PendingGB,
})
}
// GetPortalSession
// @Summary Get billing portal session
// @Description Generate a portal session URL for managing the subscription
// @Tags billing
// @Produce json
// @Security BearerAuth
// @Param subscription_id path string true "Subscription ID"
// @Success 200 {object} GetPortalSessionResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /billing/subscription/portal/{subscription_id} [post]
func (c *BillingController) GetPortalSession(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(401, gin.H{"error": "User not authenticated"})
return
}
subscriptionID := ctx.Param("subscription_id")
if subscriptionID == "" {
ctx.JSON(400, gin.H{"error": "Subscription ID is required"})
return
}
log := logger.GetLogger().With(
"request_id", uuid.New(),
"subscription_id", subscriptionID,
"user_id", user.ID,
)
url, err := c.billingService.GetPortalURL(log, user, uuid.MustParse(subscriptionID))
if err != nil {
log.Error("Failed to get portal session", "error", err)
ctx.JSON(500, gin.H{"error": "Failed to get portal session"})
return
}
ctx.JSON(200, GetPortalSessionResponse{PortalURL: url})
}
// GetSubscriptionEvents
// @Summary Get subscription events
// @Description Retrieve the event history for a subscription
// @Tags billing
// @Produce json
// @Security BearerAuth
// @Param subscription_id path string true "Subscription ID"
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Success 200 {object} GetSubscriptionEventsResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /billing/subscription/events/{subscription_id} [get]
func (c *BillingController) GetSubscriptionEvents(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(401, gin.H{"error": "User not authenticated"})
return
}
subscriptionID, err := uuid.Parse(ctx.Param("subscription_id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid subscription ID"})
return
}
var request PaginatedRequest
if err := ctx.ShouldBindQuery(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
log := logger.GetLogger().With(
"request_id", uuid.New(),
"subscription_id", subscriptionID,
"user_id", user.ID,
)
response, err := c.billingService.GetSubscriptionEvents(log, user, subscriptionID, request.Limit, request.Offset)
if err != nil {
log.Error("Failed to get subscription events", "error", err)
ctx.JSON(500, gin.H{"error": "Failed to get subscription events"})
return
}
ctx.JSON(200, response)
}
// GetInvoices
// @Summary Get subscription invoices
// @Description Retrieve all invoices for a subscription
// @Tags billing
// @Produce json
// @Security BearerAuth
// @Param subscription_id path string true "Subscription ID"
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Success 200 {object} GetInvoicesResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /billing/subscription/invoices/{subscription_id} [get]
func (c *BillingController) GetInvoices(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(401, gin.H{"error": "User not authenticated"})
return
}
subscriptionID, err := uuid.Parse(ctx.Param("subscription_id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid subscription ID"})
return
}
var request PaginatedRequest
if err := ctx.ShouldBindQuery(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
log := logger.GetLogger().With(
"request_id", uuid.New(),
"subscription_id", subscriptionID,
"user_id", user.ID,
)
response, err := c.billingService.GetSubscriptionInvoices(log, user, subscriptionID, request.Limit, request.Offset)
if err != nil {
log.Error("Failed to get invoices", "error", err)
ctx.JSON(500, gin.H{"error": "Failed to get invoices"})
return
}
ctx.JSON(200, response)
}
// GetSubscription
// @Summary Get subscription by database
// @Description Retrieve the subscription associated with a specific database
// @Tags billing
// @Produce json
// @Security BearerAuth
// @Param database_id path string true "Database ID"
// @Success 200 {object} billing_models.Subscription
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /billing/subscription/{database_id} [get]
func (c *BillingController) GetSubscription(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(401, gin.H{"error": "User not authenticated"})
return
}
databaseID, err := uuid.Parse(ctx.Param("database_id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid database ID"})
return
}
log := logger.GetLogger().With(
"request_id", uuid.New(),
"database_id", databaseID,
"user_id", user.ID,
)
subscription, err := c.billingService.GetSubscriptionByDatabaseID(log, user, databaseID)
if err != nil {
if errors.Is(err, ErrSubscriptionNotFound) {
ctx.JSON(http.StatusNotFound, gin.H{"error": "Subscription not found"})
return
}
log.Error("failed to get subscription", "error", err)
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get subscription"})
return
}
ctx.JSON(200, subscription)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,49 @@
package billing
import (
"sync"
"sync/atomic"
billing_repositories "databasus-backend/internal/features/billing/repositories"
"databasus-backend/internal/features/databases"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/logger"
)
var (
billingService = &BillingService{
&billing_repositories.SubscriptionRepository{},
&billing_repositories.SubscriptionEventRepository{},
&billing_repositories.InvoiceRepository{},
nil, // billing provider will be set later to avoid circular dependency
workspaces_services.GetWorkspaceService(),
*databases.GetDatabaseService(),
sync.Once{},
atomic.Bool{},
}
billingController = &BillingController{billingService}
setupOnce sync.Once
isSetup atomic.Bool
)
func GetBillingService() *BillingService {
return billingService
}
func GetBillingController() *BillingController {
return billingController
}
func SetupDependencies() {
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
databases.GetDatabaseService().AddDbCreationListener(billingService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("billing.SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -0,0 +1,67 @@
package billing
import (
"github.com/google/uuid"
billing_models "databasus-backend/internal/features/billing/models"
)
type CreateSubscriptionRequest struct {
DatabaseID uuid.UUID `json:"databaseId" validate:"required"`
StorageGB int `json:"storageGb" validate:"required,min=1"`
}
type CreateSubscriptionResponse struct {
PaddleTransactionID string `json:"paddleTransactionId"`
}
type ChangeStorageApplyMode string
const (
ChangeStorageApplyImmediate ChangeStorageApplyMode = "immediate"
ChangeStorageApplyNextCycle ChangeStorageApplyMode = "next_cycle"
)
type ChangeStorageRequest struct {
DatabaseID uuid.UUID `json:"databaseId" validate:"required"`
StorageGB int `json:"storageGb" validate:"required,min=1"`
}
type ChangeStorageResponse struct {
ApplyMode ChangeStorageApplyMode `json:"applyMode"`
CurrentGB int `json:"currentGb"`
PendingGB *int `json:"pendingGb,omitempty"`
}
type PortalResponse struct {
URL string `json:"url"`
}
type ChangeStorageResult struct {
ApplyMode ChangeStorageApplyMode
CurrentGB int
PendingGB *int
}
type GetPortalSessionResponse struct {
PortalURL string `json:"url"`
}
type PaginatedRequest struct {
Limit int `form:"limit" json:"limit"`
Offset int `form:"offset" json:"offset"`
}
type GetSubscriptionEventsResponse struct {
Events []*billing_models.SubscriptionEvent `json:"events"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}
type GetInvoicesResponse struct {
Invoices []*billing_models.Invoice `json:"invoices"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}

View File

@@ -0,0 +1,15 @@
package billing
import "errors"
var (
ErrInvalidStorage = errors.New("storage must be between 20 and 10000 GB")
ErrAlreadySubscribed = errors.New("database already has an active subscription")
ErrExceedsUsage = errors.New("cannot downgrade below current storage usage")
ErrNoChange = errors.New("requested storage is the same as current")
ErrDuplicate = errors.New("duplicate event already processed")
ErrProviderUnavailable = errors.New("payment provider unavailable")
ErrNoActiveSubscription = errors.New("no active subscription for this database")
ErrAccessDenied = errors.New("user does not have access to this database")
ErrSubscriptionNotFound = errors.New("subscription not found")
)

View File

@@ -0,0 +1,24 @@
package billing_models
import (
"time"
"github.com/google/uuid"
)
type Invoice struct {
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
SubscriptionID uuid.UUID `json:"subscriptionId" gorm:"column:subscription_id;type:uuid;not null"`
ProviderInvoiceID string `json:"providerInvoiceId" gorm:"column:provider_invoice_id;type:text;not null"`
AmountCents int64 `json:"amountCents" gorm:"column:amount_cents;type:bigint;not null"`
StorageGB int `json:"storageGb" gorm:"column:storage_gb;type:int;not null"`
PeriodStart time.Time `json:"periodStart" gorm:"column:period_start;type:timestamptz;not null"`
PeriodEnd time.Time `json:"periodEnd" gorm:"column:period_end;type:timestamptz;not null"`
Status InvoiceStatus `json:"status" gorm:"column:status;type:text;not null"`
PaidAt *time.Time `json:"paidAt,omitempty" gorm:"column:paid_at;type:timestamptz"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;type:timestamptz;not null"`
}
func (Invoice) TableName() string {
return "invoices"
}

View File

@@ -0,0 +1,11 @@
package billing_models
type InvoiceStatus string
const (
InvoiceStatusPending InvoiceStatus = "pending"
InvoiceStatusPaid InvoiceStatus = "paid"
InvoiceStatusFailed InvoiceStatus = "failed"
InvoiceStatusRefunded InvoiceStatus = "refunded"
InvoiceStatusDisputed InvoiceStatus = "disputed"
)

View File

@@ -0,0 +1,72 @@
package billing_models
import (
"time"
"github.com/google/uuid"
"databasus-backend/internal/config"
)
type Subscription struct {
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;not null"`
Status SubscriptionStatus `json:"status" gorm:"column:status;type:text;not null"`
StorageGB int `json:"storageGb" gorm:"column:storage_gb;type:int;not null"`
PendingStorageGB *int `json:"pendingStorageGb,omitempty" gorm:"column:pending_storage_gb;type:int"`
CurrentPeriodStart time.Time `json:"currentPeriodStart" gorm:"column:current_period_start;type:timestamptz;not null"`
CurrentPeriodEnd time.Time `json:"currentPeriodEnd" gorm:"column:current_period_end;type:timestamptz;not null"`
CanceledAt *time.Time `json:"canceledAt,omitempty" gorm:"column:canceled_at;type:timestamptz"`
DataRetentionGracePeriodUntil *time.Time `json:"dataRetentionGracePeriodUntil,omitempty" gorm:"column:data_retention_grace_period_until;type:timestamptz"`
ProviderName *string `json:"providerName,omitempty" gorm:"column:provider_name;type:text"`
ProviderSubID *string `json:"providerSubId,omitempty" gorm:"column:provider_sub_id;type:text"`
ProviderCustomerID *string `json:"providerCustomerId,omitempty" gorm:"column:provider_customer_id;type:text"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;type:timestamptz;not null"`
UpdatedAt time.Time `json:"updatedAt" gorm:"column:updated_at;type:timestamptz;not null"`
}
func (Subscription) TableName() string {
return "subscriptions"
}
func (s *Subscription) PriceCents() int64 {
return int64(s.StorageGB) * config.GetEnv().PricePerGBCents
}
// CanCreateNewBackups - whether it is allowed to create new backups
// by scheduler or for user manually. Clarification: in grace period
// user can download, delete and restore backups, but cannot create new ones
func (s *Subscription) CanCreateNewBackups() bool {
switch s.Status {
case StatusActive, StatusPastDue:
return true
case StatusTrial, StatusCanceled:
return time.Now().Before(s.CurrentPeriodEnd)
case StatusExpired:
return false
default:
panic("unknown subscription status")
}
}
func (s *Subscription) GetBackupsStorageGB() int {
switch s.Status {
case StatusActive, StatusPastDue, StatusCanceled:
return s.StorageGB
case StatusTrial:
if time.Now().Before(s.CurrentPeriodEnd) {
return s.StorageGB
}
return 0
case StatusExpired:
return 0
default:
panic("unknown subscription status")
}
}

View File

@@ -0,0 +1,25 @@
package billing_models
import (
"time"
"github.com/google/uuid"
)
type SubscriptionEvent struct {
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
SubscriptionID uuid.UUID `json:"subscriptionId" gorm:"column:subscription_id;type:uuid;not null"`
ProviderEventID *string `json:"providerEventId,omitempty" gorm:"column:provider_event_id;type:text"`
Type SubscriptionEventType `json:"type" gorm:"column:type;type:text;not null"`
OldStorageGB *int `json:"oldStorageGb,omitempty" gorm:"column:old_storage_gb;type:int"`
NewStorageGB *int `json:"newStorageGb,omitempty" gorm:"column:new_storage_gb;type:int"`
OldStatus *SubscriptionStatus `json:"oldStatus,omitempty" gorm:"column:old_status;type:text"`
NewStatus *SubscriptionStatus `json:"newStatus,omitempty" gorm:"column:new_status;type:text"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;type:timestamptz;not null"`
}
func (SubscriptionEvent) TableName() string {
return "subscription_events"
}

View File

@@ -0,0 +1,17 @@
package billing_models
type SubscriptionEventType string
const (
EventCreated SubscriptionEventType = "subscription.created"
EventUpgraded SubscriptionEventType = "subscription.upgraded"
EventDowngraded SubscriptionEventType = "subscription.downgraded"
EventNewBillingCycleStarted SubscriptionEventType = "subscription.new_billing_cycle_started"
EventCanceled SubscriptionEventType = "subscription.canceled"
EventReactivated SubscriptionEventType = "subscription.reactivated"
EventExpired SubscriptionEventType = "subscription.expired"
EventPastDue SubscriptionEventType = "subscription.past_due"
EventRecoveredFromPastDue SubscriptionEventType = "subscription.recovered_from_past_due"
EventRefund SubscriptionEventType = "payment.refund"
EventDispute SubscriptionEventType = "payment.dispute"
)

View File

@@ -0,0 +1,11 @@
package billing_models
type SubscriptionStatus string
const (
StatusTrial SubscriptionStatus = "trial" // trial period (~24h after DB creation)
StatusActive SubscriptionStatus = "active" // paid, everything works
StatusPastDue SubscriptionStatus = "past_due" // payment failed, trying to charge again, but everything still works
StatusCanceled SubscriptionStatus = "canceled" // subscription canceled by user or after past_due (grace period is active)
StatusExpired SubscriptionStatus = "expired" // grace period ended, data marked for deletion, can come from canceled and trial
)

View File

@@ -0,0 +1,22 @@
package billing_models
import (
"time"
"github.com/google/uuid"
)
type WebhookEvent struct {
RequestID uuid.UUID
ProviderEventID string
DatabaseID *uuid.UUID
Type WebhookEventType
ProviderSubscriptionID string
ProviderCustomerID string
ProviderInvoiceID string
QuantityGB int
Status SubscriptionStatus
PeriodStart *time.Time
PeriodEnd *time.Time
AmountCents int64
}

View File

@@ -0,0 +1,13 @@
package billing_models
type WebhookEventType string
const (
WHEventSubscriptionCreated WebhookEventType = "subscription.created"
WHEventSubscriptionUpdated WebhookEventType = "subscription.updated"
WHEventSubscriptionCanceled WebhookEventType = "subscription.canceled"
WHEventSubscriptionPastDue WebhookEventType = "subscription.past_due"
WHEventSubscriptionReactivated WebhookEventType = "subscription.reactivated"
WHEventPaymentSucceeded WebhookEventType = "payment.succeeded"
WHEventSubscriptionDisputeCreated WebhookEventType = "dispute.created"
)

View File

@@ -0,0 +1,5 @@
**Paddle hints:**
- **max_quantity on price:** Paddle limits `quantity` on a price to 100 by default. You need to explicitly set the range (`quantity: {minimum: 20, maximum: 10000}`) when creating a price via API or dashboard. Otherwise requests with quantity > 100 will return an error.
- **Full items list on update:** Unlike Stripe, Paddle requires sending **all** subscription items in `PATCH /subscriptions/{id}`, not just the changed ones. `proration_billing_mode` is also required. Without this you can accidentally remove a line item or get a 400.
- **Webhook events mapping:** Paddle uses `transaction.completed` instead of `payment.succeeded`, `transaction.payment_failed` instead of `payment.failed`, `adjustment.created` instead of `dispute.created`.

View File

@@ -0,0 +1,83 @@
package billing_paddle
import (
"encoding/json"
"errors"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
billing_webhooks "databasus-backend/internal/features/billing/webhooks"
"databasus-backend/internal/util/logger"
)
type PaddleBillingController struct {
paddleBillingService *PaddleBillingService
}
func (c *PaddleBillingController) RegisterPublicRoutes(router *gin.RouterGroup) {
router.POST("/billing/paddle/webhook", c.HandlePaddleWebhook)
}
// HandlePaddleWebhook
// @Summary Handle Paddle webhook
// @Description Process incoming webhook events from Paddle payment provider
// @Tags billing
// @Accept json
// @Produce json
// @Param Paddle-Signature header string true "Paddle webhook signature"
// @Success 200
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 500
// @Router /billing/paddle/webhook [post]
func (c *PaddleBillingController) HandlePaddleWebhook(ctx *gin.Context) {
requestID := uuid.New()
log := logger.GetLogger().With("request_id", requestID)
body, err := io.ReadAll(io.LimitReader(ctx.Request.Body, 1<<20))
if err != nil {
log.Error("failed to read webhook request body", "error", err)
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Failed to read request body"})
return
}
headers := make(map[string]string)
for k := range ctx.Request.Header {
headers[k] = ctx.GetHeader(k)
}
if err := c.paddleBillingService.VerifyWebhookSignature(body, headers); err != nil {
log.Warn("paddle webhook signature verification failed", "error", err)
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid webhook signature"})
return
}
var webhookDTO PaddleWebhookDTO
if err := json.Unmarshal(body, &webhookDTO); err != nil {
log.Error("failed to unmarshal webhook payload", "error", err)
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid webhook payload"})
return
}
log = log.With(
"provider_event_id", webhookDTO.EventID,
"event_type", webhookDTO.EventType,
)
if err := c.paddleBillingService.ProcessWebhookEvent(log, requestID, webhookDTO, body); err != nil {
if errors.Is(err, billing_webhooks.ErrDuplicateWebhook) {
log.Info("duplicate webhook event, returning 200 to not force retry")
ctx.Status(http.StatusOK)
return
}
log.Error("Failed to process paddle webhook", "error", err)
ctx.Status(http.StatusInternalServerError)
return
}
ctx.Status(http.StatusOK)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,72 @@
package billing_paddle
import (
"sync"
"github.com/PaddleHQ/paddle-go-sdk"
"databasus-backend/internal/config"
"databasus-backend/internal/features/billing"
billing_webhooks "databasus-backend/internal/features/billing/webhooks"
)
var (
paddleBillingService *PaddleBillingService
paddleBillingController *PaddleBillingController
initOnce sync.Once
)
func GetPaddleBillingService() *PaddleBillingService {
if !config.GetEnv().IsCloud {
return nil
}
initOnce.Do(func() {
if config.GetEnv().IsPaddleSandbox {
paddleClient, err := paddle.NewSandbox(config.GetEnv().PaddleApiKey)
if err != nil {
return
}
paddleBillingService = &PaddleBillingService{
paddleClient,
paddle.NewWebhookVerifier(config.GetEnv().PaddleWebhookSecret),
config.GetEnv().PaddlePriceID,
billing_webhooks.WebhookRepository{},
billing.GetBillingService(),
}
} else {
paddleClient, err := paddle.New(config.GetEnv().PaddleApiKey)
if err != nil {
return
}
paddleBillingService = &PaddleBillingService{
paddleClient,
paddle.NewWebhookVerifier(config.GetEnv().PaddleWebhookSecret),
config.GetEnv().PaddlePriceID,
billing_webhooks.WebhookRepository{},
billing.GetBillingService(),
}
}
paddleBillingController = &PaddleBillingController{paddleBillingService}
})
return paddleBillingService
}
func GetPaddleBillingController() *PaddleBillingController {
if !config.GetEnv().IsCloud {
return nil
}
// Ensure service + controller are initialized
GetPaddleBillingService()
return paddleBillingController
}
func SetupDependencies() {
billing.GetBillingService().SetBillingProvider(GetPaddleBillingService())
}

View File

@@ -0,0 +1,9 @@
package billing_paddle
import "encoding/json"
type PaddleWebhookDTO struct {
EventID string `json:"event_id"`
EventType string `json:"event_type"`
Data json.RawMessage
}

View File

@@ -0,0 +1,50 @@
package billing_paddle
import "time"
type TestSubscriptionCreatedPayload struct {
EventID string
SubID string
CustomerID string
DatabaseID string
QuantityGB int
PeriodStart time.Time
PeriodEnd time.Time
}
type TestSubscriptionUpdatedPayload struct {
EventID string
SubID string
CustomerID string
QuantityGB int
PeriodStart time.Time
PeriodEnd time.Time
HasScheduledChange bool
ScheduledChangeAction string
}
type TestSubscriptionCanceledPayload struct {
EventID string
SubID string
CustomerID string
}
type TestTransactionCompletedPayload struct {
EventID string
TxnID string
SubID string
CustomerID string
TotalCents int64
QuantityGB int
PeriodStart time.Time
PeriodEnd time.Time
}
type TestSubscriptionPastDuePayload struct {
EventID string
SubID string
CustomerID string
QuantityGB int
PeriodStart time.Time
PeriodEnd time.Time
}

View File

@@ -0,0 +1,638 @@
package billing_paddle
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strconv"
"time"
"github.com/PaddleHQ/paddle-go-sdk"
"github.com/google/uuid"
"databasus-backend/internal/features/billing"
billing_models "databasus-backend/internal/features/billing/models"
billing_provider "databasus-backend/internal/features/billing/provider"
billing_webhooks "databasus-backend/internal/features/billing/webhooks"
)
type PaddleBillingService struct {
client *paddle.SDK
webhookVerified *paddle.WebhookVerifier
priceID string
webhookRepository billing_webhooks.WebhookRepository
billingService *billing.BillingService
}
func (s *PaddleBillingService) GetProviderName() billing_provider.ProviderName {
return billing_provider.ProviderPaddle
}
func (s *PaddleBillingService) CreateCheckoutSession(
logger *slog.Logger,
request billing_provider.CheckoutRequest,
) (string, error) {
logger = logger.With("database_id", request.DatabaseID)
logger.Debug(fmt.Sprintf("paddle: creating checkout session for %d GB", request.StorageGB))
txRequest := &paddle.CreateTransactionRequest{
Items: []paddle.CreateTransactionItems{
*paddle.NewCreateTransactionItemsCatalogItem(&paddle.CatalogItem{
PriceID: s.priceID,
Quantity: request.StorageGB,
}),
},
CustomData: paddle.CustomData{"database_id": request.DatabaseID.String()},
Checkout: &paddle.TransactionCheckout{
URL: &request.SuccessURL,
},
}
tx, err := s.client.CreateTransaction(context.Background(), txRequest)
if err != nil {
logger.Error("paddle: failed to create transaction", "error", err)
return "", err
}
return tx.ID, nil
}
func (s *PaddleBillingService) UpgradeQuantityWithSurcharge(
logger *slog.Logger,
providerSubscriptionID string,
quantityGB int,
) error {
logger = logger.With("provider_subscription_id", providerSubscriptionID)
logger.Debug(fmt.Sprintf("paddle: applying upgrade: new storage %d GB", quantityGB))
// important: paddle requires to send all items
// in the subscription when updating, not just the changed one
subscription, err := s.client.GetSubscription(context.Background(), &paddle.GetSubscriptionRequest{
SubscriptionID: providerSubscriptionID,
})
if err != nil {
logger.Error("paddle: failed to get subscription", "error", err)
return err
}
currentQuantity := subscription.Items[0].Quantity
if currentQuantity == quantityGB {
logger.Info("paddle: subscription already at requested quantity, skipping upgrade",
"current_quantity_gb", currentQuantity,
"requested_quantity_gb", quantityGB,
)
return nil
}
priceID := subscription.Items[0].Price.ID
_, err = s.client.UpdateSubscription(context.Background(), &paddle.UpdateSubscriptionRequest{
SubscriptionID: providerSubscriptionID,
Items: paddle.NewPatchField([]paddle.SubscriptionUpdateCatalogItem{
{
PriceID: priceID,
Quantity: quantityGB,
},
}),
ProrationBillingMode: paddle.NewPatchField(paddle.ProrationBillingModeProratedImmediately),
})
if err != nil {
logger.Error("paddle: failed to update subscription", "error", err)
return err
}
logger.Debug("paddle: successfully applied upgrade")
return nil
}
func (s *PaddleBillingService) ScheduleQuantityDowngradeFromNextBillingCycle(
logger *slog.Logger,
providerSubscriptionID string,
quantityGB int,
) error {
logger = logger.With("provider_subscription_id", providerSubscriptionID)
logger.Debug(fmt.Sprintf("paddle: scheduling downgrade from next billing cycle: new storage %d GB", quantityGB))
// important: paddle requires to send all items
// in the subscription when updating, not just the changed one
subscription, err := s.client.GetSubscription(context.Background(), &paddle.GetSubscriptionRequest{
SubscriptionID: providerSubscriptionID,
})
if err != nil {
logger.Error("paddle: failed to get subscription", "error", err)
return err
}
currentQuantity := subscription.Items[0].Quantity
if currentQuantity == quantityGB {
logger.Info("paddle: subscription already at requested quantity, skipping downgrade",
"current_quantity_gb", currentQuantity,
"requested_quantity_gb", quantityGB,
)
return nil
}
if subscription.ScheduledChange != nil {
logger.Info("paddle: subscription already has a scheduled change, skipping downgrade")
return nil
}
priceID := subscription.Items[0].Price.ID
// apply downgrade from next billing cycle by setting the proration billing mode to "prorate on next billing period"
_, err = s.client.UpdateSubscription(context.Background(), &paddle.UpdateSubscriptionRequest{
SubscriptionID: providerSubscriptionID,
Items: paddle.NewPatchField([]paddle.SubscriptionUpdateCatalogItem{
{
PriceID: priceID,
Quantity: quantityGB,
},
}),
ProrationBillingMode: paddle.NewPatchField(paddle.ProrationBillingModeFullNextBillingPeriod),
})
if err != nil {
logger.Error("paddle: failed to update subscription for downgrade", "error", err)
return fmt.Errorf("failed to update subscription: %w", err)
}
logger.Debug("paddle: successfully scheduled downgrade from next billing cycle")
return nil
}
func (s *PaddleBillingService) GetSubscription(
logger *slog.Logger,
providerSubscriptionID string,
) (billing_provider.ProviderSubscription, error) {
logger = logger.With("provider_subscription_id", providerSubscriptionID)
logger.Debug("paddle: getting subscription details")
subscription, err := s.client.GetSubscription(context.Background(), &paddle.GetSubscriptionRequest{
SubscriptionID: providerSubscriptionID,
})
if err != nil {
logger.Error("paddle: failed to get subscription", "error", err)
return billing_provider.ProviderSubscription{}, err
}
logger.Debug(
fmt.Sprintf(
"paddle: successfully got subscription details: status=%s, quantity=%d",
subscription.Status,
subscription.Items[0].Quantity,
),
)
return s.toProviderSubscription(logger, subscription)
}
func (s *PaddleBillingService) CreatePortalSession(
logger *slog.Logger,
providerCustomerID, returnURL string,
) (string, error) {
logger = logger.With("provider_customer_id", providerCustomerID)
logger.Debug("paddle: creating portal session")
subscriptions, err := s.client.ListSubscriptions(context.Background(), &paddle.ListSubscriptionsRequest{
CustomerID: []string{providerCustomerID},
Status: []string{
string(paddle.SubscriptionStatusActive),
string(paddle.SubscriptionStatusPastDue),
},
})
if err != nil {
logger.Error("paddle: failed to list subscriptions for portal session", "error", err)
return "", err
}
res := subscriptions.Next(context.Background())
if !res.Ok() {
if res.Err() != nil {
logger.Error("paddle: failed to iterate subscriptions", "error", res.Err())
return "", res.Err()
}
logger.Error("paddle: no active subscriptions found for customer")
return "", fmt.Errorf("no active subscriptions found for customer %s", providerCustomerID)
}
subscription := res.Value()
if subscription.ManagementURLs.UpdatePaymentMethod == nil {
logger.Error("paddle: subscription has no management URL")
return "", fmt.Errorf("subscription %s has no management URL", subscription.ID)
}
return *subscription.ManagementURLs.UpdatePaymentMethod, nil
}
func (s *PaddleBillingService) VerifyWebhookSignature(body []byte, headers map[string]string) error {
req, _ := http.NewRequestWithContext(context.Background(), "POST", "/", bytes.NewReader(body))
for k, v := range headers {
req.Header.Set(k, v)
}
ok, err := s.webhookVerified.Verify(req)
if err != nil || !ok {
return fmt.Errorf("failed to verify webhook signature: %w", err)
}
return nil
}
func (s *PaddleBillingService) ProcessWebhookEvent(
logger *slog.Logger,
requestID uuid.UUID,
webhookDTO PaddleWebhookDTO,
rawBody []byte,
) error {
webhookEvent, err := s.normalizeWebhookEvent(
logger,
requestID,
webhookDTO.EventID,
webhookDTO.EventType,
webhookDTO.Data,
)
if err != nil {
if errors.Is(err, billing_webhooks.ErrUnsupportedEventType) {
return s.skipWebhookEvent(logger, requestID, webhookDTO, rawBody)
}
logger.Error("paddle: failed to normalize webhook event", "error", err)
return err
}
logArgs := []any{
"provider_event_id", webhookEvent.ProviderEventID,
"provider_subscription_id", webhookEvent.ProviderSubscriptionID,
"provider_customer_id", webhookEvent.ProviderCustomerID,
}
if webhookEvent.DatabaseID != nil {
logArgs = append(logArgs, "database_id", webhookEvent.DatabaseID)
}
logger = logger.With(logArgs...)
existingRecord, err := s.webhookRepository.FindSuccessfulByProviderEventID(webhookEvent.ProviderEventID)
if err == nil && existingRecord != nil {
logger.Info("paddle: webhook already processed successfully, skipping",
"existing_request_id", existingRecord.RequestID,
)
return billing_webhooks.ErrDuplicateWebhook
}
webhookRecord := &billing_webhooks.WebhookRecord{
RequestID: requestID,
ProviderName: billing_provider.ProviderPaddle,
EventType: string(webhookEvent.Type),
ProviderEventID: webhookEvent.ProviderEventID,
RawPayload: string(rawBody),
}
if err := s.webhookRepository.Insert(webhookRecord); err != nil {
logger.Error("paddle: failed to save webhook record", "error", err)
return err
}
if err := s.processWebhookEvent(logger, webhookEvent); err != nil {
logger.Error("paddle: failed to process webhook event", "error", err)
if markErr := s.webhookRepository.MarkError(requestID.String(), err.Error()); markErr != nil {
logger.Error("paddle: failed to mark webhook as errored", "error", markErr)
}
return err
}
if markErr := s.webhookRepository.MarkProcessed(requestID.String()); markErr != nil {
logger.Error("paddle: failed to mark webhook as processed", "error", markErr)
}
return nil
}
func (s *PaddleBillingService) skipWebhookEvent(
logger *slog.Logger,
requestID uuid.UUID,
webhookDTO PaddleWebhookDTO,
rawBody []byte,
) error {
webhookRecord := &billing_webhooks.WebhookRecord{
RequestID: requestID,
ProviderName: billing_provider.ProviderPaddle,
EventType: webhookDTO.EventType,
ProviderEventID: webhookDTO.EventID,
RawPayload: string(rawBody),
}
if err := s.webhookRepository.Insert(webhookRecord); err != nil {
logger.Error("paddle: failed to save skipped webhook record", "error", err)
return err
}
if err := s.webhookRepository.MarkSkipped(requestID.String()); err != nil {
logger.Error("paddle: failed to mark webhook as skipped", "error", err)
}
return nil
}
func (s *PaddleBillingService) processWebhookEvent(
logger *slog.Logger,
webhookEvent billing_models.WebhookEvent,
) error {
logger.Debug("processing webhook event")
// subscription.created - there is no subscription in the database yet
if webhookEvent.Type == billing_models.WHEventSubscriptionCreated {
return s.billingService.ActivateSubscription(logger, webhookEvent)
}
// dispute - finds subscription via invoice, no provider subscription ID available
if webhookEvent.Type == billing_models.WHEventSubscriptionDisputeCreated {
return s.billingService.RecordDispute(logger, webhookEvent)
}
// for others - search subscription first
subscription, err := s.billingService.GetSubscriptionByProviderSubID(logger, webhookEvent.ProviderSubscriptionID)
if err != nil {
logger.Error("paddle: failed to find subscription for webhook event", "error", err)
return err
}
logger = logger.With("subscription_id", subscription.ID, "database_id", subscription.DatabaseID)
logger.Debug(fmt.Sprintf("found subscription in DB with ID: %s", subscription.ID))
switch webhookEvent.Type {
case billing_models.WHEventSubscriptionUpdated:
if subscription.Status == billing_models.StatusCanceled {
return s.billingService.ReactivateSubscription(logger, subscription, webhookEvent)
}
return s.billingService.SyncSubscriptionFromProvider(logger, subscription, webhookEvent)
case billing_models.WHEventSubscriptionCanceled:
return s.billingService.CancelSubscription(logger, subscription, webhookEvent)
case billing_models.WHEventPaymentSucceeded:
return s.billingService.RecordPaymentSuccess(logger, subscription, webhookEvent)
case billing_models.WHEventSubscriptionPastDue:
return s.billingService.RecordPaymentFailed(logger, subscription, webhookEvent)
default:
logger.Error(fmt.Sprintf("unhandled webhook event type: %s", string(webhookEvent.Type)))
return nil
}
}
func (s *PaddleBillingService) normalizeWebhookEvent(
logger *slog.Logger,
requestID uuid.UUID,
eventID, eventType string,
data json.RawMessage,
) (billing_models.WebhookEvent, error) {
webhookEvent := billing_models.WebhookEvent{
RequestID: requestID,
ProviderEventID: eventID,
}
switch eventType {
case "subscription.created":
webhookEvent.Type = billing_models.WHEventSubscriptionCreated
var subscription paddle.Subscription
if err := json.Unmarshal(data, &subscription); err != nil {
logger.Error("paddle: failed to unmarshal subscription.created webhook data", "error", err)
return webhookEvent, err
}
webhookEvent.ProviderSubscriptionID = subscription.ID
webhookEvent.ProviderCustomerID = subscription.CustomerID
webhookEvent.QuantityGB = subscription.Items[0].Quantity
status, err := mapPaddleStatus(logger, subscription.Status)
if err != nil {
return webhookEvent, err
}
webhookEvent.Status = status
if subscription.CurrentBillingPeriod != nil {
periodStart := mustParseRFC3339(logger, "period start", subscription.CurrentBillingPeriod.StartsAt)
periodEnd := mustParseRFC3339(logger, "period end", subscription.CurrentBillingPeriod.EndsAt)
webhookEvent.PeriodStart = &periodStart
webhookEvent.PeriodEnd = &periodEnd
}
if subscription.CustomData == nil || subscription.CustomData["database_id"] == "" {
logger.Error("paddle: subscription has no database_id in custom data")
}
databaseIDStr, isOk := subscription.CustomData["database_id"].(string)
if !isOk {
logger.Error("paddle: database_id in custom data is not a string")
return webhookEvent, fmt.Errorf("invalid database_id type in custom data")
}
databaseID := uuid.MustParse(databaseIDStr)
webhookEvent.DatabaseID = &databaseID
case "subscription.updated":
var subscription paddle.Subscription
if err := json.Unmarshal(data, &subscription); err != nil {
return webhookEvent, err
}
webhookEvent.ProviderSubscriptionID = subscription.ID
webhookEvent.ProviderCustomerID = subscription.CustomerID
webhookEvent.QuantityGB = subscription.Items[0].Quantity
webhookEvent.Type = billing_models.WHEventSubscriptionUpdated
status, err := mapPaddleStatus(logger, subscription.Status)
if err != nil {
return webhookEvent, err
}
webhookEvent.Status = status
if subscription.CurrentBillingPeriod != nil {
periodStart := mustParseRFC3339(logger, "period start", subscription.CurrentBillingPeriod.StartsAt)
periodEnd := mustParseRFC3339(logger, "period end", subscription.CurrentBillingPeriod.EndsAt)
webhookEvent.PeriodStart = &periodStart
webhookEvent.PeriodEnd = &periodEnd
}
if subscription.ScheduledChange != nil &&
subscription.ScheduledChange.Action == paddle.ScheduledChangeActionCancel {
webhookEvent.Type = billing_models.WHEventSubscriptionCanceled
}
case "subscription.canceled":
webhookEvent.Type = billing_models.WHEventSubscriptionCanceled
var subscription paddle.Subscription
if err := json.Unmarshal(data, &subscription); err != nil {
return webhookEvent, err
}
webhookEvent.ProviderSubscriptionID = subscription.ID
webhookEvent.ProviderCustomerID = subscription.CustomerID
status, err := mapPaddleStatus(logger, subscription.Status)
if err != nil {
return webhookEvent, err
}
webhookEvent.Status = status
case "transaction.completed":
webhookEvent.Type = billing_models.WHEventPaymentSucceeded
var transaction paddle.Transaction
if err := json.Unmarshal(data, &transaction); err != nil {
return webhookEvent, err
}
webhookEvent.ProviderInvoiceID = transaction.ID
if len(transaction.Items) > 0 {
webhookEvent.QuantityGB = transaction.Items[0].Quantity
}
if transaction.SubscriptionID != nil {
webhookEvent.ProviderSubscriptionID = *transaction.SubscriptionID
}
if transaction.CustomerID != nil {
webhookEvent.ProviderCustomerID = *transaction.CustomerID
}
amountCents, err := strconv.ParseInt(transaction.Details.Totals.Total, 10, 64)
if err != nil {
logger.Error("paddle: failed to parse transaction total", "error", err)
} else {
webhookEvent.AmountCents = amountCents
}
if transaction.BillingPeriod != nil {
periodStart := mustParseRFC3339(logger, "period start", transaction.BillingPeriod.StartsAt)
periodEnd := mustParseRFC3339(logger, "period end", transaction.BillingPeriod.EndsAt)
webhookEvent.PeriodStart = &periodStart
webhookEvent.PeriodEnd = &periodEnd
}
case "subscription.past_due":
webhookEvent.Type = billing_models.WHEventSubscriptionPastDue
var subscription paddle.Subscription
if err := json.Unmarshal(data, &subscription); err != nil {
return webhookEvent, err
}
webhookEvent.ProviderSubscriptionID = subscription.ID
webhookEvent.ProviderCustomerID = subscription.CustomerID
webhookEvent.QuantityGB = subscription.Items[0].Quantity
status, err := mapPaddleStatus(logger, subscription.Status)
if err != nil {
return webhookEvent, err
}
webhookEvent.Status = status
if subscription.CurrentBillingPeriod != nil {
periodStart := mustParseRFC3339(logger, "period start", subscription.CurrentBillingPeriod.StartsAt)
periodEnd := mustParseRFC3339(logger, "period end", subscription.CurrentBillingPeriod.EndsAt)
webhookEvent.PeriodStart = &periodStart
webhookEvent.PeriodEnd = &periodEnd
}
case "adjustment.created":
webhookEvent.Type = billing_models.WHEventSubscriptionDisputeCreated
var adjustment struct {
TransactionID string `json:"transaction_id"`
}
if err := json.Unmarshal(data, &adjustment); err != nil {
return webhookEvent, err
}
webhookEvent.ProviderInvoiceID = adjustment.TransactionID
default:
logger.Debug("unsupported paddle event type, skipping", "event_type", eventType)
return webhookEvent, billing_webhooks.ErrUnsupportedEventType
}
return webhookEvent, nil
}
func (s *PaddleBillingService) toProviderSubscription(
logger *slog.Logger,
paddleSubscription *paddle.Subscription,
) (billing_provider.ProviderSubscription, error) {
status, err := mapPaddleStatus(logger, paddleSubscription.Status)
if err != nil {
return billing_provider.ProviderSubscription{}, err
}
if len(paddleSubscription.Items) == 0 {
return billing_provider.ProviderSubscription{}, fmt.Errorf(
"paddle subscription %s has no items",
paddleSubscription.ID,
)
}
providerSubscription := &billing_provider.ProviderSubscription{
ProviderSubscriptionID: paddleSubscription.ID,
ProviderCustomerID: paddleSubscription.CustomerID,
Status: status,
QuantityGB: paddleSubscription.Items[0].Quantity,
}
if paddleSubscription.CurrentBillingPeriod != nil {
providerSubscription.PeriodStart = mustParseRFC3339(
logger,
"period start",
paddleSubscription.CurrentBillingPeriod.StartsAt,
)
providerSubscription.PeriodEnd = mustParseRFC3339(
logger,
"period end",
paddleSubscription.CurrentBillingPeriod.EndsAt,
)
}
return *providerSubscription, nil
}
func mustParseRFC3339(logger *slog.Logger, label, value string) time.Time {
parsed, err := time.Parse(time.RFC3339, value)
if err != nil {
logger.Error(fmt.Sprintf("paddle: failed to parse %s", label), "error", err)
}
return parsed
}
func mapPaddleStatus(logger *slog.Logger, s paddle.SubscriptionStatus) (billing_models.SubscriptionStatus, error) {
switch s {
case paddle.SubscriptionStatusActive:
return billing_models.StatusActive, nil
case paddle.SubscriptionStatusPastDue:
return billing_models.StatusPastDue, nil
case paddle.SubscriptionStatusCanceled:
return billing_models.StatusCanceled, nil
case paddle.SubscriptionStatusTrialing:
return billing_models.StatusTrial, nil
case paddle.SubscriptionStatusPaused:
return billing_models.StatusCanceled, nil
default:
logger.Error(fmt.Sprintf("paddle: unknown subscription status: %s", string(s)))
return "", fmt.Errorf("paddle: unknown subscription status: %s", string(s))
}
}

View File

@@ -0,0 +1,38 @@
package billing_provider
import (
"time"
"github.com/google/uuid"
billing_models "databasus-backend/internal/features/billing/models"
)
type CreateSubscriptionRequest struct {
ProviderCustomerID string
DatabaseID uuid.UUID
StorageGB int
}
type ProviderSubscription struct {
ProviderSubscriptionID string
ProviderCustomerID string
Status billing_models.SubscriptionStatus
QuantityGB int
PeriodStart time.Time
PeriodEnd time.Time
}
type CheckoutRequest struct {
DatabaseID uuid.UUID
Email string
StorageGB int
SuccessURL string
CancelURL string
}
type ProviderName string
const (
ProviderPaddle ProviderName = "paddle"
)

View File

@@ -0,0 +1,21 @@
package billing_provider
import "log/slog"
type BillingProvider interface {
GetProviderName() ProviderName
UpgradeQuantityWithSurcharge(logger *slog.Logger, providerSubscriptionID string, quantityGB int) error
ScheduleQuantityDowngradeFromNextBillingCycle(
logger *slog.Logger,
providerSubscriptionID string,
quantityGB int,
) error
GetSubscription(logger *slog.Logger, providerSubscriptionID string) (ProviderSubscription, error)
CreateCheckoutSession(logger *slog.Logger, req CheckoutRequest) (checkoutURL string, err error)
CreatePortalSession(logger *slog.Logger, providerCustomerID, returnURL string) (portalURL string, err error)
}

View File

@@ -0,0 +1,72 @@
package billing_repositories
import (
"errors"
"github.com/google/uuid"
"gorm.io/gorm"
billing_models "databasus-backend/internal/features/billing/models"
"databasus-backend/internal/storage"
)
type InvoiceRepository struct{}
func (r *InvoiceRepository) Save(invoice billing_models.Invoice) error {
if invoice.SubscriptionID == uuid.Nil {
return errors.New("subscription id is required")
}
db := storage.GetDb()
if invoice.ID == uuid.Nil {
invoice.ID = uuid.New()
return db.Create(&invoice).Error
}
return db.Save(invoice).Error
}
func (r *InvoiceRepository) FindByProviderInvID(providerInvoiceID string) (*billing_models.Invoice, error) {
var invoice billing_models.Invoice
if err := storage.GetDb().Where("provider_invoice_id = ?", providerInvoiceID).
First(&invoice).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &invoice, nil
}
func (r *InvoiceRepository) FindByDatabaseID(
databaseID uuid.UUID,
limit, offset int,
) ([]*billing_models.Invoice, error) {
var invoices []*billing_models.Invoice
if err := storage.GetDb().Joins("JOIN subscriptions ON subscriptions.id = invoices.subscription_id").
Where("subscriptions.database_id = ?", databaseID).
Order("invoices.created_at DESC").
Limit(limit).
Offset(offset).
Find(&invoices).Error; err != nil {
return nil, err
}
return invoices, nil
}
func (r *InvoiceRepository) CountByDatabaseID(databaseID uuid.UUID) (int64, error) {
var count int64
err := storage.GetDb().Model(&billing_models.Invoice{}).
Joins("JOIN subscriptions ON subscriptions.id = invoices.subscription_id").
Where("subscriptions.database_id = ?", databaseID).
Count(&count).Error
return count, err
}

View File

@@ -0,0 +1,50 @@
package billing_repositories
import (
"errors"
"github.com/google/uuid"
billing_models "databasus-backend/internal/features/billing/models"
"databasus-backend/internal/storage"
)
type SubscriptionEventRepository struct{}
func (r *SubscriptionEventRepository) Create(event billing_models.SubscriptionEvent) error {
if event.SubscriptionID == uuid.Nil {
return errors.New("subscription id is required")
}
event.ID = uuid.New()
return storage.GetDb().Create(&event).Error
}
func (r *SubscriptionEventRepository) FindByDatabaseID(
databaseID uuid.UUID,
limit, offset int,
) ([]*billing_models.SubscriptionEvent, error) {
var events []*billing_models.SubscriptionEvent
if err := storage.GetDb().Joins("JOIN subscriptions ON subscriptions.id = subscription_events.subscription_id").
Where("subscriptions.database_id = ?", databaseID).
Order("subscription_events.created_at DESC").
Limit(limit).
Offset(offset).
Find(&events).Error; err != nil {
return nil, err
}
return events, nil
}
func (r *SubscriptionEventRepository) CountByDatabaseID(databaseID uuid.UUID) (int64, error) {
var count int64
err := storage.GetDb().Model(&billing_models.SubscriptionEvent{}).
Joins("JOIN subscriptions ON subscriptions.id = subscription_events.subscription_id").
Where("subscriptions.database_id = ?", databaseID).
Count(&count).Error
return count, err
}

View File

@@ -0,0 +1,123 @@
package billing_repositories
import (
"errors"
"time"
"github.com/google/uuid"
"gorm.io/gorm"
billing_models "databasus-backend/internal/features/billing/models"
"databasus-backend/internal/storage"
)
type SubscriptionRepository struct{}
func (r *SubscriptionRepository) Save(sub billing_models.Subscription) error {
db := storage.GetDb()
if sub.ID == uuid.Nil {
sub.ID = uuid.New()
return db.Create(&sub).Error
}
return db.Save(&sub).Error
}
func (r *SubscriptionRepository) FindByID(id uuid.UUID) (*billing_models.Subscription, error) {
var sub billing_models.Subscription
if err := storage.GetDb().Where("id = ?", id).First(&sub).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &sub, nil
}
func (r *SubscriptionRepository) FindByDatabaseIDAndStatuses(
databaseID uuid.UUID,
stauses []billing_models.SubscriptionStatus,
) ([]*billing_models.Subscription, error) {
var subs []*billing_models.Subscription
if err := storage.GetDb().Where("database_id = ? AND status IN ?", databaseID, stauses).
Find(&subs).Error; err != nil {
return nil, err
}
return subs, nil
}
func (r *SubscriptionRepository) FindLatestByDatabaseID(databaseID uuid.UUID) (*billing_models.Subscription, error) {
var sub billing_models.Subscription
if err := storage.GetDb().
Where("database_id = ?", databaseID).
Order("created_at DESC").
First(&sub).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &sub, nil
}
func (r *SubscriptionRepository) FindByProviderSubID(providerSubID string) (*billing_models.Subscription, error) {
var sub billing_models.Subscription
if err := storage.GetDb().Where("provider_sub_id = ?", providerSubID).
First(&sub).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &sub, nil
}
func (r *SubscriptionRepository) FindByStatuses(
statuses []billing_models.SubscriptionStatus,
) ([]billing_models.Subscription, error) {
var subs []billing_models.Subscription
if err := storage.GetDb().Where("status IN ?", statuses).Find(&subs).Error; err != nil {
return nil, err
}
return subs, nil
}
func (r *SubscriptionRepository) FindCanceledWithEndedGracePeriod(
now time.Time,
) ([]billing_models.Subscription, error) {
var subs []billing_models.Subscription
if err := storage.GetDb().
Where("status = ? AND data_retention_grace_period_until < ?", billing_models.StatusCanceled, now).
Find(&subs).
Error; err != nil {
return nil, err
}
return subs, nil
}
func (r *SubscriptionRepository) FindExpiredTrials(now time.Time) ([]billing_models.Subscription, error) {
var subs []billing_models.Subscription
if err := storage.GetDb().Where("status = ? AND current_period_end < ?", billing_models.StatusTrial, now).
Find(&subs).Error; err != nil {
return nil, err
}
return subs, nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,8 @@
package billing_webhooks
import "errors"
var (
ErrDuplicateWebhook = errors.New("duplicate webhook event")
ErrUnsupportedEventType = errors.New("unsupported webhook event type")
)

View File

@@ -0,0 +1,25 @@
package billing_webhooks
import (
"time"
"github.com/google/uuid"
billing_provider "databasus-backend/internal/features/billing/provider"
)
type WebhookRecord struct {
RequestID uuid.UUID `gorm:"column:request_id;primaryKey;type:uuid;default:gen_random_uuid()"`
ProviderName billing_provider.ProviderName `gorm:"column:provider_name;type:text;not null"`
EventType string `gorm:"column:event_type;type:text;not null"`
ProviderEventID string `gorm:"column:provider_event_id;type:text;not null;index"`
RawPayload string `gorm:"column:raw_payload;type:text;not null"`
ProcessedAt *time.Time `gorm:"column:processed_at"`
IsSkipped bool `gorm:"column:is_skipped;not null;default:false"`
Error *string `gorm:"column:error"`
CreatedAt time.Time `gorm:"column:created_at;not null"`
}
func (WebhookRecord) TableName() string {
return "webhook_records"
}

View File

@@ -0,0 +1,73 @@
package billing_webhooks
import (
"errors"
"time"
"gorm.io/gorm"
"databasus-backend/internal/storage"
)
type WebhookRepository struct{}
func (r *WebhookRepository) FindSuccessfulByProviderEventID(providerEventID string) (*WebhookRecord, error) {
var record WebhookRecord
err := storage.GetDb().
Where("provider_event_id = ? AND processed_at IS NOT NULL", providerEventID).
First(&record).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
return nil, err
}
return &record, nil
}
func (r *WebhookRepository) Insert(record *WebhookRecord) error {
if record.ProviderEventID == "" {
return errors.New("provider event ID is required")
}
record.CreatedAt = time.Now().UTC()
return storage.GetDb().Create(record).Error
}
func (r *WebhookRepository) MarkProcessed(requestID string) error {
now := time.Now().UTC()
return storage.
GetDb().
Model(&WebhookRecord{}).
Where("request_id = ?", requestID).
Update("processed_at", now).
Error
}
func (r *WebhookRepository) MarkSkipped(requestID string) error {
now := time.Now().UTC()
return storage.
GetDb().
Model(&WebhookRecord{}).
Where("request_id = ?", requestID).
Updates(map[string]any{
"is_skipped": true,
"processed_at": now,
}).
Error
}
func (r *WebhookRepository) MarkError(requestID, errMsg string) error {
return storage.
GetDb().
Model(&WebhookRecord{}).
Where("request_id = ?", requestID).
Update("error", errMsg).
Error
}

View File

@@ -1328,6 +1328,143 @@ func getTestPostgresConfig() *postgresql.PostgresqlDatabase {
}
}
func Test_CreateDatabase_WhenCloudAndUserIsNotReadOnly_ReturnsBadRequest(t *testing.T) {
enableCloud(t)
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Cloud Not ReadOnly", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
request := Database{
Name: "Cloud Non-ReadOnly DB",
WorkspaceID: &workspace.ID,
Type: DatabaseTypePostgres,
Postgresql: getTestPostgresConfig(),
}
resp := test_utils.MakePostRequest(
t,
router,
"/api/v1/databases/create",
"Bearer "+owner.Token,
request,
http.StatusBadRequest,
)
assert.Contains(t, string(resp.Body), "in cloud mode, only read-only database users are allowed")
}
func Test_CreateDatabase_WhenCloudAndUserIsReadOnly_DatabaseCreated(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Cloud ReadOnly", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Temp DB for RO User", workspace.ID, owner.Token, router)
readOnlyUser := createReadOnlyUserViaAPI(t, router, database.ID, owner.Token)
assert.NotEmpty(t, readOnlyUser.Username)
assert.NotEmpty(t, readOnlyUser.Password)
RemoveTestDatabase(database)
enableCloud(t)
pgConfig := getTestPostgresConfig()
pgConfig.Username = readOnlyUser.Username
pgConfig.Password = readOnlyUser.Password
request := Database{
Name: "Cloud ReadOnly DB",
WorkspaceID: &workspace.ID,
Type: DatabaseTypePostgres,
Postgresql: pgConfig,
}
var response Database
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/databases/create",
"Bearer "+owner.Token,
request,
http.StatusCreated,
&response,
)
defer RemoveTestDatabase(&response)
assert.Equal(t, "Cloud ReadOnly DB", response.Name)
assert.NotEqual(t, uuid.Nil, response.ID)
}
func Test_CreateDatabase_WhenNotCloudAndUserIsNotReadOnly_DatabaseCreated(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Non-Cloud", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
request := Database{
Name: "Non-Cloud DB",
WorkspaceID: &workspace.ID,
Type: DatabaseTypePostgres,
Postgresql: getTestPostgresConfig(),
}
var response Database
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/databases/create",
"Bearer "+owner.Token,
request,
http.StatusCreated,
&response,
)
defer RemoveTestDatabase(&response)
assert.Equal(t, "Non-Cloud DB", response.Name)
assert.NotEqual(t, uuid.Nil, response.ID)
}
func enableCloud(t *testing.T) {
t.Helper()
config.GetEnv().IsCloud = true
t.Cleanup(func() {
config.GetEnv().IsCloud = false
})
}
func createReadOnlyUserViaAPI(
t *testing.T,
router *gin.Engine,
databaseID uuid.UUID,
token string,
) *CreateReadOnlyUserResponse {
var database Database
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/databases/%s", databaseID.String()),
"Bearer "+token,
http.StatusOK,
&database,
)
var response CreateReadOnlyUserResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/databases/create-readonly-user",
"Bearer "+token,
database,
http.StatusOK,
&response,
)
return &response
}
func getTestMariadbConfig() *mariadb.MariadbDatabase {
env := config.GetEnv()
portStr := env.TestMariadb1011Port

View File

@@ -81,8 +81,8 @@ func (p *PostgresqlDatabase) Validate() error {
p.BackupType = PostgresBackupTypePgDump
}
if p.BackupType == PostgresBackupTypePgDump && config.GetEnv().IsCloud {
return errors.New("PG_DUMP backup type is not supported in cloud mode")
if p.BackupType != PostgresBackupTypePgDump && config.GetEnv().IsCloud {
return errors.New("only PG_DUMP backup type is supported in cloud mode")
}
if p.BackupType == PostgresBackupTypePgDump {

View File

@@ -1310,6 +1310,46 @@ func Test_Validate_WhenRequiredFieldsMissing_ReturnsError(t *testing.T) {
}
}
func Test_Validate_WhenCloudAndBackupTypeIsNotPgDump_ValidationFails(t *testing.T) {
enableCloud(t)
model := &PostgresqlDatabase{
Host: "example.com",
Port: 5432,
Username: "user",
Password: "pass",
CpuCount: 1,
BackupType: PostgresBackupTypeWalV1,
}
err := model.Validate()
assert.EqualError(t, err, "only PG_DUMP backup type is supported in cloud mode")
}
func Test_Validate_WhenCloudAndBackupTypeIsPgDump_ValidationPasses(t *testing.T) {
enableCloud(t)
model := &PostgresqlDatabase{
Host: "example.com",
Port: 5432,
Username: "user",
Password: "pass",
CpuCount: 1,
BackupType: PostgresBackupTypePgDump,
}
err := model.Validate()
assert.NoError(t, err)
}
func enableCloud(t *testing.T) {
t.Helper()
config.GetEnv().IsCloud = true
t.Cleanup(func() {
config.GetEnv().IsCloud = false
})
}
type PostgresContainer struct {
Host string
Port int

View File

@@ -1,20 +0,0 @@
package plans
import (
"databasus-backend/internal/util/logger"
)
var databasePlanRepository = &DatabasePlanRepository{}
var databasePlanService = &DatabasePlanService{
databasePlanRepository,
logger.GetLogger(),
}
func GetDatabasePlanService() *DatabasePlanService {
return databasePlanService
}
func GetDatabasePlanRepository() *DatabasePlanRepository {
return databasePlanRepository
}

View File

@@ -1,19 +0,0 @@
package plans
import (
"github.com/google/uuid"
"databasus-backend/internal/util/period"
)
type DatabasePlan struct {
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;primaryKey;not null"`
MaxBackupSizeMB int64 `json:"maxBackupSizeMb" gorm:"column:max_backup_size_mb;type:int;not null"`
MaxBackupsTotalSizeMB int64 `json:"maxBackupsTotalSizeMb" gorm:"column:max_backups_total_size_mb;type:int;not null"`
MaxStoragePeriod period.TimePeriod `json:"maxStoragePeriod" gorm:"column:max_storage_period;type:text;not null"`
}
func (p *DatabasePlan) TableName() string {
return "database_plans"
}

View File

@@ -1,27 +0,0 @@
package plans
import (
"github.com/google/uuid"
"databasus-backend/internal/storage"
)
type DatabasePlanRepository struct{}
func (r *DatabasePlanRepository) GetDatabasePlan(databaseID uuid.UUID) (*DatabasePlan, error) {
var databasePlan DatabasePlan
if err := storage.GetDb().Where("database_id = ?", databaseID).First(&databasePlan).Error; err != nil {
if err.Error() == "record not found" {
return nil, nil
}
return nil, err
}
return &databasePlan, nil
}
func (r *DatabasePlanRepository) CreateDatabasePlan(databasePlan *DatabasePlan) error {
return storage.GetDb().Create(&databasePlan).Error
}

View File

@@ -1,68 +0,0 @@
package plans
import (
"log/slog"
"github.com/google/uuid"
"databasus-backend/internal/config"
"databasus-backend/internal/util/period"
)
type DatabasePlanService struct {
databasePlanRepository *DatabasePlanRepository
logger *slog.Logger
}
func (s *DatabasePlanService) GetDatabasePlan(databaseID uuid.UUID) (*DatabasePlan, error) {
plan, err := s.databasePlanRepository.GetDatabasePlan(databaseID)
if err != nil {
return nil, err
}
if plan == nil {
s.logger.Info("no database plan found, creating default plan", "databaseID", databaseID)
defaultPlan := s.createDefaultDatabasePlan(databaseID)
err := s.databasePlanRepository.CreateDatabasePlan(defaultPlan)
if err != nil {
s.logger.Error("failed to create default database plan", "error", err)
return nil, err
}
return defaultPlan, nil
}
return plan, nil
}
func (s *DatabasePlanService) createDefaultDatabasePlan(databaseID uuid.UUID) *DatabasePlan {
var plan DatabasePlan
isCloud := config.GetEnv().IsCloud
if isCloud {
s.logger.Info("creating default database plan for cloud", "databaseID", databaseID)
// for playground we set limited storages enough to test,
// but not too expensive to provide it for Databasus
plan = DatabasePlan{
DatabaseID: databaseID,
MaxBackupSizeMB: 100, // ~ 1.5GB database
MaxBackupsTotalSizeMB: 4000, // ~ 30 daily backups + 10 manual backups
MaxStoragePeriod: period.PeriodWeek,
}
} else {
s.logger.Info("creating default database plan for self hosted", "databaseID", databaseID)
// by default - everything is unlimited in self hosted mode
plan = DatabasePlan{
DatabaseID: databaseID,
MaxBackupSizeMB: 0,
MaxBackupsTotalSizeMB: 0,
MaxStoragePeriod: period.PeriodForever,
}
}
return &plan
}

View File

@@ -775,7 +775,123 @@ func cleanupDatabaseWithBackup(database *databases.Database, backup *backups_cor
}
}
func Test_RestoreBackup_WhenCloudAndCpuCountMoreThanOne_ReturnsBadRequest(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
enableCloud(t)
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: env_config.GetEnv().TestLocalhost,
Port: 5432,
Username: "postgres",
Password: "postgres",
CpuCount: 4,
},
}
testResp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
"Bearer "+owner.Token,
request,
http.StatusBadRequest,
)
assert.Contains(t, string(testResp.Body), "multi-thread restore is not supported in cloud mode")
}
func Test_RestoreBackup_WhenCloudAndCpuCountIsOne_RestoreInitiated(t *testing.T) {
router := createTestRouter()
_, cleanup := SetupMockRestoreNode(t)
defer cleanup()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
enableCloud(t)
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: env_config.GetEnv().TestLocalhost,
Port: 5432,
Username: "postgres",
Password: "postgres",
CpuCount: 1,
},
}
testResp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
"Bearer "+owner.Token,
request,
http.StatusOK,
)
assert.Contains(t, string(testResp.Body), "restore started successfully")
}
func Test_RestoreBackup_WhenNotCloudAndCpuCountMoreThanOne_RestoreInitiated(t *testing.T) {
router := createTestRouter()
_, cleanup := SetupMockRestoreNode(t)
defer cleanup()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: env_config.GetEnv().TestLocalhost,
Port: 5432,
Username: "postgres",
Password: "postgres",
CpuCount: 4,
},
}
testResp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
"Bearer "+owner.Token,
request,
http.StatusOK,
)
assert.Contains(t, string(testResp.Body), "restore started successfully")
}
func cleanupBackup(backup *backups_core.Backup) {
repo := &backups_core.BackupRepository{}
repo.DeleteByID(backup.ID)
}
func enableCloud(t *testing.T) {
t.Helper()
env_config.GetEnv().IsCloud = true
t.Cleanup(func() {
env_config.GetEnv().IsCloud = false
})
}

View File

@@ -129,11 +129,14 @@ func (s *RestoreService) RestoreBackupWithAuth(
return err
}
if config.GetEnv().IsCloud {
// in cloud mode we use only single thread mode,
// because otherwise we will exhaust local storage
// space (instead of streaming from S3 directly to DB)
requestDTO.PostgresqlDatabase.CpuCount = 1
if config.GetEnv().IsCloud && requestDTO.PostgresqlDatabase != nil &&
requestDTO.PostgresqlDatabase.CpuCount > 1 {
s.logger.Warn("restore rejected: multi-thread mode not supported in cloud",
"requested_cpu_count", requestDTO.PostgresqlDatabase.CpuCount)
return errors.New(
"multi-thread restore is not supported in cloud mode, only single thread (CPU=1) is allowed",
)
}
if err := s.validateVersionCompatibility(backupDatabase, requestDTO); err != nil {

View File

@@ -21,13 +21,19 @@ func NewMultiHandler(
}
func (h *MultiHandler) Enabled(ctx context.Context, level slog.Level) bool {
if h.victoriaLogsWriter != nil {
return level >= slog.LevelDebug
}
return h.stdoutHandler.Enabled(ctx, level)
}
func (h *MultiHandler) Handle(ctx context.Context, record slog.Record) error {
// Send to stdout handler
if err := h.stdoutHandler.Handle(ctx, record); err != nil {
return err
// Send to stdout handler (only if level is enabled for stdout)
if h.stdoutHandler.Enabled(ctx, record.Level) {
if err := h.stdoutHandler.Handle(ctx, record); err != nil {
return err
}
}
// Send to VictoriaLogs if configured

View File

@@ -0,0 +1,102 @@
-- +goose Up
-- +goose StatementBegin
CREATE TABLE subscriptions (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
database_id UUID NOT NULL,
status TEXT NOT NULL,
storage_gb INT NOT NULL,
pending_storage_gb INT,
current_period_start TIMESTAMPTZ NOT NULL,
current_period_end TIMESTAMPTZ NOT NULL,
canceled_at TIMESTAMPTZ,
data_retention_grace_period_until TIMESTAMPTZ,
provider_name TEXT,
provider_sub_id TEXT,
provider_customer_id TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX idx_subscriptions_database_id ON subscriptions (database_id);
CREATE INDEX idx_subscriptions_status ON subscriptions (status);
CREATE INDEX idx_subscriptions_provider_sub_id ON subscriptions (provider_sub_id);
CREATE TABLE invoices (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
subscription_id UUID NOT NULL,
provider_invoice_id TEXT NOT NULL,
amount_cents BIGINT NOT NULL,
storage_gb INT NOT NULL,
period_start TIMESTAMPTZ NOT NULL,
period_end TIMESTAMPTZ NOT NULL,
status TEXT NOT NULL,
paid_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
ALTER TABLE invoices
ADD CONSTRAINT fk_invoices_subscription_id
FOREIGN KEY (subscription_id)
REFERENCES subscriptions (id)
ON DELETE CASCADE;
CREATE INDEX idx_invoices_subscription_id ON invoices (subscription_id);
CREATE INDEX idx_invoices_provider_invoice_id ON invoices (provider_invoice_id);
CREATE TABLE subscription_events (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
subscription_id UUID NOT NULL,
provider_event_id TEXT,
type TEXT NOT NULL,
old_storage_gb INT,
new_storage_gb INT,
old_status TEXT,
new_status TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
ALTER TABLE subscription_events
ADD CONSTRAINT fk_subscription_events_subscription_id
FOREIGN KEY (subscription_id)
REFERENCES subscriptions (id)
ON DELETE CASCADE;
CREATE INDEX idx_subscription_events_subscription_id ON subscription_events (subscription_id);
CREATE TABLE webhook_records (
request_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
provider_name TEXT NOT NULL,
event_type TEXT NOT NULL,
provider_event_id TEXT NOT NULL,
raw_payload TEXT NOT NULL,
processed_at TIMESTAMPTZ,
error TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX idx_webhook_records_provider_event_id ON webhook_records (provider_event_id);
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
DROP INDEX IF EXISTS idx_webhook_records_provider_event_id;
DROP TABLE IF EXISTS webhook_records;
DROP INDEX IF EXISTS idx_subscription_events_subscription_id;
ALTER TABLE subscription_events DROP CONSTRAINT IF EXISTS fk_subscription_events_subscription_id;
DROP TABLE IF EXISTS subscription_events;
DROP INDEX IF EXISTS idx_invoices_provider_invoice_id;
DROP INDEX IF EXISTS idx_invoices_subscription_id;
ALTER TABLE invoices DROP CONSTRAINT IF EXISTS fk_invoices_subscription_id;
DROP TABLE IF EXISTS invoices;
DROP INDEX IF EXISTS idx_subscriptions_provider_sub_id;
DROP INDEX IF EXISTS idx_subscriptions_status;
DROP INDEX IF EXISTS idx_subscriptions_database_id;
DROP TABLE IF EXISTS subscriptions;
-- +goose StatementEnd

View File

@@ -0,0 +1,39 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE backup_configs
DROP COLUMN IF EXISTS max_backup_size_mb,
DROP COLUMN IF EXISTS max_backups_total_size_mb;
DROP INDEX IF EXISTS idx_database_plans_database_id;
ALTER TABLE database_plans
DROP CONSTRAINT IF EXISTS fk_database_plans_database_id;
DROP TABLE IF EXISTS database_plans;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE backup_configs
ADD COLUMN max_backup_size_mb BIGINT NOT NULL DEFAULT 0,
ADD COLUMN max_backups_total_size_mb BIGINT NOT NULL DEFAULT 0;
CREATE TABLE database_plans (
database_id UUID PRIMARY KEY,
max_backup_size_mb BIGINT NOT NULL,
max_backups_total_size_mb BIGINT NOT NULL,
max_storage_period TEXT NOT NULL
);
ALTER TABLE database_plans
ADD CONSTRAINT fk_database_plans_database_id
FOREIGN KEY (database_id)
REFERENCES databases (id)
ON DELETE CASCADE;
CREATE INDEX idx_database_plans_database_id ON database_plans (database_id);
-- +goose StatementEnd

View File

@@ -0,0 +1,11 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE webhook_records
ADD COLUMN is_skipped BOOLEAN NOT NULL DEFAULT FALSE;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE webhook_records
DROP COLUMN is_skipped;
-- +goose StatementEnd

View File

@@ -2,5 +2,8 @@ MODE=development
VITE_GITHUB_CLIENT_ID=
VITE_GOOGLE_CLIENT_ID=
VITE_IS_EMAIL_CONFIGURED=false
VITE_CLOUDFLARE_TURNSTILE_SITE_KEY=
VITE_IS_CLOUD=false
VITE_CLOUDFLARE_TURNSTILE_SITE_KEY=
VITE_CLOUD_PRICE_PER_GB=
VITE_CLOUD_PADDLE_CLIENT_TOKEN=
VITE_CLOUD_IS_PADDLE_SANDBOX=true

View File

@@ -5,6 +5,7 @@ import { Routes } from 'react-router';
import { useVersionCheck } from './shared/hooks/useVersionCheck';
import { IS_CLOUD, IS_PADDLE_SANDBOX, PADDLE_CLIENT_TOKEN } from './constants';
import { userApi } from './entity/users';
import { AuthPageComponent } from './pages/AuthPageComponent';
import { OAuthCallbackPage } from './pages/OAuthCallbackPage';
@@ -18,6 +19,18 @@ function AppContent() {
useVersionCheck();
useEffect(() => {
if (IS_CLOUD && PADDLE_CLIENT_TOKEN) {
Paddle.Environment.set(IS_PADDLE_SANDBOX ? 'sandbox' : 'production');
Paddle.Initialize({
token: PADDLE_CLIENT_TOKEN,
eventCallback: (event) => {
window.dispatchEvent(new CustomEvent('paddle-event', { detail: event }));
},
});
}
}, []);
useEffect(() => {
const isAuthorized = userApi.isAuthorized();
setIsAuthorized(isAuthorized);

View File

@@ -5,6 +5,9 @@ interface RuntimeConfig {
IS_EMAIL_CONFIGURED?: string;
CLOUDFLARE_TURNSTILE_SITE_KEY?: string;
CONTAINER_ARCH?: string;
CLOUD_PRICE_PER_GB?: string;
CLOUD_PADDLE_CLIENT_TOKEN?: string;
CLOUD_IS_PADDLE_SANDBOX?: string;
}
declare global {
@@ -31,6 +34,10 @@ export const APP_VERSION = (import.meta.env.VITE_APP_VERSION as string) || 'dev'
export const IS_CLOUD =
window.__RUNTIME_CONFIG__?.IS_CLOUD === 'true' || import.meta.env.VITE_IS_CLOUD === 'true';
export const CLOUD_PRICE_PER_GB = Number(
window.__RUNTIME_CONFIG__?.CLOUD_PRICE_PER_GB || import.meta.env.VITE_CLOUD_PRICE_PER_GB || '0',
);
export const GITHUB_CLIENT_ID =
window.__RUNTIME_CONFIG__?.GITHUB_CLIENT_ID || import.meta.env.VITE_GITHUB_CLIENT_ID || '';
@@ -46,6 +53,15 @@ export const CLOUDFLARE_TURNSTILE_SITE_KEY =
import.meta.env.VITE_CLOUDFLARE_TURNSTILE_SITE_KEY ||
'';
export const PADDLE_CLIENT_TOKEN =
window.__RUNTIME_CONFIG__?.CLOUD_PADDLE_CLIENT_TOKEN ||
import.meta.env.VITE_CLOUD_PADDLE_CLIENT_TOKEN ||
'';
export const IS_PADDLE_SANDBOX =
window.__RUNTIME_CONFIG__?.CLOUD_IS_PADDLE_SANDBOX === 'true' ||
import.meta.env.VITE_CLOUD_IS_PADDLE_SANDBOX === 'true';
const archMap: Record<string, string> = { amd64: 'x64', arm64: 'arm64' };
const rawArch = window.__RUNTIME_CONFIG__?.CONTAINER_ARCH || 'unknown';
export const CONTAINER_ARCH = archMap[rawArch] || rawArch;

View File

@@ -1,7 +1,6 @@
import { getApplicationServer } from '../../../constants';
import RequestOptions from '../../../shared/api/RequestOptions';
import { apiHelper } from '../../../shared/api/apiHelper';
import type { DatabasePlan } from '../../plan';
import type { BackupConfig } from '../model/BackupConfig';
import type { TransferDatabaseRequest } from '../model/TransferDatabaseRequest';
@@ -55,12 +54,4 @@ export const backupConfigApi = {
requestOptions,
);
},
async getDatabasePlan(databaseId: string) {
return apiHelper.fetchGetJson<DatabasePlan>(
`${getApplicationServer()}/api/v1/backup-configs/database/${databaseId}/plan`,
undefined,
true,
);
},
};

View File

@@ -8,4 +8,3 @@ export { BackupEncryption } from './model/BackupEncryption';
export { PgWalBackupType } from './model/PgWalBackupType';
export { RetentionPolicyType } from './model/RetentionPolicyType';
export type { TransferDatabaseRequest } from './model/TransferDatabaseRequest';
export type { DatabasePlan } from '../plan';

View File

@@ -25,7 +25,4 @@ export interface BackupConfig {
isRetryIfFailed: boolean;
maxFailedTriesCount: number;
encryption: BackupEncryption;
maxBackupSizeMb: number;
maxBackupsTotalSizeMb: number;
}

View File

@@ -0,0 +1,66 @@
import { getApplicationServer } from '../../../constants';
import RequestOptions from '../../../shared/api/RequestOptions';
import { apiHelper } from '../../../shared/api/apiHelper';
import type { ChangeStorageResponse } from '../model/ChangeStorageResponse';
import type { GetInvoicesResponse } from '../model/GetInvoicesResponse';
import type { GetSubscriptionEventsResponse } from '../model/GetSubscriptionEventsResponse';
import type { Subscription } from '../model/Subscription';
export const billingApi = {
async createSubscription(databaseId: string, storageGb: number) {
const requestOptions = new RequestOptions();
requestOptions.setBody(JSON.stringify({ databaseId, storageGb }));
return apiHelper.fetchPostJson<{ paddleTransactionId: string }>(
`${getApplicationServer()}/api/v1/billing/subscription`,
requestOptions,
);
},
async changeStorage(databaseId: string, storageGb: number) {
const requestOptions = new RequestOptions();
requestOptions.setBody(JSON.stringify({ databaseId, storageGb }));
return apiHelper.fetchPostJson<ChangeStorageResponse>(
`${getApplicationServer()}/api/v1/billing/subscription/change-storage`,
requestOptions,
);
},
async getPortalSession(subscriptionId: string) {
return apiHelper.fetchPostJson<{ url: string }>(
`${getApplicationServer()}/api/v1/billing/subscription/portal/${subscriptionId}`,
new RequestOptions(),
);
},
async getSubscriptionEvents(subscriptionId: string, limit?: number, offset?: number) {
const params = new URLSearchParams();
if (limit !== undefined) params.append('limit', limit.toString());
if (offset !== undefined) params.append('offset', offset.toString());
const query = params.toString();
const url = `${getApplicationServer()}/api/v1/billing/subscription/events/${subscriptionId}${query ? `?${query}` : ''}`;
return apiHelper.fetchGetJson<GetSubscriptionEventsResponse>(url, undefined, true);
},
async getInvoices(subscriptionId: string, limit?: number, offset?: number) {
const params = new URLSearchParams();
if (limit !== undefined) params.append('limit', limit.toString());
if (offset !== undefined) params.append('offset', offset.toString());
const query = params.toString();
const url = `${getApplicationServer()}/api/v1/billing/subscription/invoices/${subscriptionId}${query ? `?${query}` : ''}`;
return apiHelper.fetchGetJson<GetInvoicesResponse>(url, undefined, true);
},
async getSubscription(databaseId: string) {
return apiHelper.fetchGetJson<Subscription>(
`${getApplicationServer()}/api/v1/billing/subscription/${databaseId}`,
undefined,
true,
);
},
};

View File

@@ -0,0 +1,11 @@
export { billingApi } from './api/billingApi';
export { SubscriptionStatus } from './model/SubscriptionStatus';
export type { Subscription } from './model/Subscription';
export { InvoiceStatus } from './model/InvoiceStatus';
export type { Invoice } from './model/Invoice';
export { SubscriptionEventType } from './model/SubscriptionEventType';
export type { SubscriptionEvent } from './model/SubscriptionEvent';
export { ChangeStorageApplyMode } from './model/ChangeStorageApplyMode';
export type { ChangeStorageResponse } from './model/ChangeStorageResponse';
export type { GetSubscriptionEventsResponse } from './model/GetSubscriptionEventsResponse';
export type { GetInvoicesResponse } from './model/GetInvoicesResponse';

View File

@@ -0,0 +1,4 @@
export enum ChangeStorageApplyMode {
Immediate = 'immediate',
NextCycle = 'next_cycle',
}

View File

@@ -0,0 +1,7 @@
import { ChangeStorageApplyMode } from './ChangeStorageApplyMode';
export interface ChangeStorageResponse {
applyMode: ChangeStorageApplyMode;
currentGb: number;
pendingGb?: number;
}

View File

@@ -0,0 +1,8 @@
import type { Invoice } from './Invoice';
export interface GetInvoicesResponse {
invoices: Invoice[];
total: number;
limit: number;
offset: number;
}

View File

@@ -0,0 +1,8 @@
import type { SubscriptionEvent } from './SubscriptionEvent';
export interface GetSubscriptionEventsResponse {
events: SubscriptionEvent[];
total: number;
limit: number;
offset: number;
}

View File

@@ -0,0 +1,14 @@
import { InvoiceStatus } from './InvoiceStatus';
export interface Invoice {
id: string;
subscriptionId: string;
providerInvoiceId: string;
amountCents: number;
storageGb: number;
periodStart: string;
periodEnd: string;
status: InvoiceStatus;
paidAt?: string;
createdAt: string;
}

View File

@@ -0,0 +1,7 @@
export enum InvoiceStatus {
Pending = 'pending',
Paid = 'paid',
Failed = 'failed',
Refunded = 'refunded',
Disputed = 'disputed',
}

View File

@@ -0,0 +1,18 @@
import { SubscriptionStatus } from './SubscriptionStatus';
export interface Subscription {
id: string;
databaseId: string;
status: SubscriptionStatus;
storageGb: number;
pendingStorageGb?: number;
currentPeriodStart: string;
currentPeriodEnd: string;
canceledAt?: string;
dataRetentionGracePeriodUntil?: string;
providerName?: string;
providerSubId?: string;
providerCustomerId?: string;
createdAt: string;
updatedAt: string;
}

View File

@@ -0,0 +1,14 @@
import { SubscriptionEventType } from './SubscriptionEventType';
import { SubscriptionStatus } from './SubscriptionStatus';
export interface SubscriptionEvent {
id: string;
subscriptionId: string;
providerEventId?: string;
type: SubscriptionEventType;
oldStorageGb?: number;
newStorageGb?: number;
oldStatus?: SubscriptionStatus;
newStatus?: SubscriptionStatus;
createdAt: string;
}

View File

@@ -0,0 +1,13 @@
export enum SubscriptionEventType {
Created = 'subscription.created',
Upgraded = 'subscription.upgraded',
Downgraded = 'subscription.downgraded',
NewBillingCycleStarted = 'subscription.new_billing_cycle_started',
Canceled = 'subscription.canceled',
Reactivated = 'subscription.reactivated',
Expired = 'subscription.expired',
PastDue = 'subscription.past_due',
RecoveredFromPastDue = 'subscription.recovered_from_past_due',
Refund = 'payment.refund',
Dispute = 'payment.dispute',
}

View File

@@ -0,0 +1,7 @@
export enum SubscriptionStatus {
Trial = 'trial',
Active = 'active',
PastDue = 'past_due',
Canceled = 'canceled',
Expired = 'expired',
}

View File

@@ -1 +0,0 @@
export type { DatabasePlan } from './model/DatabasePlan';

View File

@@ -1,8 +0,0 @@
import type { Period } from '../../databases/model/Period';
export interface DatabasePlan {
databaseId: string;
maxBackupSizeMb: number;
maxBackupsTotalSizeMb: number;
maxStoragePeriod: Period;
}

View File

@@ -0,0 +1,131 @@
import { Button } from 'antd';
import dayjs from 'dayjs';
import { useEffect, useState } from 'react';
import { type Subscription, SubscriptionStatus, billingApi } from '../../../entity/billing';
import { getUserShortTimeFormat } from '../../../shared/time';
import { PurchaseComponent } from '../../billing';
interface Props {
databaseId: string;
isCanManageDBs: boolean;
onNavigateToBilling?: () => void;
}
export const BackupsBillingBannerComponent = ({
databaseId,
isCanManageDBs,
onNavigateToBilling,
}: Props) => {
const [subscription, setSubscription] = useState<Subscription | null>(null);
const [isPurchaseModalOpen, setIsPurchaseModalOpen] = useState(false);
const loadSubscription = async () => {
try {
const sub = await billingApi.getSubscription(databaseId);
setSubscription(sub);
} catch {
setSubscription(null);
}
};
useEffect(() => {
loadSubscription();
}, [databaseId]);
if (
!subscription ||
(subscription.status !== SubscriptionStatus.Trial &&
subscription.status !== SubscriptionStatus.Canceled &&
subscription.status !== SubscriptionStatus.Expired)
) {
return null;
}
return (
<>
<div
className={`mt-3 rounded-lg px-4 py-3 text-sm ${
subscription.status === SubscriptionStatus.Canceled ||
subscription.status === SubscriptionStatus.Expired
? 'border border-red-600/30 bg-red-900/20'
: 'border border-yellow-600/30 bg-yellow-900/20'
}`}
>
<p
className={
subscription.status === SubscriptionStatus.Canceled ||
subscription.status === SubscriptionStatus.Expired
? 'text-red-400'
: 'text-yellow-400'
}
>
{subscription.status === SubscriptionStatus.Trial && (
<>
You are on a free trial. Your trial ends on{' '}
<span className="font-medium">
{dayjs
.utc(subscription.currentPeriodEnd)
.local()
.format(getUserShortTimeFormat().format)}
</span>{' '}
({dayjs.utc(subscription.currentPeriodEnd).local().fromNow()}). After that, backups
will be removed.
</>
)}
{subscription.status === SubscriptionStatus.Canceled && (
<>
Your subscription has been canceled.{' '}
{subscription.dataRetentionGracePeriodUntil ? (
<>
Backups will be removed on{' '}
<span className="font-medium">
{dayjs
.utc(subscription.dataRetentionGracePeriodUntil)
.local()
.format(getUserShortTimeFormat().format)}
</span>{' '}
({dayjs.utc(subscription.dataRetentionGracePeriodUntil).local().fromNow()}).
</>
) : (
<> Backups will be removed after the grace period.</>
)}
</>
)}
{subscription.status === SubscriptionStatus.Expired && (
<>Your subscription has expired.</>
)}
</p>
{isCanManageDBs &&
subscription.status === SubscriptionStatus.Canceled &&
onNavigateToBilling && (
<Button type="primary" size="small" className="mt-2" onClick={onNavigateToBilling}>
Go to Billing
</Button>
)}
{isCanManageDBs && subscription.status !== SubscriptionStatus.Canceled && (
<Button
type="primary"
size="small"
className="mt-2"
onClick={() => setIsPurchaseModalOpen(true)}
>
Purchase storage
</Button>
)}
</div>
{isPurchaseModalOpen && (
<PurchaseComponent
databaseId={databaseId}
onSubscriptionChanged={() => loadSubscription()}
onClose={() => setIsPurchaseModalOpen(false)}
/>
)}
</>
);
};

View File

@@ -14,6 +14,7 @@ import type { ColumnsType } from 'antd/es/table';
import dayjs from 'dayjs';
import { useEffect, useRef, useState } from 'react';
import { IS_CLOUD } from '../../../constants';
import {
type Backup,
type BackupConfig,
@@ -28,6 +29,7 @@ import { getUserTimeFormat } from '../../../shared/time';
import { ConfirmationComponent } from '../../../shared/ui';
import { RestoresComponent } from '../../restores';
import { AgentRestoreComponent } from './AgentRestoreComponent';
import { BackupsBillingBannerComponent } from './BackupsBillingBannerComponent';
const BACKUPS_PAGE_SIZE = 50;
@@ -36,6 +38,7 @@ interface Props {
isCanManageDBs: boolean;
isDirectlyUnderTab?: boolean;
scrollContainerRef?: React.RefObject<HTMLDivElement | null>;
onNavigateToBilling?: () => void;
}
export const BackupsComponent = ({
@@ -43,6 +46,7 @@ export const BackupsComponent = ({
isCanManageDBs,
isDirectlyUnderTab,
scrollContainerRef,
onNavigateToBilling,
}: Props) => {
const [isBackupsLoading, setIsBackupsLoading] = useState(false);
const [backups, setBackups] = useState<Backup[]>([]);
@@ -510,6 +514,14 @@ export const BackupsComponent = ({
>
<h2 className="text-lg font-bold md:text-xl dark:text-white">Backups</h2>
{IS_CLOUD && (
<BackupsBillingBannerComponent
databaseId={database.id}
isCanManageDBs={isCanManageDBs}
onNavigateToBilling={onNavigateToBilling}
/>
)}
{!isBackupConfigLoading && !backupConfig?.isBackupsEnabled && (
<div className="text-sm text-red-600">
Scheduled backups are disabled (you can enable it back in the backup configuration)

View File

@@ -19,7 +19,6 @@ import { IS_CLOUD } from '../../../constants';
import {
type BackupConfig,
BackupEncryption,
type DatabasePlan,
RetentionPolicyType,
backupConfigApi,
} from '../../../entity/backups';
@@ -97,13 +96,9 @@ export const EditBackupConfigComponent = ({
const [isShowWarn, setIsShowWarn] = useState(false);
const [databasePlan, setDatabasePlan] = useState<DatabasePlan>();
const [isLoading, setIsLoading] = useState(true);
const hasAdvancedValues =
!!backupConfig?.isRetryIfFailed ||
(backupConfig?.maxBackupSizeMb ?? 0) > 0 ||
(backupConfig?.maxBackupsTotalSizeMb ?? 0) > 0;
const hasAdvancedValues = !!backupConfig?.isRetryIfFailed;
const [isShowAdvanced, setShowAdvanced] = useState(hasAdvancedValues);
const [isShowGfsHint, setShowGfsHint] = useState(false);
@@ -114,65 +109,6 @@ export const EditBackupConfigComponent = ({
const dateTimeFormat = useMemo(() => getUserTimeFormat(), []);
const createDefaultPlan = (databaseId: string, isCloud: boolean): DatabasePlan => {
if (isCloud) {
return {
databaseId,
maxBackupSizeMb: 100,
maxBackupsTotalSizeMb: 4000,
maxStoragePeriod: Period.WEEK,
};
} else {
return {
databaseId,
maxBackupSizeMb: 0,
maxBackupsTotalSizeMb: 0,
maxStoragePeriod: Period.FOREVER,
};
}
};
const isPeriodAllowed = (period: Period, maxPeriod: Period): boolean => {
const periodOrder = [
Period.DAY,
Period.WEEK,
Period.MONTH,
Period.THREE_MONTH,
Period.SIX_MONTH,
Period.YEAR,
Period.TWO_YEARS,
Period.THREE_YEARS,
Period.FOUR_YEARS,
Period.FIVE_YEARS,
Period.FOREVER,
];
const periodIndex = periodOrder.indexOf(period);
const maxIndex = periodOrder.indexOf(maxPeriod);
return periodIndex <= maxIndex;
};
const availablePeriods = useMemo(() => {
const allPeriods = [
{ label: '1 day', value: Period.DAY },
{ label: '1 week', value: Period.WEEK },
{ label: '1 month', value: Period.MONTH },
{ label: '3 months', value: Period.THREE_MONTH },
{ label: '6 months', value: Period.SIX_MONTH },
{ label: '1 year', value: Period.YEAR },
{ label: '2 years', value: Period.TWO_YEARS },
{ label: '3 years', value: Period.THREE_YEARS },
{ label: '4 years', value: Period.FOUR_YEARS },
{ label: '5 years', value: Period.FIVE_YEARS },
{ label: 'Forever', value: Period.FOREVER },
];
if (!databasePlan) {
return allPeriods;
}
return allPeriods.filter((p) => isPeriodAllowed(p.value, databasePlan.maxStoragePeriod));
}, [databasePlan]);
const updateBackupConfig = (patch: Partial<BackupConfig>) => {
setBackupConfig((prev) => (prev ? { ...prev, ...patch } : prev));
setIsUnsaved(true);
@@ -237,13 +173,7 @@ export const EditBackupConfigComponent = ({
setBackupConfig(config);
setIsUnsaved(false);
setIsSaving(false);
const plan = await backupConfigApi.getDatabasePlan(database.id);
setDatabasePlan(plan);
} else {
const plan = createDefaultPlan('', IS_CLOUD);
setDatabasePlan(plan);
setBackupConfig({
databaseId: database.id,
isBackupsEnabled: true,
@@ -256,11 +186,7 @@ export const EditBackupConfigComponent = ({
retentionPolicyType: IS_CLOUD
? RetentionPolicyType.GFS
: RetentionPolicyType.TimePeriod,
retentionTimePeriod: IS_CLOUD
? plan.maxStoragePeriod === Period.FOREVER
? Period.THREE_MONTH
: plan.maxStoragePeriod
: Period.THREE_MONTH,
retentionTimePeriod: Period.THREE_MONTH,
retentionCount: 100,
retentionGfsHours: 24,
retentionGfsDays: 7,
@@ -271,8 +197,6 @@ export const EditBackupConfigComponent = ({
isRetryIfFailed: true,
maxFailedTriesCount: 3,
encryption: BackupEncryption.ENCRYPTED,
maxBackupSizeMb: plan.maxBackupSizeMb,
maxBackupsTotalSizeMb: plan.maxBackupsTotalSizeMb,
});
}
@@ -604,7 +528,19 @@ export const EditBackupConfigComponent = ({
onChange={(v) => updateBackupConfig({ retentionTimePeriod: v })}
size="small"
className="w-[200px]"
options={availablePeriods}
options={[
{ label: '1 day', value: Period.DAY },
{ label: '1 week', value: Period.WEEK },
{ label: '1 month', value: Period.MONTH },
{ label: '3 months', value: Period.THREE_MONTH },
{ label: '6 months', value: Period.SIX_MONTH },
{ label: '1 year', value: Period.YEAR },
{ label: '2 years', value: Period.TWO_YEARS },
{ label: '3 years', value: Period.THREE_YEARS },
{ label: '4 years', value: Period.FOUR_YEARS },
{ label: '5 years', value: Period.FIVE_YEARS },
{ label: 'Forever', value: Period.FOREVER },
]}
/>
<Tooltip
@@ -829,121 +765,6 @@ export const EditBackupConfigComponent = ({
</div>
</div>
)}
<div className="mt-5 mb-1 flex w-full flex-col items-start sm:flex-row sm:items-center">
<div className="mb-1 min-w-[150px] sm:mb-0">Max backup size limit</div>
<div className="flex items-center">
<Switch
size="small"
checked={backupConfig.maxBackupSizeMb > 0}
disabled={IS_CLOUD}
onChange={(checked) => {
updateBackupConfig({
maxBackupSizeMb: checked ? backupConfig.maxBackupSizeMb || 1000 : 0,
});
}}
/>
<Tooltip
className="cursor-pointer"
title="Limits the size of each individual backup. Note that backups are typically 15× smaller than the database size. For example, a 100 MB backup represents approximately 1.5 GB database."
>
<InfoCircleOutlined className="ml-2" style={{ color: 'gray' }} />
</Tooltip>
</div>
</div>
{backupConfig.maxBackupSizeMb > 0 && (
<div className="mb-5 flex w-full flex-col items-start sm:flex-row sm:items-center">
<div className="mb-1 min-w-[150px] sm:mb-0">Max file size (MB)</div>
<InputNumber
min={1}
max={
databasePlan?.maxBackupSizeMb && databasePlan.maxBackupSizeMb > 0
? databasePlan.maxBackupSizeMb
: undefined
}
value={backupConfig.maxBackupSizeMb}
onChange={(value) => {
const newValue = value || 1;
if (databasePlan?.maxBackupSizeMb && databasePlan.maxBackupSizeMb > 0) {
updateBackupConfig({
maxBackupSizeMb: Math.min(newValue, databasePlan.maxBackupSizeMb),
});
} else {
updateBackupConfig({ maxBackupSizeMb: newValue });
}
}}
size="small"
className="w-full max-w-[75px] grow"
/>
<div className="ml-2 text-xs text-gray-600 dark:text-gray-400">
~{((backupConfig.maxBackupSizeMb / 1024) * 15).toFixed(2)} GB DB size
</div>
</div>
)}
<div className="mb-1 flex w-full flex-col items-start sm:flex-row sm:items-center">
<div className="mb-1 min-w-[150px] sm:mb-0">Limit total backups size</div>
<div className="flex items-center">
<Switch
size="small"
checked={backupConfig.maxBackupsTotalSizeMb > 0}
disabled={IS_CLOUD}
onChange={(checked) => {
updateBackupConfig({
maxBackupsTotalSizeMb: checked
? backupConfig.maxBackupsTotalSizeMb || 1_000_000
: 0,
});
}}
/>
<Tooltip
className="cursor-pointer"
title="Limits the total size of all backups in storage (like S3, local, etc.). Once this limit is exceeded, the oldest backups are automatically removed until the total size is within the limit again."
>
<InfoCircleOutlined className="ml-2" style={{ color: 'gray' }} />
</Tooltip>
</div>
</div>
{backupConfig.maxBackupsTotalSizeMb > 0 && (
<div className="mb-1 flex w-full flex-col items-start sm:flex-row sm:items-center">
<div className="mb-1 min-w-[150px] sm:mb-0">Backups files size (MB)</div>
<InputNumber
min={1}
max={
databasePlan?.maxBackupsTotalSizeMb && databasePlan.maxBackupsTotalSizeMb > 0
? databasePlan.maxBackupsTotalSizeMb
: undefined
}
value={backupConfig.maxBackupsTotalSizeMb}
onChange={(value) => {
const newValue = value || 1;
if (
databasePlan?.maxBackupsTotalSizeMb &&
databasePlan.maxBackupsTotalSizeMb > 0
) {
updateBackupConfig({
maxBackupsTotalSizeMb: Math.min(newValue, databasePlan.maxBackupsTotalSizeMb),
});
} else {
updateBackupConfig({ maxBackupsTotalSizeMb: newValue });
}
}}
size="small"
className="w-full max-w-[75px] grow"
/>
<div className="ml-2 text-xs text-gray-600 dark:text-gray-400">
{(backupConfig.maxBackupsTotalSizeMb / 1024).toFixed(2)} GB (~
{backupConfig.maxBackupsTotalSizeMb / backupConfig.maxBackupSizeMb} backups)
</div>
</div>
)}
</>
)}

View File

@@ -0,0 +1,199 @@
import { App } from 'antd';
import { useEffect, useRef, useState } from 'react';
import { ChangeStorageApplyMode, SubscriptionStatus, billingApi } from '../../../entity/billing';
import type { Subscription } from '../../../entity/billing';
import { POLL_INTERVAL_MS, POLL_TIMEOUT_MS, findSliderPosForGb } from '../models/purchaseUtils';
interface UsePurchaseFlowParams {
databaseId: string;
onSubscriptionChanged: () => void;
onClose: () => void;
}
export function usePurchaseFlow({
databaseId,
onSubscriptionChanged,
onClose,
}: UsePurchaseFlowParams) {
const { message } = App.useApp();
const [subscription, setSubscription] = useState<Subscription | null>(null);
const [isLoadingSubscription, setIsLoadingSubscription] = useState(true);
const [loadError, setLoadError] = useState<string | null>(null);
const [isSubmitting, setIsSubmitting] = useState(false);
const [isCheckoutOpen, setIsCheckoutOpen] = useState(false);
const [isWaitingForPayment, setIsWaitingForPayment] = useState(false);
const [isPaymentConfirmed, setIsPaymentConfirmed] = useState(false);
const [confirmedStorageGb, setConfirmedStorageGb] = useState<number | undefined>();
const [isPaymentTimedOut, setIsPaymentTimedOut] = useState(false);
const [isWaitingForUpgrade, setIsWaitingForUpgrade] = useState(false);
const [isUpgradeTimedOut, setIsUpgradeTimedOut] = useState(false);
const [initialSliderPos, setInitialSliderPos] = useState(0);
const pollingRef = useRef<ReturnType<typeof setInterval> | null>(null);
const timeoutRef = useRef<ReturnType<typeof setTimeout> | null>(null);
const stopPolling = () => {
if (pollingRef.current) {
clearInterval(pollingRef.current);
pollingRef.current = null;
}
if (timeoutRef.current) {
clearTimeout(timeoutRef.current);
timeoutRef.current = null;
}
};
const loadSubscription = async () => {
setIsLoadingSubscription(true);
setLoadError(null);
try {
const sub = await billingApi.getSubscription(databaseId);
setSubscription(sub);
setInitialSliderPos(findSliderPosForGb(sub.storageGb));
} catch {
setLoadError('Failed to load subscription');
} finally {
setIsLoadingSubscription(false);
}
};
const pollForPaymentConfirmation = () => {
setIsWaitingForPayment(true);
setIsPaymentTimedOut(false);
pollingRef.current = setInterval(async () => {
try {
const sub = await billingApi.getSubscription(databaseId);
if (
sub.status !== SubscriptionStatus.Trial &&
sub.status !== SubscriptionStatus.Expired &&
sub.status !== SubscriptionStatus.Canceled
) {
stopPolling();
setIsWaitingForPayment(false);
setIsPaymentConfirmed(true);
setConfirmedStorageGb(sub.storageGb);
onSubscriptionChanged();
}
} catch {
// ignore polling errors, keep trying
}
}, POLL_INTERVAL_MS);
timeoutRef.current = setTimeout(() => {
stopPolling();
setIsWaitingForPayment(false);
setIsPaymentTimedOut(true);
}, POLL_TIMEOUT_MS);
};
const pollForUpgradeConfirmation = (targetStorageGb: number) => {
setIsWaitingForUpgrade(true);
setIsUpgradeTimedOut(false);
pollingRef.current = setInterval(async () => {
try {
const sub = await billingApi.getSubscription(databaseId);
if (sub.storageGb === targetStorageGb && sub.pendingStorageGb === undefined) {
stopPolling();
setIsWaitingForUpgrade(false);
onSubscriptionChanged();
onClose();
}
} catch {
// ignore polling errors, keep trying
}
}, POLL_INTERVAL_MS);
timeoutRef.current = setTimeout(() => {
stopPolling();
setIsWaitingForUpgrade(false);
setIsUpgradeTimedOut(true);
}, POLL_TIMEOUT_MS);
};
const handlePurchase = async (storageGb: number) => {
setIsSubmitting(true);
try {
const result = await billingApi.createSubscription(databaseId, storageGb);
setIsCheckoutOpen(true);
setIsSubmitting(false);
Paddle.Checkout.open({
transactionId: result.paddleTransactionId,
});
} catch {
message.error('Failed to create subscription');
setIsSubmitting(false);
}
};
const handleStorageChange = async (storageGb: number) => {
if (!subscription) return;
setIsSubmitting(true);
try {
const result = await billingApi.changeStorage(databaseId, storageGb);
if (result.applyMode === ChangeStorageApplyMode.Immediate) {
setIsSubmitting(false);
pollForUpgradeConfirmation(storageGb);
} else {
setIsSubmitting(false);
onSubscriptionChanged();
onClose();
}
} catch {
message.error('Failed to change storage');
setIsSubmitting(false);
}
};
useEffect(() => {
loadSubscription();
return () => stopPolling();
}, [databaseId]);
useEffect(() => {
const handlePaddleEvent = (e: Event) => {
const event = (e as CustomEvent<PaddleEvent>).detail;
if (event.name === 'checkout.completed') {
setIsCheckoutOpen(false);
pollForPaymentConfirmation();
} else if (event.name === 'checkout.closed') {
setIsCheckoutOpen(false);
}
};
window.addEventListener('paddle-event', handlePaddleEvent);
return () => window.removeEventListener('paddle-event', handlePaddleEvent);
}, [databaseId]);
return {
subscription,
isLoadingSubscription,
loadError,
isSubmitting,
isCheckoutOpen,
isWaitingForPayment,
isPaymentConfirmed,
confirmedStorageGb,
isPaymentTimedOut,
isWaitingForUpgrade,
isUpgradeTimedOut,
initialSliderPos,
handlePurchase,
handleStorageChange,
};
}

View File

@@ -0,0 +1,2 @@
export { BillingComponent } from './ui/BillingComponent';
export { PurchaseComponent } from './ui/PurchaseComponent';

View File

@@ -0,0 +1,82 @@
const BACKUPS_COMPRESSION_RATIO = 10;
function buildBackupSizeSteps(): number[] {
const values: number[] = [];
for (let i = 1; i <= 100; i++) values.push(i);
for (let i = 110; i <= 200; i += 10) values.push(i);
return values;
}
function buildStorageSizeSteps(): number[] {
const values: number[] = [];
for (let i = 20; i <= 100; i++) values.push(i);
for (let i = 110; i <= 1000; i += 10) values.push(i);
for (let i = 1100; i <= 5000; i += 100) values.push(i);
for (let i = 6000; i <= 10000; i += 1000) values.push(i);
return values;
}
const BACKUP_SIZE_STEPS = buildBackupSizeSteps();
const STORAGE_SIZE_STEPS = buildStorageSizeSteps();
const DB_SIZE_COMMANDS = [
{
label: 'PostgreSQL',
code: `SELECT pg_size_pretty(pg_database_size(current_database()));`,
},
{
label: 'MySQL / MariaDB',
code: `SELECT table_schema AS 'Database',
ROUND(SUM(data_length + index_length) / 1024 / 1024, 2) AS 'Size (MB)'
FROM information_schema.tables
GROUP BY table_schema;`,
},
{
label: 'MongoDB',
code: `db.stats(1024 * 1024) // size in MB`,
},
];
const POLL_INTERVAL_MS = 3000;
const POLL_TIMEOUT_MS = 2 * 60 * 1000;
function distributeGfs(total: number) {
const daily = Math.min(7, total);
const weekly = Math.min(4, Math.max(0, total - daily));
const monthly = Math.min(12, Math.max(0, total - daily - weekly));
const yearly = Math.max(0, total - daily - weekly - monthly);
return { daily, weekly, monthly, yearly };
}
function formatSize(gb: number): string {
if (gb >= 1000) {
const tb = gb / 1000;
return tb % 1 === 0 ? `${tb} TB` : `${tb.toFixed(1)} TB`;
}
return `${gb} GB`;
}
function sliderBackground(pos: number, max: number): React.CSSProperties {
const pct = (pos / max) * 100;
return {
background: `linear-gradient(to right, #155dfc ${pct}%, #1f2937 ${pct}%)`,
};
}
function findSliderPosForGb(gb: number): number {
const idx = STORAGE_SIZE_STEPS.findIndex((s) => s >= gb);
return idx === -1 ? STORAGE_SIZE_STEPS.length - 1 : idx;
}
export {
BACKUPS_COMPRESSION_RATIO,
BACKUP_SIZE_STEPS,
STORAGE_SIZE_STEPS,
DB_SIZE_COMMANDS,
POLL_INTERVAL_MS,
POLL_TIMEOUT_MS,
distributeGfs,
formatSize,
sliderBackground,
findSliderPosForGb,
};

View File

@@ -0,0 +1,59 @@
interface Gfs {
daily: number;
weekly: number;
monthly: number;
yearly: number;
}
interface Props {
backupsFit: number;
gfs: Gfs;
}
export function BackupRetentionSection({ backupsFit, gfs }: Props) {
return (
<div>
<div className="space-y-1.5">
<div className="rounded-lg border border-[#ffffff20] bg-[#1f2937]/50 px-3 py-2 text-center">
<p className="text-gray-500">Total backups</p>
<p className="text-lg font-bold text-gray-200">{backupsFit}</p>
</div>
<div className="my-1 flex items-center gap-3">
<div className="h-px flex-1 bg-[#ffffff20]" />
<span className="text-sm text-gray-500">or</span>
<div className="h-px flex-1 bg-[#ffffff20]" />
</div>
<p className="mb-2 text-sm text-gray-400">
Keeps recent backups frequently, older ones less often broad time at the lowest cost. It
is enough to keep the following amount of backups in GFS:
</p>
<div className="grid grid-cols-2 gap-1.5">
{(
[
['Daily', gfs.daily],
['Weekly', gfs.weekly],
['Monthly', gfs.monthly],
['Yearly', gfs.yearly],
] as const
).map(([label, value]) => (
<div
key={label}
className="rounded-lg border border-[#ffffff20] bg-[#1f2937]/50 px-2 py-1.5 text-center"
>
<p className="text-xs text-gray-500">{label}</p>
<p className="text-base font-bold text-gray-200">{value}</p>
</div>
))}
</div>
</div>
<p className="mt-2 text-sm text-gray-400">
You can fine-tune retention values (change daily count, keep only monthly, keep N latest,
etc.)
</p>
</div>
);
}

View File

@@ -0,0 +1,432 @@
import { App, Button, Spin, Table, Tag } from 'antd';
import type { ColumnsType } from 'antd/es/table';
import dayjs from 'dayjs';
import { useEffect, useState } from 'react';
import { CLOUD_PRICE_PER_GB } from '../../../constants';
import {
type Invoice,
InvoiceStatus,
type Subscription,
type SubscriptionEvent,
SubscriptionEventType,
SubscriptionStatus,
billingApi,
} from '../../../entity/billing';
import type { Database } from '../../../entity/databases';
import { getUserShortTimeFormat, getUserTimeFormat } from '../../../shared/time';
import { PurchaseComponent } from './PurchaseComponent';
const MAX_ROWS = 25;
const STATUS_TAG_COLOR: Record<SubscriptionStatus, string> = {
[SubscriptionStatus.Trial]: 'blue',
[SubscriptionStatus.Active]: 'green',
[SubscriptionStatus.PastDue]: 'orange',
[SubscriptionStatus.Canceled]: 'red',
[SubscriptionStatus.Expired]: 'default',
};
const STATUS_LABEL: Record<SubscriptionStatus, string> = {
[SubscriptionStatus.Trial]: 'Trial',
[SubscriptionStatus.Active]: 'Active',
[SubscriptionStatus.PastDue]: 'Past Due',
[SubscriptionStatus.Canceled]: 'Canceled',
[SubscriptionStatus.Expired]: 'Expired',
};
const INVOICE_STATUS_COLOR: Record<InvoiceStatus, string> = {
[InvoiceStatus.Paid]: 'green',
[InvoiceStatus.Pending]: 'blue',
[InvoiceStatus.Failed]: 'red',
[InvoiceStatus.Refunded]: 'orange',
[InvoiceStatus.Disputed]: 'red',
};
const INVOICE_STATUS_LABEL: Record<InvoiceStatus, string> = {
[InvoiceStatus.Paid]: 'Paid',
[InvoiceStatus.Pending]: 'Pending',
[InvoiceStatus.Failed]: 'Failed',
[InvoiceStatus.Refunded]: 'Refunded',
[InvoiceStatus.Disputed]: 'Disputed',
};
const EVENT_TYPE_LABEL: Record<SubscriptionEventType, string> = {
[SubscriptionEventType.Created]: 'Subscription created',
[SubscriptionEventType.Upgraded]: 'Storage upgraded',
[SubscriptionEventType.Downgraded]: 'Storage downgraded',
[SubscriptionEventType.NewBillingCycleStarted]: 'New billing cycle started',
[SubscriptionEventType.Canceled]: 'Canceled',
[SubscriptionEventType.Reactivated]: 'Reactivated',
[SubscriptionEventType.Expired]: 'Expired',
[SubscriptionEventType.PastDue]: 'Past Due',
[SubscriptionEventType.RecoveredFromPastDue]: 'Recovered',
[SubscriptionEventType.Refund]: 'Refund',
[SubscriptionEventType.Dispute]: 'Dispute',
};
interface Props {
database: Database;
isCanManageDBs: boolean;
}
export const BillingComponent = ({ database, isCanManageDBs }: Props) => {
const { message } = App.useApp();
const [subscription, setSubscription] = useState<Subscription | null>(null);
const [isLoadingSubscription, setIsLoadingSubscription] = useState(true);
const [invoices, setInvoices] = useState<Invoice[]>([]);
const [isLoadingInvoices, setIsLoadingInvoices] = useState(false);
const [totalInvoices, setTotalInvoices] = useState(0);
const [events, setEvents] = useState<SubscriptionEvent[]>([]);
const [isLoadingEvents, setIsLoadingEvents] = useState(false);
const [totalEvents, setTotalEvents] = useState(0);
const [isPurchaseModalOpen, setIsPurchaseModalOpen] = useState(false);
const [isPortalLoading, setIsPortalLoading] = useState(false);
const loadSubscription = async (): Promise<Subscription | null> => {
setIsLoadingSubscription(true);
try {
const sub = await billingApi.getSubscription(database.id);
setSubscription(sub);
return sub;
} catch {
setSubscription(null);
return null;
} finally {
setIsLoadingSubscription(false);
}
};
const loadInvoices = async (subscriptionId: string) => {
setIsLoadingInvoices(true);
try {
const response = await billingApi.getInvoices(subscriptionId, MAX_ROWS, 0);
setInvoices(response.invoices);
setTotalInvoices(response.total);
} catch {
setInvoices([]);
} finally {
setIsLoadingInvoices(false);
}
};
const loadEvents = async (subscriptionId: string) => {
setIsLoadingEvents(true);
try {
const response = await billingApi.getSubscriptionEvents(subscriptionId, MAX_ROWS, 0);
setEvents(response.events);
setTotalEvents(response.total);
} catch {
setEvents([]);
} finally {
setIsLoadingEvents(false);
}
};
const handlePortalClick = async () => {
if (!subscription) return;
setIsPortalLoading(true);
try {
const result = await billingApi.getPortalSession(subscription.id);
window.open(result.url, '_blank');
} catch {
message.error('Failed to open billing portal');
} finally {
setIsPortalLoading(false);
}
};
const handleSubscriptionChanged = async () => {
const sub = await loadSubscription();
if (sub) {
loadInvoices(sub.id);
loadEvents(sub.id);
}
};
useEffect(() => {
loadSubscription();
}, [database.id]);
useEffect(() => {
if (!subscription) return;
loadInvoices(subscription.id);
loadEvents(subscription.id);
}, [subscription?.id]);
const timeFormat = getUserTimeFormat();
const shortFormat = getUserShortTimeFormat();
const canPurchase =
subscription &&
(subscription.status === SubscriptionStatus.Trial ||
subscription.status === SubscriptionStatus.Expired);
const canAccessPortal =
subscription &&
(subscription.status === SubscriptionStatus.Active ||
subscription.status === SubscriptionStatus.PastDue ||
subscription.status === SubscriptionStatus.Canceled);
const isTrial = subscription?.status === SubscriptionStatus.Trial;
const monthlyPrice = subscription ? subscription.storageGb * CLOUD_PRICE_PER_GB : 0;
const invoiceColumns: ColumnsType<Invoice> = [
{
title: 'Period',
dataIndex: 'periodStart',
render: (_: unknown, record: Invoice) =>
`${dayjs.utc(record.periodStart).local().format(shortFormat.format)} - ${dayjs.utc(record.periodEnd).local().format(shortFormat.format)}`,
},
{
title: 'Amount',
dataIndex: 'amountCents',
render: (cents: number) => `$${(cents / 100).toFixed(2)}`,
},
{
title: 'Storage',
dataIndex: 'storageGb',
render: (gb: number) => `${gb} GB`,
},
{
title: 'Status',
dataIndex: 'status',
render: (status: InvoiceStatus) => (
<Tag color={INVOICE_STATUS_COLOR[status]}>{INVOICE_STATUS_LABEL[status]}</Tag>
),
},
{
title: 'Paid At',
dataIndex: 'paidAt',
render: (paidAt: string | undefined) =>
paidAt ? dayjs.utc(paidAt).local().format(timeFormat.format) : '-',
},
];
const eventColumns: ColumnsType<SubscriptionEvent> = [
{
title: 'Date',
dataIndex: 'createdAt',
render: (createdAt: string) => (
<div>
<div>{dayjs.utc(createdAt).local().format(timeFormat.format)}</div>
<div className="text-xs text-gray-500">{dayjs.utc(createdAt).local().fromNow()}</div>
</div>
),
},
{
title: 'Event',
dataIndex: 'type',
render: (type: SubscriptionEventType) => EVENT_TYPE_LABEL[type] ?? type,
},
{
title: 'Details',
render: (_: unknown, record: SubscriptionEvent) => {
const parts: string[] = [];
if (record.oldStorageGb != null && record.newStorageGb != null) {
parts.push(`${record.oldStorageGb} GB \u2192 ${record.newStorageGb} GB`);
}
if (record.oldStatus != null && record.newStatus != null) {
parts.push(
`${STATUS_LABEL[record.oldStatus] ?? record.oldStatus} \u2192 ${STATUS_LABEL[record.newStatus] ?? record.newStatus}`,
);
}
return parts.length > 0 ? parts.join(', ') : '-';
},
},
];
if (isLoadingSubscription) {
return (
<div className="flex w-full justify-center rounded-tr-md rounded-br-md rounded-bl-md bg-white p-10 shadow dark:bg-gray-800">
<Spin size="large" />
</div>
);
}
return (
<div className="w-full rounded-tr-md rounded-br-md rounded-bl-md bg-white p-3 shadow md:p-5 dark:bg-gray-800">
<div className="max-w-[720px]">
<h2 className="text-lg font-bold md:text-xl dark:text-white">Billing</h2>
{/* Subscription Summary */}
{!subscription && (
<div className="mt-4">
<p className="text-gray-500 dark:text-gray-400">
No subscription found for this database.
</p>
{isCanManageDBs && (
<Button type="primary" className="mt-3" onClick={() => setIsPurchaseModalOpen(true)}>
Purchase
</Button>
)}
</div>
)}
{subscription && (
<>
<div className="mt-4 rounded-lg border border-gray-200 p-4 dark:border-gray-700">
<div className="flex items-center gap-2">
<Tag
color={STATUS_TAG_COLOR[subscription.status]}
style={{ fontSize: 14, padding: '2px 12px' }}
>
{STATUS_LABEL[subscription.status]}
</Tag>
{!isTrial && (
<span className="text-2xl font-bold dark:text-white">
${monthlyPrice.toFixed(2)}
<span className="text-sm font-normal text-gray-500">/mo</span>
</span>
)}
</div>
<div className="mt-4 grid grid-cols-2 gap-3">
<div className="rounded-md bg-gray-50 px-3 py-2 dark:bg-gray-700/50">
<div className="text-xs text-gray-500 dark:text-gray-400">Storage</div>
<div className="font-medium dark:text-gray-200">
{subscription.storageGb} GB
{subscription.pendingStorageGb != null && (
<span className="ml-1 text-xs text-yellow-500">
({subscription.pendingStorageGb} GB pending)
</span>
)}
</div>
</div>
<div className="rounded-md bg-gray-50 px-3 py-2 dark:bg-gray-700/50">
<div className="text-xs text-gray-500 dark:text-gray-400">Current period</div>
<div className="font-medium dark:text-gray-200">
{dayjs.utc(subscription.currentPeriodStart).local().format(shortFormat.format)}{' '}
- {dayjs.utc(subscription.currentPeriodEnd).local().format(shortFormat.format)}
</div>
</div>
{subscription.canceledAt && (
<div className="rounded-md bg-red-50 px-3 py-2 dark:bg-red-900/20">
<div className="text-xs text-gray-500 dark:text-gray-400">Canceled at</div>
<div className="font-medium text-red-500">
{dayjs.utc(subscription.canceledAt).local().format(timeFormat.format)}
</div>
<div className="text-xs text-gray-500">
{dayjs.utc(subscription.canceledAt).local().fromNow()}
</div>
</div>
)}
{subscription.dataRetentionGracePeriodUntil && (
<div className="rounded-md bg-yellow-50 px-3 py-2 dark:bg-yellow-900/20">
<div className="text-xs text-gray-500 dark:text-gray-400">
Data retained until
</div>
<div className="font-medium text-yellow-500">
{dayjs
.utc(subscription.dataRetentionGracePeriodUntil)
.local()
.format(timeFormat.format)}
</div>
<div className="text-xs text-gray-500">
{dayjs.utc(subscription.dataRetentionGracePeriodUntil).local().fromNow()}
</div>
</div>
)}
</div>
{isCanManageDBs && (
<div className="mt-4 flex flex-wrap gap-2 border-t border-gray-200 pt-4 dark:border-gray-700">
{canPurchase && (
<Button type="primary" onClick={() => setIsPurchaseModalOpen(true)}>
Purchase
</Button>
)}
{canAccessPortal && (
<>
{subscription.status !== SubscriptionStatus.Canceled && (
<Button onClick={() => setIsPurchaseModalOpen(true)}>Change storage</Button>
)}
{subscription.status === SubscriptionStatus.Canceled ? (
<Button
type="primary"
loading={isPortalLoading}
onClick={handlePortalClick}
>
Resume subscription
</Button>
) : (
<Button loading={isPortalLoading} onClick={handlePortalClick}>
Manage subscription
</Button>
)}
</>
)}
</div>
)}
</div>
{/* Invoices */}
<h3 className="mt-6 mb-3 text-base font-bold dark:text-white">Invoices</h3>
<Table
bordered
size="small"
columns={invoiceColumns}
dataSource={invoices}
rowKey="id"
loading={isLoadingInvoices}
pagination={false}
/>
{totalInvoices > MAX_ROWS && (
<p className="mt-1 text-xs text-gray-500">
Showing {MAX_ROWS} of {totalInvoices} invoices
</p>
)}
{/* Activity */}
<h3 className="mt-6 mb-3 text-base font-bold dark:text-white">Activity</h3>
<Table
bordered
size="small"
columns={eventColumns}
dataSource={events}
rowKey="id"
loading={isLoadingEvents}
pagination={false}
/>
{totalEvents > MAX_ROWS && (
<p className="mt-1 text-xs text-gray-500">
Showing {MAX_ROWS} of {totalEvents} events
</p>
)}
</>
)}
{isPurchaseModalOpen && (
<PurchaseComponent
databaseId={database.id}
onSubscriptionChanged={handleSubscriptionChanged}
onClose={() => setIsPurchaseModalOpen(false)}
/>
)}
</div>
</div>
);
};

View File

@@ -0,0 +1,64 @@
import { useState } from 'react';
interface DbSizeCommand {
label: string;
code: string;
}
interface Props {
commands: DbSizeCommand[];
}
export function DbSizeCommands({ commands }: Props) {
const [copiedIndex, setCopiedIndex] = useState<number | null>(null);
return (
<details className="group mb-2">
<summary className="flex cursor-pointer list-none items-center gap-1.5 text-gray-500 transition-colors hover:text-gray-400">
<svg
width="12"
height="12"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
className="transition-transform group-open:rotate-90"
>
<path d="M9 18l6-6-6-6" />
</svg>
How to check DB size?
</summary>
<div className="mt-2 space-y-1.5">
{commands.map((cmd, index) => (
<div key={index}>
<p className="mb-1 text-xs text-gray-400">{cmd.label}</p>
<div className="relative">
<pre className="overflow-x-auto rounded-lg border border-[#ffffff20] bg-[#1f2937] px-2.5 py-1.5 pr-16 text-xs">
<code className="block whitespace-pre text-gray-300">{cmd.code}</code>
</pre>
<button
onClick={async () => {
try {
await navigator.clipboard.writeText(cmd.code);
setCopiedIndex(index);
setTimeout(() => setCopiedIndex(null), 2000);
} catch {
/* ignore */
}
}}
className={`absolute top-2 right-2 rounded border border-[#ffffff20] px-2 py-0.5 text-white transition-colors ${
copiedIndex === index ? 'bg-green-500' : 'bg-blue-600 hover:bg-blue-700'
}`}
>
{copiedIndex === index ? 'Copied!' : 'Copy'}
</button>
</div>
</div>
))}
</div>
</details>
);
}

View File

@@ -0,0 +1,74 @@
import { Button } from 'antd';
import { SubscriptionStatus } from '../../../entity/billing';
interface Props {
monthlyPrice: number;
currentPrice: number;
isPurchaseFlow: boolean;
isChangeFlow: boolean;
isUpgrade: boolean;
isDowngrade: boolean;
isSameStorage: boolean;
isSubmitting: boolean;
subscriptionStatus: SubscriptionStatus;
onPurchase: () => void;
onChangeStorage: () => void;
}
export function PriceActionBar({
monthlyPrice,
currentPrice,
isPurchaseFlow,
isChangeFlow,
isUpgrade,
isDowngrade,
isSameStorage,
isSubmitting,
subscriptionStatus,
onPurchase,
onChangeStorage,
}: Props) {
return (
<div className="mt-4 flex items-center gap-4 border-t border-[#ffffff20] pt-4">
<div className="flex-1">
<p className="text-2xl font-bold">
${monthlyPrice.toFixed(2)}
<span className="text-base font-medium text-gray-400">/mo</span>
</p>
{isChangeFlow && !isSameStorage && (
<p className="text-xs text-gray-400">Currently ${currentPrice.toFixed(2)}/mo</p>
)}
</div>
<div className="flex flex-col items-end gap-1">
{isPurchaseFlow && (
<Button type="primary" size="large" loading={isSubmitting} onClick={onPurchase}>
{subscriptionStatus === SubscriptionStatus.Canceled ? 'Re-subscribe' : 'Purchase'}
</Button>
)}
{isChangeFlow && (
<>
<Button
type="primary"
size="large"
loading={isSubmitting}
disabled={!!isSameStorage}
onClick={onChangeStorage}
>
{isUpgrade ? 'Upgrade' : isDowngrade ? 'Downgrade' : 'Change Storage'}
</Button>
{isDowngrade && (
<p className="text-xs text-gray-500">
Storage will be reduced from next billing cycle
</p>
)}
</>
)}
</div>
</div>
);
}

View File

@@ -0,0 +1,54 @@
.calcSlider {
-webkit-appearance: none;
appearance: none;
width: 100%;
height: 6px;
border-radius: 3px;
outline: none;
cursor: pointer;
}
.calcSlider::-webkit-slider-thumb {
-webkit-appearance: none;
appearance: none;
width: 20px;
height: 20px;
border-radius: 50%;
background: #155dfc;
border: 2px solid #fff;
cursor: pointer;
transition: transform 0.15s;
box-shadow: 0 0 8px rgba(21, 93, 252, 0.5);
}
.calcSlider::-webkit-slider-thumb:hover {
transform: scale(1.15);
}
.calcSlider::-moz-range-thumb {
width: 20px;
height: 20px;
border-radius: 50%;
background: #155dfc;
border: 2px solid #fff;
cursor: pointer;
box-shadow: 0 0 8px rgba(21, 93, 252, 0.5);
}
.calcSlider::-moz-range-track {
height: 6px;
border-radius: 3px;
background: #1f2937;
}
.calcSlider::-moz-range-progress {
background: #155dfc;
border-radius: 3px;
height: 6px;
}
.calcSlider:focus-visible::-webkit-slider-thumb {
box-shadow:
0 0 0 3px rgba(21, 93, 252, 0.3),
0 0 8px rgba(21, 93, 252, 0.5);
}

View File

@@ -0,0 +1,189 @@
import { Button, Modal, Spin } from 'antd';
import { useEffect, useState } from 'react';
import { usePurchaseFlow } from '../hooks/usePurchaseFlow';
import { CLOUD_PRICE_PER_GB } from '../../../constants';
import { SubscriptionStatus } from '../../../entity/billing';
import {
BACKUPS_COMPRESSION_RATIO,
BACKUP_SIZE_STEPS,
STORAGE_SIZE_STEPS,
distributeGfs,
formatSize,
} from '../models/purchaseUtils';
import { BackupRetentionSection } from './BackupRetentionSection';
import { PriceActionBar } from './PriceActionBar';
import { StorageSlidersSection } from './StorageSlidersSection';
interface Props {
databaseId: string;
onSubscriptionChanged: () => void;
onClose: () => void;
}
export function PurchaseComponent({ databaseId, onSubscriptionChanged, onClose }: Props) {
const flow = usePurchaseFlow({ databaseId, onSubscriptionChanged, onClose });
const [storageSliderPos, setStorageSliderPos] = useState(0);
const [backupSliderPos, setBackupSliderPos] = useState(0);
useEffect(() => {
if (flow.initialSliderPos > 0) {
setStorageSliderPos(flow.initialSliderPos);
}
}, [flow.initialSliderPos]);
const singleBackupSizeGb = BACKUP_SIZE_STEPS[backupSliderPos];
const minStoragePosIndex = STORAGE_SIZE_STEPS.findIndex((s) => s >= singleBackupSizeGb);
const minStoragePos =
minStoragePosIndex === -1 ? STORAGE_SIZE_STEPS.length - 1 : minStoragePosIndex;
const effectiveStoragePos = Math.max(storageSliderPos, minStoragePos);
const newStorageGb = STORAGE_SIZE_STEPS[effectiveStoragePos];
const approximateDbSize = singleBackupSizeGb * BACKUPS_COMPRESSION_RATIO;
const backupsFit = Math.floor(newStorageGb / singleBackupSizeGb);
const gfs = distributeGfs(backupsFit);
const monthlyPrice = newStorageGb * CLOUD_PRICE_PER_GB;
const { subscription } = flow;
const isPurchaseFlow =
subscription &&
(subscription.status === SubscriptionStatus.Trial ||
subscription.status === SubscriptionStatus.Canceled ||
subscription.status === SubscriptionStatus.Expired);
const isChangeFlow =
subscription &&
(subscription.status === SubscriptionStatus.Active ||
subscription.status === SubscriptionStatus.PastDue);
const isUpgrade = isChangeFlow && newStorageGb > subscription.storageGb;
const isDowngrade = isChangeFlow && newStorageGb < subscription.storageGb;
const isSameStorage = isChangeFlow && newStorageGb === subscription.storageGb;
const currentPrice = subscription ? subscription.storageGb * CLOUD_PRICE_PER_GB : 0;
const modalTitle = isPurchaseFlow
? subscription.status === SubscriptionStatus.Canceled
? 'Re-subscribe'
: 'Purchase subscription'
: 'Change Storage';
const isShowingForm =
subscription &&
!flow.isLoadingSubscription &&
!flow.isWaitingForUpgrade &&
!flow.isUpgradeTimedOut &&
!flow.isWaitingForPayment &&
!flow.isPaymentConfirmed &&
!flow.isPaymentTimedOut &&
!flow.isCheckoutOpen;
return (
<Modal
title={modalTitle}
open
onCancel={onClose}
footer={null}
width={700}
maskClosable={false}
>
{flow.isLoadingSubscription && (
<div className="flex justify-center py-10">
<Spin size="large" />
</div>
)}
{flow.loadError && <div className="py-10 text-center text-red-500">{flow.loadError}</div>}
{flow.isWaitingForPayment && (
<div className="flex flex-col items-center gap-4 py-10">
<Spin size="large" />
<p className="text-gray-400">Confirming your payment...</p>
</div>
)}
{flow.isPaymentConfirmed && (
<div className="py-6 text-center">
<p className="mb-1 text-lg font-semibold text-green-600 dark:text-green-400">
Payment successful!
</p>
{flow.confirmedStorageGb !== undefined && (
<p className="mb-4 text-gray-500 dark:text-gray-400">
Your subscription is now active with {flow.confirmedStorageGb} GB of storage.
</p>
)}
<Button type="primary" onClick={onClose}>
OK
</Button>
</div>
)}
{flow.isPaymentTimedOut && (
<div className="py-6 text-center">
<p className="mb-4 text-yellow-500">
Payment confirmation is taking longer than expected. Please reload the page to check the
status.
</p>
<Button onClick={() => window.location.reload()}>Reload page</Button>
</div>
)}
{flow.isUpgradeTimedOut && (
<div className="py-6 text-center">
<p className="mb-2 text-yellow-500">
Upgrade is taking longer than expected, it will be applied shortly. Please reload the
page
</p>
<Button onClick={onClose}>Close</Button>
</div>
)}
{flow.isWaitingForUpgrade && !flow.isUpgradeTimedOut && (
<div className="flex flex-col items-center gap-4 py-10">
<Spin size="large" />
<p className="text-gray-400">Waiting for storage upgrade confirmation...</p>
</div>
)}
{isShowingForm && (
<div>
{isChangeFlow && subscription.pendingStorageGb !== undefined && (
<div className="mb-4 rounded-lg border border-yellow-600/30 bg-yellow-900/20 px-4 py-3 text-sm text-yellow-400">
Pending storage change to {formatSize(subscription.pendingStorageGb)} from next
billing cycle
</div>
)}
<div className="md:grid md:grid-cols-2 md:gap-6">
<StorageSlidersSection
onStorageSliderChange={setStorageSliderPos}
backupSliderPos={backupSliderPos}
onBackupSliderChange={setBackupSliderPos}
effectiveStoragePos={effectiveStoragePos}
newStorageGb={newStorageGb}
singleBackupSizeGb={singleBackupSizeGb}
approximateDbSize={approximateDbSize}
/>
<BackupRetentionSection backupsFit={backupsFit} gfs={gfs} />
</div>
<PriceActionBar
monthlyPrice={monthlyPrice}
currentPrice={currentPrice}
isPurchaseFlow={!!isPurchaseFlow}
isChangeFlow={!!isChangeFlow}
isUpgrade={!!isUpgrade}
isDowngrade={!!isDowngrade}
isSameStorage={!!isSameStorage}
isSubmitting={flow.isSubmitting}
subscriptionStatus={subscription.status}
onPurchase={() => flow.handlePurchase(newStorageGb)}
onChangeStorage={() => flow.handleStorageChange(newStorageGb)}
/>
</div>
)}
</Modal>
);
}

Some files were not shown because too many files have changed in this diff Show More