mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 00:32:03 +02:00
FEATURE (cloud): Add cloud
This commit is contained in:
609
AGENTS.md
609
AGENTS.md
@@ -9,7 +9,6 @@ This is NOT a strict set of rules, but a set of recommendations to help you writ
|
||||
|
||||
- [Engineering philosophy](#engineering-philosophy)
|
||||
- [Backend guidelines](#backend-guidelines)
|
||||
- [Code style](#code-style)
|
||||
- [Boolean naming](#boolean-naming)
|
||||
- [Add reasonable new lines between logical statements](#add-reasonable-new-lines-between-logical-statements)
|
||||
- [Comments](#comments)
|
||||
@@ -19,6 +18,7 @@ This is NOT a strict set of rules, but a set of recommendations to help you writ
|
||||
- [Refactoring](#refactoring)
|
||||
- [Testing](#testing)
|
||||
- [Time handling](#time-handling)
|
||||
- [Logging](#logging)
|
||||
- [CRUD examples](#crud-examples)
|
||||
- [Frontend guidelines](#frontend-guidelines)
|
||||
- [React component structure](#react-component-structure)
|
||||
@@ -122,157 +122,6 @@ Good:
|
||||
|
||||
Exclusion: widely used variables like "db", "ctx", "req", "res", etc.
|
||||
|
||||
### Code style
|
||||
|
||||
**Always place private methods to the bottom of file**
|
||||
|
||||
This rule applies to ALL Go files including tests, services, controllers, repositories, etc.
|
||||
|
||||
In Go, exported (public) functions/methods start with uppercase letters, while unexported (private) ones start with lowercase letters.
|
||||
|
||||
#### Structure order:
|
||||
|
||||
1. Type definitions and constants
|
||||
2. Public methods/functions (uppercase)
|
||||
3. Private methods/functions (lowercase)
|
||||
|
||||
#### Examples:
|
||||
|
||||
**Service with methods:**
|
||||
|
||||
```go
|
||||
type UserService struct {
|
||||
repository *UserRepository
|
||||
}
|
||||
|
||||
// Public methods first
|
||||
func (s *UserService) CreateUser(user *User) error {
|
||||
if err := s.validateUser(user); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.repository.Save(user)
|
||||
}
|
||||
|
||||
func (s *UserService) GetUser(id uuid.UUID) (*User, error) {
|
||||
return s.repository.FindByID(id)
|
||||
}
|
||||
|
||||
// Private methods at the bottom
|
||||
func (s *UserService) validateUser(user *User) error {
|
||||
if user.Name == "" {
|
||||
return errors.New("name is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
**Package-level functions:**
|
||||
|
||||
```go
|
||||
package utils
|
||||
|
||||
// Public functions first
|
||||
func ProcessData(data []byte) (Result, error) {
|
||||
cleaned := sanitizeInput(data)
|
||||
return parseData(cleaned)
|
||||
}
|
||||
|
||||
func ValidateInput(input string) bool {
|
||||
return isValidFormat(input) && checkLength(input)
|
||||
}
|
||||
|
||||
// Private functions at the bottom
|
||||
func sanitizeInput(data []byte) []byte {
|
||||
// implementation
|
||||
}
|
||||
|
||||
func parseData(data []byte) (Result, error) {
|
||||
// implementation
|
||||
}
|
||||
|
||||
func isValidFormat(input string) bool {
|
||||
// implementation
|
||||
}
|
||||
|
||||
func checkLength(input string) bool {
|
||||
// implementation
|
||||
}
|
||||
```
|
||||
|
||||
**Test files:**
|
||||
|
||||
```go
|
||||
package user_test
|
||||
|
||||
// Public test functions first
|
||||
func Test_CreateUser_ValidInput_UserCreated(t *testing.T) {
|
||||
user := createTestUser()
|
||||
result, err := service.CreateUser(user)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
|
||||
func Test_GetUser_ExistingUser_ReturnsUser(t *testing.T) {
|
||||
user := createTestUser()
|
||||
// test implementation
|
||||
}
|
||||
|
||||
// Private helper functions at the bottom
|
||||
func createTestUser() *User {
|
||||
return &User{
|
||||
Name: "Test User",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
}
|
||||
|
||||
func setupTestDatabase() *Database {
|
||||
// setup implementation
|
||||
}
|
||||
```
|
||||
|
||||
**Controller example:**
|
||||
|
||||
```go
|
||||
type ProjectController struct {
|
||||
service *ProjectService
|
||||
}
|
||||
|
||||
// Public HTTP handlers first
|
||||
func (c *ProjectController) CreateProject(ctx *gin.Context) {
|
||||
var request CreateProjectRequest
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
c.handleError(ctx, err)
|
||||
return
|
||||
}
|
||||
// handler logic
|
||||
}
|
||||
|
||||
func (c *ProjectController) GetProject(ctx *gin.Context) {
|
||||
projectID := c.extractProjectID(ctx)
|
||||
// handler logic
|
||||
}
|
||||
|
||||
// Private helper methods at the bottom
|
||||
func (c *ProjectController) handleError(ctx *gin.Context, err error) {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
}
|
||||
|
||||
func (c *ProjectController) extractProjectID(ctx *gin.Context) uuid.UUID {
|
||||
return uuid.MustParse(ctx.Param("projectId"))
|
||||
}
|
||||
```
|
||||
|
||||
#### Key points:
|
||||
|
||||
- **Exported/Public** = starts with uppercase letter (CreateUser, GetProject)
|
||||
- **Unexported/Private** = starts with lowercase letter (validateUser, handleError)
|
||||
- This improves code readability by showing the public API first
|
||||
- Private helpers are implementation details, so they go at the bottom
|
||||
- Apply this rule consistently across ALL Go files in the project
|
||||
|
||||
---
|
||||
|
||||
### Boolean naming
|
||||
|
||||
**Always prefix boolean variables with verbs like `is`, `has`, `was`, `should`, `can`, etc.**
|
||||
@@ -1001,6 +850,20 @@ func Test_BackupLifecycle_CreateAndDelete(t *testing.T) {
|
||||
}
|
||||
```
|
||||
|
||||
#### Cloud testing
|
||||
|
||||
If you are testing cloud, set isCloud = true before test run and defer isCloud = false after test run. Example helper function:
|
||||
|
||||
```go
|
||||
func enableCloud(t *testing.T) {
|
||||
t.Helper()
|
||||
config.GetEnv().IsCloud = true
|
||||
t.Cleanup(func() {
|
||||
config.GetEnv().IsCloud = false
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
#### Testing utilities structure
|
||||
|
||||
**Create `testing.go` or `testing/testing.go` files with common utilities:**
|
||||
@@ -1112,6 +975,100 @@ This ensures consistent timezone handling across the application.
|
||||
|
||||
---
|
||||
|
||||
### Logging
|
||||
|
||||
We use `log/slog` for structured logging. Follow these conventions to keep logs consistent, searchable, and useful for debugging.
|
||||
|
||||
#### Scoped loggers for tracing
|
||||
|
||||
Attach IDs via `logger.With(...)` as early as possible so every downstream log line carries them automatically. Common IDs: `database_id`, `subscription_id`, `backup_id`, `storage_id`, `user_id`.
|
||||
|
||||
```go
|
||||
func (s *BillingService) CreateSubscription(logger *slog.Logger, user *users_models.User, databaseID uuid.UUID, storageGB int) {
|
||||
logger = logger.With("database_id", databaseID)
|
||||
|
||||
// all subsequent logger calls automatically include database_id
|
||||
logger.Debug(fmt.Sprintf("creating subscription for storage %d GB", storageGB))
|
||||
}
|
||||
```
|
||||
|
||||
For background services, create scoped loggers with `task_name` for each subtask in `Run()`:
|
||||
|
||||
```go
|
||||
func (c *BackupCleaner) Run(ctx context.Context) {
|
||||
retentionLog := c.logger.With("task_name", "clean_by_retention_policy")
|
||||
exceededLog := c.logger.With("task_name", "clean_exceeded_backups")
|
||||
|
||||
// pass scoped logger to each method
|
||||
c.cleanByRetentionPolicy(retentionLog)
|
||||
c.cleanExceededBackups(exceededLog)
|
||||
}
|
||||
```
|
||||
|
||||
Within loops, scope further:
|
||||
|
||||
```go
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
dbLog := logger.With("database_id", backupConfig.DatabaseID, "policy", backupConfig.RetentionPolicyType)
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
#### Values in message, IDs as kv pairs
|
||||
|
||||
**Values and statuses** (sizes, counts, status transitions) go into the message via `fmt.Sprintf`:
|
||||
|
||||
```go
|
||||
logger.Info(fmt.Sprintf("subscription renewed: %s -> %s, %d GB", oldStatus, newStatus, sub.StorageGB))
|
||||
logger.Info(
|
||||
fmt.Sprintf("deleted exceeded backup: backup size is %.1f MB, total size is %.1f MB, limit is %d MB",
|
||||
backup.BackupSizeMb, backupsTotalSizeMB, limitPerDbMB),
|
||||
"backup_id", backup.ID,
|
||||
)
|
||||
```
|
||||
|
||||
**IDs** stay as structured kv pairs — never inline them into the message string. This keeps them searchable in log aggregation tools:
|
||||
|
||||
```go
|
||||
// good
|
||||
logger.Info("deleted old backup", "backup_id", backup.ID)
|
||||
|
||||
// bad — ID buried in message, not searchable
|
||||
logger.Info(fmt.Sprintf("deleted old backup %s", backup.ID))
|
||||
```
|
||||
|
||||
**`error` is always a kv pair**, never inlined into the message:
|
||||
|
||||
```go
|
||||
// good
|
||||
logger.Error("failed to save subscription", "error", err)
|
||||
|
||||
// bad
|
||||
logger.Error(fmt.Sprintf("failed to save subscription: %v", err))
|
||||
```
|
||||
|
||||
#### Key naming and message style
|
||||
|
||||
- **snake_case for all log keys**: `database_id`, `backup_id`, `task_name`, `total_size_mb` — not camelCase
|
||||
- **Lowercase log messages**: start with lowercase, no trailing period
|
||||
|
||||
```go
|
||||
// good
|
||||
logger.Error("failed to create checkout session", "error", err)
|
||||
|
||||
// bad
|
||||
logger.Error("Failed to create checkout session.", "error", err)
|
||||
```
|
||||
|
||||
#### Log level usage
|
||||
|
||||
- **Debug**: routine operations, entering a function, query results count (`"getting subscription events"`, `"found 5 invoices"`)
|
||||
- **Info**: significant state changes, completed actions (`"subscription activated"`, `"deleted exceeded backup"`)
|
||||
- **Warn**: degraded but recoverable situations (`"oldest backup is too recent to delete"`, `"requested storage is the same as current"`)
|
||||
- **Error**: failures that need attention (`"failed to save subscription"`, `"failed to delete backup file"`)
|
||||
|
||||
---
|
||||
|
||||
### CRUD examples
|
||||
|
||||
This is an example of complete CRUD implementation structure:
|
||||
@@ -1127,7 +1084,6 @@ import (
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogController struct {
|
||||
@@ -1135,7 +1091,6 @@ type AuditLogController struct {
|
||||
}
|
||||
|
||||
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// All audit log endpoints require authentication (handled in main.go)
|
||||
auditRoutes := router.Group("/audit-logs")
|
||||
|
||||
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
|
||||
@@ -1151,7 +1106,6 @@ func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// @Security BearerAuth
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
@@ -1182,54 +1136,7 @@ func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// GetUserAuditLogs
|
||||
// @Summary Get user audit logs
|
||||
// @Description Retrieve audit logs for a specific user
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param userId path string true "User ID"
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/users/{userId} [get]
|
||||
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
userIDStr := ctx.Param("userId")
|
||||
targetUserID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view user audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
// GetUserAuditLogs follows the same pattern...
|
||||
```
|
||||
|
||||
#### controller_test.go
|
||||
@@ -1237,34 +1144,13 @@ func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
|
||||
```go
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
"databasus-backend/internal/storage"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_GetGlobalAuditLogs_AdminSucceedsAndMemberGetsForbidden(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
memberUser := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
projectID := uuid.New()
|
||||
|
||||
// Create test logs
|
||||
createAuditLog(service, "Test log with user", &adminUser.UserID, nil)
|
||||
createAuditLog(service, "Test log with project", nil, &projectID)
|
||||
createAuditLog(service, "Test log standalone", nil, nil)
|
||||
|
||||
// Test ADMIN can access global logs
|
||||
@@ -1272,13 +1158,9 @@ func Test_GetGlobalAuditLogs_AdminSucceedsAndMemberGetsForbidden(t *testing.T) {
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
"/api/v1/audit-logs/global?limit=10", "Bearer "+adminUser.Token, http.StatusOK, &response)
|
||||
|
||||
assert.GreaterOrEqual(t, len(response.AuditLogs), 3)
|
||||
assert.GreaterOrEqual(t, response.Total, int64(3))
|
||||
|
||||
assert.GreaterOrEqual(t, len(response.AuditLogs), 2)
|
||||
messages := extractMessages(response.AuditLogs)
|
||||
assert.Contains(t, messages, "Test log with user")
|
||||
assert.Contains(t, messages, "Test log with project")
|
||||
assert.Contains(t, messages, "Test log standalone")
|
||||
|
||||
// Test MEMBER cannot access global logs
|
||||
resp := test_utils.MakeGetRequest(t, router, "/api/v1/audit-logs/global",
|
||||
@@ -1286,79 +1168,6 @@ func Test_GetGlobalAuditLogs_AdminSucceedsAndMemberGetsForbidden(t *testing.T) {
|
||||
assert.Contains(t, string(resp.Body), "only administrators can view global audit logs")
|
||||
}
|
||||
|
||||
func Test_GetUserAuditLogs_PermissionsEnforcedCorrectly(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
projectID := uuid.New()
|
||||
|
||||
// Create test logs for different users
|
||||
createAuditLog(service, "Test log user1 first", &user1.UserID, nil)
|
||||
createAuditLog(service, "Test log user1 second", &user1.UserID, &projectID)
|
||||
createAuditLog(service, "Test log user2 first", &user2.UserID, nil)
|
||||
createAuditLog(service, "Test log user2 second", &user2.UserID, &projectID)
|
||||
createAuditLog(service, "Test project log", nil, &projectID)
|
||||
|
||||
// Test ADMIN can view any user's logs
|
||||
var user1Response GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s?limit=10", user1.UserID.String()),
|
||||
"Bearer "+adminUser.Token, http.StatusOK, &user1Response)
|
||||
|
||||
assert.Equal(t, 2, len(user1Response.AuditLogs))
|
||||
messages := extractMessages(user1Response.AuditLogs)
|
||||
assert.Contains(t, messages, "Test log user1 first")
|
||||
assert.Contains(t, messages, "Test log user1 second")
|
||||
|
||||
// Test user can view own logs
|
||||
var ownLogsResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s", user2.UserID.String()),
|
||||
"Bearer "+user2.Token, http.StatusOK, &ownLogsResponse)
|
||||
assert.Equal(t, 2, len(ownLogsResponse.AuditLogs))
|
||||
|
||||
// Test user cannot view other user's logs
|
||||
resp := test_utils.MakeGetRequest(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s", user1.UserID.String()),
|
||||
"Bearer "+user2.Token, http.StatusForbidden)
|
||||
|
||||
assert.Contains(t, string(resp.Body), "insufficient permissions")
|
||||
}
|
||||
|
||||
func Test_FilterAuditLogsByTime_ReturnsOnlyLogsBeforeDate(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
db := storage.GetDb()
|
||||
baseTime := time.Now().UTC()
|
||||
|
||||
// Create logs with different timestamps
|
||||
createTimedLog(db, &adminUser.UserID, "Test old log", baseTime.Add(-2*time.Hour))
|
||||
createTimedLog(db, &adminUser.UserID, "Test recent log", baseTime.Add(-30*time.Minute))
|
||||
createAuditLog(service, "Test current log", &adminUser.UserID, nil)
|
||||
|
||||
// Test filtering - get logs before 1 hour ago
|
||||
beforeTime := baseTime.Add(-1 * time.Hour)
|
||||
var filteredResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/global?beforeDate=%s", beforeTime.Format(time.RFC3339)),
|
||||
"Bearer "+adminUser.Token, http.StatusOK, &filteredResponse)
|
||||
|
||||
// Verify only old log is returned
|
||||
messages := extractMessages(filteredResponse.AuditLogs)
|
||||
assert.Contains(t, messages, "Test old log")
|
||||
assert.NotContains(t, messages, "Test recent log")
|
||||
assert.NotContains(t, messages, "Test current log")
|
||||
|
||||
// Test without filter - should get all logs
|
||||
var allResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router, "/api/v1/audit-logs/global",
|
||||
"Bearer "+adminUser.Token, http.StatusOK, &allResponse)
|
||||
assert.GreaterOrEqual(t, len(allResponse.AuditLogs), 3)
|
||||
}
|
||||
|
||||
func createRouter() *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
@@ -1384,12 +1193,10 @@ import (
|
||||
|
||||
var auditLogRepository = &AuditLogRepository{}
|
||||
var auditLogService = &AuditLogService{
|
||||
auditLogRepository: auditLogRepository,
|
||||
logger: logger.GetLogger(),
|
||||
}
|
||||
var auditLogController = &AuditLogController{
|
||||
auditLogService: auditLogService,
|
||||
auditLogRepository,
|
||||
logger.GetLogger(),
|
||||
}
|
||||
var auditLogController = &AuditLogController{auditLogService}
|
||||
|
||||
func GetAuditLogService() *AuditLogService {
|
||||
return auditLogService
|
||||
@@ -1427,7 +1234,7 @@ type GetAuditLogsResponse struct {
|
||||
}
|
||||
```
|
||||
|
||||
#### models.go
|
||||
#### model.go
|
||||
|
||||
```go
|
||||
package audit_logs
|
||||
@@ -1490,63 +1297,7 @@ func (r *AuditLogRepository) GetGlobal(limit, offset int, beforeDate *time.Time)
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetByUser(
|
||||
userID uuid.UUID,
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLog, error) {
|
||||
var auditLogs []*AuditLog
|
||||
|
||||
query := storage.GetDb().
|
||||
Where("user_id = ?", userID).
|
||||
Order("created_at DESC")
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetByProject(
|
||||
projectID uuid.UUID,
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLog, error) {
|
||||
var auditLogs []*AuditLog
|
||||
|
||||
query := storage.GetDb().
|
||||
Where("project_id = ?", projectID).
|
||||
Order("created_at DESC")
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) CountGlobal(beforeDate *time.Time) (int64, error) {
|
||||
var count int64
|
||||
query := storage.GetDb().Model(&AuditLog{})
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
// GetByUser, GetByProject, CountGlobal follow the same pattern...
|
||||
```
|
||||
|
||||
#### service.go
|
||||
@@ -1570,11 +1321,7 @@ type AuditLogService struct {
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *AuditLogService) WriteAuditLog(
|
||||
message string,
|
||||
userID *uuid.UUID,
|
||||
projectID *uuid.UUID,
|
||||
) {
|
||||
func (s *AuditLogService) WriteAuditLog(message string, userID *uuid.UUID, projectID *uuid.UUID) {
|
||||
auditLog := &AuditLog{
|
||||
UserID: userID,
|
||||
ProjectID: projectID,
|
||||
@@ -1582,17 +1329,11 @@ func (s *AuditLogService) WriteAuditLog(
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
err := s.auditLogRepository.Create(auditLog)
|
||||
if err != nil {
|
||||
if err := s.auditLogRepository.Create(auditLog); err != nil {
|
||||
s.logger.Error("failed to create audit log", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuditLogService) CreateAuditLog(auditLog *AuditLog) error {
|
||||
return s.auditLogRepository.Create(auditLog)
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetGlobalAuditLogs(
|
||||
user *user_models.User,
|
||||
request *GetAuditLogsRequest,
|
||||
@@ -1626,59 +1367,7 @@ func (s *AuditLogService) GetGlobalAuditLogs(
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetUserAuditLogs(
|
||||
targetUserID uuid.UUID,
|
||||
user *user_models.User,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
// Users can view their own logs, ADMIN can view any user's logs
|
||||
if user.Role != user_enums.UserRoleAdmin && user.ID != targetUserID {
|
||||
return nil, errors.New("insufficient permissions to view user audit logs")
|
||||
}
|
||||
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetByUser(targetUserID, limit, offset, request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: int64(len(auditLogs)),
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetProjectAuditLogs(
|
||||
projectID uuid.UUID,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetByProject(projectID, limit, offset, request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: int64(len(auditLogs)),
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
// GetUserAuditLogs, GetProjectAuditLogs follow the same pattern...
|
||||
```
|
||||
|
||||
#### service_test.go
|
||||
@@ -1686,34 +1375,16 @@ func (s *AuditLogService) GetProjectAuditLogs(
|
||||
```go
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func Test_AuditLogs_ProjectSpecificLogs(t *testing.T) {
|
||||
service := GetAuditLogService()
|
||||
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
project1ID, project2ID := uuid.New(), uuid.New()
|
||||
project1ID := uuid.New()
|
||||
|
||||
// Create test logs for projects
|
||||
createAuditLog(service, "Test project1 log first", &user1.UserID, &project1ID)
|
||||
createAuditLog(service, "Test project1 log second", &user2.UserID, &project1ID)
|
||||
createAuditLog(service, "Test project2 log first", &user1.UserID, &project2ID)
|
||||
createAuditLog(service, "Test project2 log second", &user2.UserID, &project2ID)
|
||||
createAuditLog(service, "Test no project log", &user1.UserID, nil)
|
||||
createAuditLog(service, "Test project1 log second", &user1.UserID, &project1ID)
|
||||
|
||||
request := &GetAuditLogsRequest{Limit: 10, Offset: 0}
|
||||
|
||||
// Test project 1 logs
|
||||
project1Response, err := service.GetProjectAuditLogs(project1ID, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(project1Response.AuditLogs))
|
||||
@@ -1721,34 +1392,6 @@ func Test_AuditLogs_ProjectSpecificLogs(t *testing.T) {
|
||||
messages := extractMessages(project1Response.AuditLogs)
|
||||
assert.Contains(t, messages, "Test project1 log first")
|
||||
assert.Contains(t, messages, "Test project1 log second")
|
||||
for _, log := range project1Response.AuditLogs {
|
||||
assert.Equal(t, &project1ID, log.ProjectID)
|
||||
}
|
||||
|
||||
// Test project 2 logs
|
||||
project2Response, err := service.GetProjectAuditLogs(project2ID, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(project2Response.AuditLogs))
|
||||
|
||||
messages2 := extractMessages(project2Response.AuditLogs)
|
||||
assert.Contains(t, messages2, "Test project2 log first")
|
||||
assert.Contains(t, messages2, "Test project2 log second")
|
||||
|
||||
// Test pagination
|
||||
limitedResponse, err := service.GetProjectAuditLogs(project1ID,
|
||||
&GetAuditLogsRequest{Limit: 1, Offset: 0})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(limitedResponse.AuditLogs))
|
||||
assert.Equal(t, 1, limitedResponse.Limit)
|
||||
|
||||
// Test beforeDate filter
|
||||
beforeTime := time.Now().UTC().Add(-1 * time.Minute)
|
||||
filteredResponse, err := service.GetProjectAuditLogs(project1ID,
|
||||
&GetAuditLogsRequest{Limit: 10, BeforeDate: &beforeTime})
|
||||
assert.NoError(t, err)
|
||||
for _, log := range filteredResponse.AuditLogs {
|
||||
assert.True(t, log.CreatedAt.Before(beforeTime))
|
||||
}
|
||||
}
|
||||
|
||||
func createAuditLog(service *AuditLogService, message string, userID, projectID *uuid.UUID) {
|
||||
@@ -1762,16 +1405,6 @@ func extractMessages(logs []*AuditLog) []string {
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
func createTimedLog(db *gorm.DB, userID *uuid.UUID, message string, createdAt time.Time) {
|
||||
log := &AuditLog{
|
||||
ID: uuid.New(),
|
||||
UserID: userID,
|
||||
Message: message,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
db.Create(log)
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
12
Dockerfile
12
Dockerfile
@@ -316,7 +316,9 @@ window.__RUNTIME_CONFIG__ = {
|
||||
GOOGLE_CLIENT_ID: '\${GOOGLE_CLIENT_ID:-}',
|
||||
IS_EMAIL_CONFIGURED: '\$IS_EMAIL_CONFIGURED',
|
||||
CLOUDFLARE_TURNSTILE_SITE_KEY: '\${CLOUDFLARE_TURNSTILE_SITE_KEY:-}',
|
||||
CONTAINER_ARCH: '\${CONTAINER_ARCH:-unknown}'
|
||||
CONTAINER_ARCH: '\${CONTAINER_ARCH:-unknown}',
|
||||
CLOUD_PRICE_PER_GB: '\${CLOUD_PRICE_PER_GB:-}',
|
||||
CLOUD_PADDLE_CLIENT_TOKEN: '\${CLOUD_PADDLE_CLIENT_TOKEN:-}'
|
||||
};
|
||||
JSEOF
|
||||
|
||||
@@ -329,6 +331,14 @@ if [ -n "\${ANALYTICS_SCRIPT:-}" ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
# Inject Paddle script if client token is provided (only if not already injected)
|
||||
if [ -n "\${CLOUD_PADDLE_CLIENT_TOKEN:-}" ]; then
|
||||
if ! grep -q "cdn.paddle.com" /app/ui/build/index.html 2>/dev/null; then
|
||||
echo "Injecting Paddle script..."
|
||||
sed -i "s#</head># <script src=\"https://cdn.paddle.com/paddle/v2/paddle.js\"></script>\n </head>#" /app/ui/build/index.html
|
||||
fi
|
||||
fi
|
||||
|
||||
# Inject static HTML into root div for cloud mode (payment system requires visible legal links)
|
||||
if [ "\${IS_CLOUD:-false}" = "true" ]; then
|
||||
if ! grep -q "cloud-static-content" /app/ui/build/index.html 2>/dev/null; then
|
||||
|
||||
22
NOTICE.md
Normal file
22
NOTICE.md
Normal file
@@ -0,0 +1,22 @@
|
||||
Copyright © 2025–2026 Rostislav Dugin and contributors.
|
||||
|
||||
“Databasus” is a trademark of Rostislav Dugin.
|
||||
|
||||
The source code in this repository is licensed under the Apache License, Version 2.0.
|
||||
That license applies to the code only and does not grant any right to use the
|
||||
Databasus name, logo, or branding, except for reasonable and customary referential
|
||||
use in describing the origin of the software and reproducing the content of this NOTICE.
|
||||
|
||||
Permitted referential use includes truthful use of the name “Databasus” to identify
|
||||
the original Databasus project in software catalogs, deployment templates, hosting
|
||||
panels, package indexes, compatibility pages, integrations, tutorials, reviews, and
|
||||
similar informational materials, including phrases such as “Databasus”,
|
||||
“Deploy Databasus”, “Databasus on Coolify”, and “Compatible with Databasus”.
|
||||
|
||||
You may not use “Databasus” as the name or primary branding of a competing product,
|
||||
service, fork, distribution, or hosted offering, or in any manner likely to cause
|
||||
confusion as to source, affiliation, sponsorship, or endorsement.
|
||||
|
||||
Nothing in this repository transfers, waives, limits, or estops any rights in the
|
||||
Databasus mark. All trademark rights are reserved except for the limited referential
|
||||
use stated above.
|
||||
16
README.md
16
README.md
@@ -1,8 +1,8 @@
|
||||
<div align="center">
|
||||
<img src="assets/logo.svg" alt="Databasus Logo" width="250"/>
|
||||
|
||||
<h3>Backup tool for PostgreSQL, MySQL and MongoDB</h3>
|
||||
<p>Databasus is a free, open source and self-hosted tool to backup databases (with focus on PostgreSQL). Make backups with different storages (S3, Google Drive, FTP, etc.) and notifications about progress (Slack, Discord, Telegram, etc.)</p>
|
||||
<h3>PostgreSQL backup tool (with MySQL\MariaDB and MongoDB support)</h3>
|
||||
<p>Databasus is a free, open source and self-hosted tool to backup databases (with primary focus on PostgreSQL). Make backups with different storages (S3, Google Drive, FTP, etc.) and notifications about progress (Slack, Discord, Telegram, etc.)</p>
|
||||
|
||||
<!-- Badges -->
|
||||
[](https://www.postgresql.org/)
|
||||
@@ -259,7 +259,9 @@ Contributions are welcome! Read the <a href="https://databasus.com/contribute">c
|
||||
|
||||
Also you can join our large community of developers, DBAs and DevOps engineers on Telegram [@databasus_community](https://t.me/databasus_community).
|
||||
|
||||
## AI disclaimer
|
||||
## FAQ
|
||||
|
||||
### AI disclaimer
|
||||
|
||||
There have been questions about AI usage in project development in issues and discussions. As the project focuses on security, reliability and production usage, it's important to explain how AI is used in the development process.
|
||||
|
||||
@@ -295,3 +297,11 @@ Moreover, it's important to note that we do not differentiate between bad human
|
||||
Even if code is written manually by a human, it's not guaranteed to be merged. Vibe code is not allowed at all and all such PRs are rejected by default (see [contributing guide](https://databasus.com/contribute)).
|
||||
|
||||
We also draw attention to fast issue resolution and security [vulnerability reporting](https://github.com/databasus/databasus?tab=security-ov-file#readme).
|
||||
|
||||
### You have a cloud version — are you truly open source?
|
||||
|
||||
Yes. Every feature available in Databasus Cloud is equally available in the self-hosted version with no restrictions, no feature gates and no usage limits. The entire codebase is Apache 2.0 licensed and always will be.
|
||||
|
||||
Databasus is not "open core." We do not withhold features behind a paid tier and then call the limited remainder "open source," as projects like GitLab or Sentry do. We believe open source means the complete product is open, not just a marketing label on a stripped-down edition.
|
||||
|
||||
Databasus Cloud runs the exact same code as the self-hosted version. The only difference is that we take care of infrastructure, high availability, backups, reservations, monitoring and updates for you — so you don't have to. Revenue from Cloud funds full-time development of the project. Most large open-source projects rely on corporate backing or sponsorship to survive. We chose a different path: Databasus sustains itself so it can grow and improve independently, without being tied to any enterprise or sponsor.
|
||||
|
||||
@@ -27,6 +27,13 @@ VALKEY_PORT=6379
|
||||
VALKEY_USERNAME=
|
||||
VALKEY_PASSWORD=
|
||||
VALKEY_IS_SSL=false
|
||||
# billing
|
||||
PRICE_PER_GB_CENTS=
|
||||
IS_PADDLE_SANDBOX=true
|
||||
PADDLE_API_KEY=
|
||||
PADDLE_WEBHOOK_SECRET=
|
||||
PADDLE_PRICE_ID=
|
||||
PADDLE_CLIENT_TOKEN=
|
||||
# testing
|
||||
# to get Google Drive env variables: add storage in UI and copy data from added storage here
|
||||
TEST_GOOGLE_DRIVE_CLIENT_ID=
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime/debug"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -25,6 +26,8 @@ import (
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
backups_services "databasus-backend/internal/features/backups/backups/services"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/billing"
|
||||
billing_paddle "databasus-backend/internal/features/billing/paddle"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/disk"
|
||||
"databasus-backend/internal/features/encryption/secrets"
|
||||
@@ -105,7 +108,9 @@ func main() {
|
||||
go generateSwaggerDocs(log)
|
||||
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
ginApp := gin.Default()
|
||||
ginApp := gin.New()
|
||||
ginApp.Use(gin.Logger())
|
||||
ginApp.Use(ginRecoveryWithLogger(log))
|
||||
|
||||
// Add GZIP compression middleware
|
||||
ginApp.Use(gzip.Gzip(
|
||||
@@ -217,6 +222,10 @@ func setUpRoutes(r *gin.Engine) {
|
||||
backups_controllers.GetPostgresWalBackupController().RegisterRoutes(v1)
|
||||
databases.GetDatabaseController().RegisterPublicRoutes(v1)
|
||||
|
||||
if config.GetEnv().IsCloud {
|
||||
billing_paddle.GetPaddleBillingController().RegisterPublicRoutes(v1)
|
||||
}
|
||||
|
||||
// Setup auth middleware
|
||||
userService := users_services.GetUserService()
|
||||
authMiddleware := users_middleware.AuthMiddleware(userService)
|
||||
@@ -240,6 +249,7 @@ func setUpRoutes(r *gin.Engine) {
|
||||
audit_logs.GetAuditLogController().RegisterRoutes(protected)
|
||||
users_controllers.GetManagementController().RegisterRoutes(protected)
|
||||
users_controllers.GetSettingsController().RegisterRoutes(protected)
|
||||
billing.GetBillingController().RegisterRoutes(protected)
|
||||
}
|
||||
|
||||
func setUpDependencies() {
|
||||
@@ -252,6 +262,11 @@ func setUpDependencies() {
|
||||
storages.SetupDependencies()
|
||||
backups_config.SetupDependencies()
|
||||
task_cancellation.SetupDependencies()
|
||||
billing.SetupDependencies()
|
||||
|
||||
if config.GetEnv().IsCloud {
|
||||
billing_paddle.SetupDependencies()
|
||||
}
|
||||
}
|
||||
|
||||
func runBackgroundTasks(log *slog.Logger) {
|
||||
@@ -308,6 +323,12 @@ func runBackgroundTasks(log *slog.Logger) {
|
||||
go runWithPanicLogging(log, "restore nodes registry background service", func() {
|
||||
restoring.GetRestoreNodesRegistry().Run(ctx)
|
||||
})
|
||||
|
||||
if config.GetEnv().IsCloud {
|
||||
go runWithPanicLogging(log, "billing background service", func() {
|
||||
billing.GetBillingService().Run(ctx, *log)
|
||||
})
|
||||
}
|
||||
} else {
|
||||
log.Info("Skipping primary node tasks as not primary node")
|
||||
}
|
||||
@@ -330,7 +351,7 @@ func runBackgroundTasks(log *slog.Logger) {
|
||||
func runWithPanicLogging(log *slog.Logger, serviceName string, fn func()) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error("Panic in "+serviceName, "error", r)
|
||||
log.Error("Panic in "+serviceName, "error", r, "stacktrace", string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
fn()
|
||||
@@ -410,6 +431,25 @@ func enableCors(ginApp *gin.Engine) {
|
||||
}
|
||||
}
|
||||
|
||||
func ginRecoveryWithLogger(log *slog.Logger) gin.HandlerFunc {
|
||||
return func(ctx *gin.Context) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error("Panic recovered in HTTP handler",
|
||||
"error", r,
|
||||
"stacktrace", string(debug.Stack()),
|
||||
"method", ctx.Request.Method,
|
||||
"path", ctx.Request.URL.Path,
|
||||
)
|
||||
|
||||
ctx.AbortWithStatus(http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
|
||||
ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func mountFrontend(ginApp *gin.Engine) {
|
||||
staticDir := "./ui/build"
|
||||
ginApp.NoRoute(func(c *gin.Context) {
|
||||
|
||||
@@ -5,6 +5,7 @@ go 1.26.1
|
||||
require (
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0
|
||||
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.3
|
||||
github.com/PaddleHQ/paddle-go-sdk v1.0.0
|
||||
github.com/gin-contrib/cors v1.7.5
|
||||
github.com/gin-contrib/gzip v1.2.3
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
@@ -100,6 +101,8 @@ require (
|
||||
github.com/emersion/go-message v0.18.2 // indirect
|
||||
github.com/emersion/go-vcard v0.0.0-20241024213814-c9703dde27ff // indirect
|
||||
github.com/flynn/noise v1.1.0 // indirect
|
||||
github.com/ggicci/httpin v0.19.0 // indirect
|
||||
github.com/ggicci/owl v0.8.2 // indirect
|
||||
github.com/go-chi/chi/v5 v5.2.3 // indirect
|
||||
github.com/go-darwin/apfs v0.0.0-20211011131704-f84b94dbf348 // indirect
|
||||
github.com/go-git/go-billy/v5 v5.6.2 // indirect
|
||||
|
||||
@@ -77,6 +77,8 @@ github.com/Max-Sum/base32768 v0.0.0-20230304063302-18e6ce5945fd/go.mod h1:C8yoIf
|
||||
github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY=
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/PaddleHQ/paddle-go-sdk v1.0.0 h1:+EXitsPFbRcc0CpQE/MIeudxiVOR8pFe/aOWTEUHDKU=
|
||||
github.com/PaddleHQ/paddle-go-sdk v1.0.0/go.mod h1:kbBBzf0BHEj38QvhtoELqlGip3alKgA/I+vl7RQzB58=
|
||||
github.com/ProtonMail/bcrypt v0.0.0-20210511135022-227b4adcab57/go.mod h1:HecWFHognK8GfRDGnFQbW/LiV7A3MX3gZVs45vk5h8I=
|
||||
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf h1:yc9daCCYUefEs69zUkSzubzjBbL+cmOXgnmt9Fyd9ug=
|
||||
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf/go.mod h1:o0ESU9p83twszAU8LBeJKFAAMX14tISa0yk4Oo5TOqo=
|
||||
@@ -248,6 +250,10 @@ github.com/gabriel-vasile/mimetype v1.4.11/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9t
|
||||
github.com/geoffgarside/ber v1.1.0/go.mod h1:jVPKeCbj6MvQZhwLYsGwaGI52oUorHoHKNecGT85ZCc=
|
||||
github.com/geoffgarside/ber v1.2.0 h1:/loowoRcs/MWLYmGX9QtIAbA+V/FrnVLsMMPhwiRm64=
|
||||
github.com/geoffgarside/ber v1.2.0/go.mod h1:jVPKeCbj6MvQZhwLYsGwaGI52oUorHoHKNecGT85ZCc=
|
||||
github.com/ggicci/httpin v0.19.0 h1:p0B3SWLVgg770VirYiHB14M5wdRx3zR8mCTzM/TkTQ8=
|
||||
github.com/ggicci/httpin v0.19.0/go.mod h1:hzsQHcbqLabmGOycf7WNw6AAzcVbsMeoOp46bWAbIWc=
|
||||
github.com/ggicci/owl v0.8.2 h1:og+lhqpzSMPDdEB+NJfzoAJARP7qCG3f8uUC3xvGukA=
|
||||
github.com/ggicci/owl v0.8.2/go.mod h1:PHRD57u41vFN5UtFz2SF79yTVoM3HlWpjMiE+ZU2dj4=
|
||||
github.com/gin-contrib/cors v1.7.5 h1:cXC9SmofOrRg0w9PigwGlHG3ztswH6bqq4vJVXnvYMk=
|
||||
github.com/gin-contrib/cors v1.7.5/go.mod h1:4q3yi7xBEDDWKapjT2o1V7mScKDDr8k+jZ0fSquGoy0=
|
||||
github.com/gin-contrib/gzip v1.2.3 h1:dAhT722RuEG330ce2agAs75z7yB+NKvX/ZM1r8w0u2U=
|
||||
@@ -454,6 +460,8 @@ github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1
|
||||
github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk=
|
||||
github.com/jtolio/noiseconn v0.0.0-20231127013910-f6d9ecbf1de7 h1:JcltaO1HXM5S2KYOYcKgAV7slU0xPy1OcvrVgn98sRQ=
|
||||
github.com/jtolio/noiseconn v0.0.0-20231127013910-f6d9ecbf1de7/go.mod h1:MEkhEPFwP3yudWO0lj6vfYpLIB+3eIcuIW+e0AZzUQk=
|
||||
github.com/justinas/alice v1.2.0 h1:+MHSA/vccVCF4Uq37S42jwlkvI2Xzl7zTPCN5BnZNVo=
|
||||
github.com/justinas/alice v1.2.0/go.mod h1:fN5HRH/reO/zrUflLfTN43t3vXvKzvZIENsNEe7i7qA=
|
||||
github.com/jzelinskie/whirlpool v0.0.0-20201016144138-0675e54bb004 h1:G+9t9cEtnC9jFiTxyptEKuNIAbiN5ZCQzX2a74lj3xg=
|
||||
github.com/jzelinskie/whirlpool v0.0.0-20201016144138-0675e54bb004/go.mod h1:KmHnJWQrgEvbuy0vcvj00gtMqbvNn1L+3YUZLK/B92c=
|
||||
github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU=
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ilyakaznacheev/cleanenv"
|
||||
"github.com/joho/godotenv"
|
||||
@@ -53,6 +54,20 @@ type EnvVariables struct {
|
||||
TempFolder string
|
||||
SecretKeyPath string
|
||||
|
||||
// Billing (always tax-exclusive)
|
||||
PricePerGBCents int64 `env:"PRICE_PER_GB_CENTS"`
|
||||
MinStorageGB int
|
||||
MaxStorageGB int
|
||||
TrialDuration time.Duration
|
||||
TrialStorageGB int
|
||||
GracePeriod time.Duration
|
||||
// Paddle billing
|
||||
IsPaddleSandbox bool `env:"IS_PADDLE_SANDBOX"`
|
||||
PaddleApiKey string `env:"PADDLE_API_KEY"`
|
||||
PaddleWebhookSecret string `env:"PADDLE_WEBHOOK_SECRET"`
|
||||
PaddlePriceID string `env:"PADDLE_PRICE_ID"`
|
||||
PaddleClientToken string `env:"PADDLE_CLIENT_TOKEN"`
|
||||
|
||||
TestGoogleDriveClientID string `env:"TEST_GOOGLE_DRIVE_CLIENT_ID"`
|
||||
TestGoogleDriveClientSecret string `env:"TEST_GOOGLE_DRIVE_CLIENT_SECRET"`
|
||||
TestGoogleDriveTokenJSON string `env:"TEST_GOOGLE_DRIVE_TOKEN_JSON"`
|
||||
@@ -132,9 +147,9 @@ var (
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
func GetEnv() EnvVariables {
|
||||
func GetEnv() *EnvVariables {
|
||||
once.Do(loadEnvVariables)
|
||||
return env
|
||||
return &env
|
||||
}
|
||||
|
||||
func loadEnvVariables() {
|
||||
@@ -363,5 +378,39 @@ func loadEnvVariables() {
|
||||
|
||||
}
|
||||
|
||||
// Billing
|
||||
if env.IsCloud {
|
||||
if env.PricePerGBCents == 0 {
|
||||
log.Error("PRICE_PER_GB_CENTS is empty or zero")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if env.PaddleApiKey == "" {
|
||||
log.Error("PADDLE_API_KEY is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if env.PaddleWebhookSecret == "" {
|
||||
log.Error("PADDLE_WEBHOOK_SECRET is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if env.PaddlePriceID == "" {
|
||||
log.Error("PADDLE_PRICE_ID is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if env.PaddleClientToken == "" {
|
||||
log.Error("PADDLE_CLIENT_TOKEN is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
env.MinStorageGB = 20
|
||||
env.MaxStorageGB = 10_000
|
||||
env.TrialDuration = 24 * time.Hour
|
||||
env.TrialStorageGB = 20
|
||||
env.GracePeriod = 30 * 24 * time.Hour
|
||||
|
||||
log.Info("Environment variables loaded successfully!")
|
||||
}
|
||||
|
||||
@@ -171,26 +171,6 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
backup.BackupSizeMb = completedMBs
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
// Check size limit (0 = unlimited)
|
||||
if backupConfig.MaxBackupSizeMB > 0 &&
|
||||
completedMBs > float64(backupConfig.MaxBackupSizeMB) {
|
||||
errMsg := fmt.Sprintf(
|
||||
"backup size (%.2f MB) exceeded maximum allowed size (%d MB)",
|
||||
completedMBs,
|
||||
backupConfig.MaxBackupSizeMB,
|
||||
)
|
||||
|
||||
backup.Status = backups_core.BackupStatusFailed
|
||||
backup.IsSkipRetry = true
|
||||
backup.FailMessage = &errMsg
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save backup with size exceeded error", "error", err)
|
||||
}
|
||||
cancel() // Cancel the backup context
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to update backup progress", "error", err)
|
||||
}
|
||||
|
||||
@@ -153,121 +153,3 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
assert.Equal(t, notifier.ID, capturedNotifier.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_BackupSizeLimits(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
t.Run("UnlimitedSize_MaxBackupSizeMBIsZero_BackupCompletes", func(t *testing.T) {
|
||||
// Enable backups with unlimited size (0)
|
||||
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
backupConfig.MaxBackupSizeMB = 0 // unlimited
|
||||
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.createBackupUseCase = &CreateLargeBackupUsecase{}
|
||||
|
||||
// Create a backup record
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode.MakeBackup(backup.ID, false)
|
||||
|
||||
// Verify backup completed successfully even with large size
|
||||
updatedBackup, err := backupRepository.FindByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
|
||||
assert.Equal(t, float64(10000), updatedBackup.BackupSizeMb)
|
||||
assert.Nil(t, updatedBackup.FailMessage)
|
||||
})
|
||||
|
||||
t.Run("SizeExceeded_BackupFailedWithIsSkipRetry", func(t *testing.T) {
|
||||
// Enable backups with 5 MB limit
|
||||
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
backupConfig.MaxBackupSizeMB = 5
|
||||
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.createBackupUseCase = &CreateProgressiveBackupUsecase{}
|
||||
|
||||
// Create a backup record
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode.MakeBackup(backup.ID, false)
|
||||
|
||||
// Verify backup was marked as failed with IsSkipRetry=true
|
||||
updatedBackup, err := backupRepository.FindByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusFailed, updatedBackup.Status)
|
||||
assert.True(t, updatedBackup.IsSkipRetry)
|
||||
assert.NotNil(t, updatedBackup.FailMessage)
|
||||
assert.Contains(t, *updatedBackup.FailMessage, "exceeded maximum allowed size")
|
||||
assert.Contains(t, *updatedBackup.FailMessage, "10.00 MB")
|
||||
assert.Contains(t, *updatedBackup.FailMessage, "5 MB")
|
||||
assert.Greater(t, updatedBackup.BackupSizeMb, float64(5))
|
||||
})
|
||||
|
||||
t.Run("SizeWithinLimit_BackupCompletes", func(t *testing.T) {
|
||||
// Enable backups with 100 MB limit
|
||||
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
backupConfig.MaxBackupSizeMB = 100
|
||||
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.createBackupUseCase = &CreateMediumBackupUsecase{}
|
||||
|
||||
// Create a backup record
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode.MakeBackup(backup.ID, false)
|
||||
|
||||
// Verify backup completed successfully
|
||||
updatedBackup, err := backupRepository.FindByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
|
||||
assert.Equal(t, float64(50), updatedBackup.BackupSizeMb)
|
||||
assert.Nil(t, updatedBackup.FailMessage)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/storages"
|
||||
@@ -26,6 +27,7 @@ type BackupCleaner struct {
|
||||
backupRepository *backups_core.BackupRepository
|
||||
storageService *storages.StorageService
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
billingService BillingService
|
||||
fieldEncryptor util_encryption.FieldEncryptor
|
||||
logger *slog.Logger
|
||||
backupRemoveListeners []backups_core.BackupRemoveListener
|
||||
@@ -44,6 +46,10 @@ func (c *BackupCleaner) Run(ctx context.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
retentionLog := c.logger.With("task_name", "clean_by_retention_policy")
|
||||
exceededLog := c.logger.With("task_name", "clean_exceeded_storage_backups")
|
||||
staleLog := c.logger.With("task_name", "clean_stale_basebackups")
|
||||
|
||||
ticker := time.NewTicker(cleanerTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -52,16 +58,16 @@ func (c *BackupCleaner) Run(ctx context.Context) {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := c.cleanByRetentionPolicy(); err != nil {
|
||||
c.logger.Error("Failed to clean backups by retention policy", "error", err)
|
||||
if err := c.cleanByRetentionPolicy(retentionLog); err != nil {
|
||||
retentionLog.Error("failed to clean backups by retention policy", "error", err)
|
||||
}
|
||||
|
||||
if err := c.cleanExceededBackups(); err != nil {
|
||||
c.logger.Error("Failed to clean exceeded backups", "error", err)
|
||||
if err := c.cleanExceededStorageBackups(exceededLog); err != nil {
|
||||
exceededLog.Error("failed to clean exceeded backups", "error", err)
|
||||
}
|
||||
|
||||
if err := c.cleanStaleUploadedBasebackups(); err != nil {
|
||||
c.logger.Error("Failed to clean stale uploaded basebackups", "error", err)
|
||||
if err := c.cleanStaleUploadedBasebackups(staleLog); err != nil {
|
||||
staleLog.Error("failed to clean stale uploaded basebackups", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -104,7 +110,7 @@ func (c *BackupCleaner) AddBackupRemoveListener(listener backups_core.BackupRemo
|
||||
c.backupRemoveListeners = append(c.backupRemoveListeners, listener)
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanStaleUploadedBasebackups() error {
|
||||
func (c *BackupCleaner) cleanStaleUploadedBasebackups(logger *slog.Logger) error {
|
||||
staleBackups, err := c.backupRepository.FindStaleUploadedBasebackups(
|
||||
time.Now().UTC().Add(-10 * time.Minute),
|
||||
)
|
||||
@@ -113,31 +119,30 @@ func (c *BackupCleaner) cleanStaleUploadedBasebackups() error {
|
||||
}
|
||||
|
||||
for _, backup := range staleBackups {
|
||||
backupLog := logger.With("database_id", backup.DatabaseID, "backup_id", backup.ID)
|
||||
|
||||
staleStorage, storageErr := c.storageService.GetStorageByID(backup.StorageID)
|
||||
if storageErr != nil {
|
||||
c.logger.Error(
|
||||
"Failed to get storage for stale basebackup cleanup",
|
||||
"backupId", backup.ID,
|
||||
"storageId", backup.StorageID,
|
||||
backupLog.Error(
|
||||
"failed to get storage for stale basebackup cleanup",
|
||||
"storage_id", backup.StorageID,
|
||||
"error", storageErr,
|
||||
)
|
||||
} else {
|
||||
if err := staleStorage.DeleteFile(c.fieldEncryptor, backup.FileName); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to delete stale basebackup file",
|
||||
"backupId", backup.ID,
|
||||
"fileName", backup.FileName,
|
||||
"error", err,
|
||||
backupLog.Error(
|
||||
fmt.Sprintf("failed to delete stale basebackup file: %s", backup.FileName),
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
metadataFileName := backup.FileName + ".metadata"
|
||||
if err := staleStorage.DeleteFile(c.fieldEncryptor, metadataFileName); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to delete stale basebackup metadata file",
|
||||
"backupId", backup.ID,
|
||||
"fileName", metadataFileName,
|
||||
"error", err,
|
||||
backupLog.Error(
|
||||
fmt.Sprintf("failed to delete stale basebackup metadata file: %s", metadataFileName),
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -147,77 +152,67 @@ func (c *BackupCleaner) cleanStaleUploadedBasebackups() error {
|
||||
backup.FailMessage = &failMsg
|
||||
|
||||
if err := c.backupRepository.Save(backup); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to mark stale uploaded basebackup as failed",
|
||||
"backupId", backup.ID,
|
||||
"error", err,
|
||||
)
|
||||
backupLog.Error("failed to mark stale uploaded basebackup as failed", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Marked stale uploaded basebackup as failed and cleaned storage",
|
||||
"backupId", backup.ID,
|
||||
"databaseId", backup.DatabaseID,
|
||||
)
|
||||
backupLog.Info("marked stale uploaded basebackup as failed and cleaned storage")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanByRetentionPolicy() error {
|
||||
func (c *BackupCleaner) cleanByRetentionPolicy(logger *slog.Logger) error {
|
||||
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
dbLog := logger.With("database_id", backupConfig.DatabaseID, "policy", backupConfig.RetentionPolicyType)
|
||||
|
||||
var cleanErr error
|
||||
|
||||
switch backupConfig.RetentionPolicyType {
|
||||
case backups_config.RetentionPolicyTypeCount:
|
||||
cleanErr = c.cleanByCount(backupConfig)
|
||||
cleanErr = c.cleanByCount(dbLog, backupConfig)
|
||||
case backups_config.RetentionPolicyTypeGFS:
|
||||
cleanErr = c.cleanByGFS(backupConfig)
|
||||
cleanErr = c.cleanByGFS(dbLog, backupConfig)
|
||||
default:
|
||||
cleanErr = c.cleanByTimePeriod(backupConfig)
|
||||
cleanErr = c.cleanByTimePeriod(dbLog, backupConfig)
|
||||
}
|
||||
|
||||
if cleanErr != nil {
|
||||
c.logger.Error(
|
||||
"Failed to clean backups by retention policy",
|
||||
"databaseId", backupConfig.DatabaseID,
|
||||
"policy", backupConfig.RetentionPolicyType,
|
||||
"error", cleanErr,
|
||||
)
|
||||
dbLog.Error("failed to clean backups by retention policy", "error", cleanErr)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanExceededBackups() error {
|
||||
func (c *BackupCleaner) cleanExceededStorageBackups(logger *slog.Logger) error {
|
||||
if !config.GetEnv().IsCloud {
|
||||
return nil
|
||||
}
|
||||
|
||||
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
if backupConfig.MaxBackupsTotalSizeMB <= 0 {
|
||||
dbLog := logger.With("database_id", backupConfig.DatabaseID)
|
||||
|
||||
subscription, subErr := c.billingService.GetSubscription(dbLog, backupConfig.DatabaseID)
|
||||
if subErr != nil {
|
||||
dbLog.Error("failed to get subscription for exceeded backups check", "error", subErr)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := c.cleanExceededBackupsForDatabase(
|
||||
backupConfig.DatabaseID,
|
||||
backupConfig.MaxBackupsTotalSizeMB,
|
||||
); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to clean exceeded backups for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
storageLimitMB := int64(subscription.GetBackupsStorageGB()) * 1024
|
||||
|
||||
if err := c.cleanExceededBackupsForDatabase(dbLog, backupConfig.DatabaseID, storageLimitMB); err != nil {
|
||||
dbLog.Error("failed to clean exceeded backups for database", "error", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -225,7 +220,7 @@ func (c *BackupCleaner) cleanExceededBackups() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanByTimePeriod(backupConfig *backups_config.BackupConfig) error {
|
||||
func (c *BackupCleaner) cleanByTimePeriod(logger *slog.Logger, backupConfig *backups_config.BackupConfig) error {
|
||||
if backupConfig.RetentionTimePeriod == "" {
|
||||
return nil
|
||||
}
|
||||
@@ -255,21 +250,17 @@ func (c *BackupCleaner) cleanByTimePeriod(backupConfig *backups_config.BackupCon
|
||||
}
|
||||
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
c.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
|
||||
logger.Error("failed to delete old backup", "backup_id", backup.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Deleted old backup",
|
||||
"backupId", backup.ID,
|
||||
"databaseId", backupConfig.DatabaseID,
|
||||
)
|
||||
logger.Info("deleted old backup", "backup_id", backup.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanByCount(backupConfig *backups_config.BackupConfig) error {
|
||||
func (c *BackupCleaner) cleanByCount(logger *slog.Logger, backupConfig *backups_config.BackupConfig) error {
|
||||
if backupConfig.RetentionCount <= 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -298,28 +289,20 @@ func (c *BackupCleaner) cleanByCount(backupConfig *backups_config.BackupConfig)
|
||||
}
|
||||
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to delete backup by count policy",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
logger.Error("failed to delete backup by count policy", "backup_id", backup.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Deleted backup by count policy",
|
||||
"backupId", backup.ID,
|
||||
"databaseId", backupConfig.DatabaseID,
|
||||
"retentionCount", backupConfig.RetentionCount,
|
||||
logger.Info(
|
||||
fmt.Sprintf("deleted backup by count policy: retention count is %d", backupConfig.RetentionCount),
|
||||
"backup_id", backup.ID,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanByGFS(backupConfig *backups_config.BackupConfig) error {
|
||||
func (c *BackupCleaner) cleanByGFS(logger *slog.Logger, backupConfig *backups_config.BackupConfig) error {
|
||||
if backupConfig.RetentionGfsHours <= 0 && backupConfig.RetentionGfsDays <= 0 &&
|
||||
backupConfig.RetentionGfsWeeks <= 0 && backupConfig.RetentionGfsMonths <= 0 &&
|
||||
backupConfig.RetentionGfsYears <= 0 {
|
||||
@@ -357,29 +340,20 @@ func (c *BackupCleaner) cleanByGFS(backupConfig *backups_config.BackupConfig) er
|
||||
}
|
||||
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to delete backup by GFS policy",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
logger.Error("failed to delete backup by GFS policy", "backup_id", backup.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Deleted backup by GFS policy",
|
||||
"backupId", backup.ID,
|
||||
"databaseId", backupConfig.DatabaseID,
|
||||
)
|
||||
logger.Info("deleted backup by GFS policy", "backup_id", backup.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanExceededBackupsForDatabase(
|
||||
logger *slog.Logger,
|
||||
databaseID uuid.UUID,
|
||||
limitperDbMB int64,
|
||||
limitPerDbMB int64,
|
||||
) error {
|
||||
for {
|
||||
backupsTotalSizeMB, err := c.backupRepository.GetTotalSizeByDatabase(databaseID)
|
||||
@@ -387,7 +361,7 @@ func (c *BackupCleaner) cleanExceededBackupsForDatabase(
|
||||
return err
|
||||
}
|
||||
|
||||
if backupsTotalSizeMB <= float64(limitperDbMB) {
|
||||
if backupsTotalSizeMB <= float64(limitPerDbMB) {
|
||||
break
|
||||
}
|
||||
|
||||
@@ -400,59 +374,27 @@ func (c *BackupCleaner) cleanExceededBackupsForDatabase(
|
||||
}
|
||||
|
||||
if len(oldestBackups) == 0 {
|
||||
c.logger.Warn(
|
||||
"No backups to delete but still over limit",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"totalSizeMB",
|
||||
backupsTotalSizeMB,
|
||||
"limitMB",
|
||||
limitperDbMB,
|
||||
)
|
||||
logger.Warn(fmt.Sprintf(
|
||||
"no backups to delete but still over limit: total size is %.1f MB, limit is %d MB",
|
||||
backupsTotalSizeMB, limitPerDbMB,
|
||||
))
|
||||
break
|
||||
}
|
||||
|
||||
backup := oldestBackups[0]
|
||||
if isRecentBackup(backup) {
|
||||
c.logger.Warn(
|
||||
"Oldest backup is too recent to delete, stopping size cleanup",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"totalSizeMB",
|
||||
backupsTotalSizeMB,
|
||||
"limitMB",
|
||||
limitperDbMB,
|
||||
)
|
||||
break
|
||||
}
|
||||
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to delete exceeded backup",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
logger.Error("failed to delete exceeded backup", "backup_id", backup.ID, "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Deleted exceeded backup",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"backupSizeMB",
|
||||
backup.BackupSizeMb,
|
||||
"totalSizeMB",
|
||||
backupsTotalSizeMB,
|
||||
"limitMB",
|
||||
limitperDbMB,
|
||||
logger.Info(
|
||||
fmt.Sprintf("deleted exceeded backup: backup size is %.1f MB, total size is %.1f MB, limit is %d MB",
|
||||
backup.BackupSizeMb, backupsTotalSizeMB, limitPerDbMB),
|
||||
"backup_id", backup.ID,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -425,7 +425,7 @@ func Test_CleanByGFS_KeepsCorrectBackupsPerSlot(t *testing.T) {
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -502,7 +502,7 @@ func Test_CleanByGFS_WithWeeklyAndMonthlySlots_KeepsWiderSpread(t *testing.T) {
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -576,7 +576,7 @@ func Test_CleanByGFS_WithHourlySlots_KeepsCorrectBackups(t *testing.T) {
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -677,7 +677,7 @@ func Test_CleanByGFS_SkipsRecentBackup_WhenNotInKeepSet(t *testing.T) {
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -759,7 +759,7 @@ func Test_CleanByGFS_With20DailyBackups_KeepsOnlyExpectedCount(t *testing.T) {
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -844,7 +844,7 @@ func Test_CleanByGFS_WithMultipleBackupsPerDay_KeepsOnlyOnePerDailySlot(t *testi
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -929,7 +929,7 @@ func Test_CleanByGFS_With24HourlySlotsAnd23DailyBackups_DeletesExcessBackups(t *
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -999,7 +999,7 @@ func Test_CleanByGFS_WithDisabledHourlySlotsAnd23DailyBackups_DeletesExcessBacku
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -1069,7 +1069,7 @@ func Test_CleanByGFS_WithDailySlotsAndWeeklyBackups_DeletesExcessBackups(t *test
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -1152,7 +1152,7 @@ func Test_CleanByGFS_WithWeeklySlotsAndMonthlyBackups_DeletesExcessBackups(t *te
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
billing_models "databasus-backend/internal/features/billing/models"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
@@ -17,6 +20,7 @@ import (
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/storage"
|
||||
"databasus-backend/internal/util/logger"
|
||||
"databasus-backend/internal/util/period"
|
||||
)
|
||||
|
||||
@@ -51,6 +55,7 @@ func Test_CleanOldBackups_DeletesBackupsOlderThanRetentionTimePeriod(t *testing.
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
@@ -89,7 +94,7 @@ func Test_CleanOldBackups_DeletesBackupsOlderThanRetentionTimePeriod(t *testing.
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -129,6 +134,7 @@ func Test_CleanOldBackups_SkipsDatabaseWithForeverRetentionPeriod(t *testing.T)
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
@@ -145,7 +151,7 @@ func Test_CleanOldBackups_SkipsDatabaseWithForeverRetentionPeriod(t *testing.T)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -154,7 +160,8 @@ func Test_CleanOldBackups_SkipsDatabaseWithForeverRetentionPeriod(t *testing.T)
|
||||
assert.Equal(t, oldBackup.ID, remainingBackups[0].ID)
|
||||
}
|
||||
|
||||
func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
|
||||
func Test_CleanExceededBackups_WhenUnderStorageLimit_NoBackupsDeleted(t *testing.T) {
|
||||
enableCloud(t)
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
@@ -178,14 +185,14 @@ func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
|
||||
interval := createTestInterval()
|
||||
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
MaxBackupsTotalSizeMB: 100,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
@@ -196,15 +203,18 @@ func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 16.67,
|
||||
BackupSizeMb: 100,
|
||||
CreatedAt: time.Now().UTC().Add(-time.Duration(i) * time.Hour),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanExceededBackups()
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
|
||||
}
|
||||
cleaner := CreateTestBackupCleaner(mockBilling)
|
||||
err = cleaner.cleanExceededStorageBackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -212,7 +222,8 @@ func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
|
||||
assert.Equal(t, 3, len(remainingBackups))
|
||||
}
|
||||
|
||||
func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T) {
|
||||
func Test_CleanExceededBackups_WhenOverStorageLimit_DeletesOldestBackups(t *testing.T) {
|
||||
enableCloud(t)
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
@@ -236,18 +247,20 @@ func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T)
|
||||
interval := createTestInterval()
|
||||
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
MaxBackupsTotalSizeMB: 30,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 5 backups at 300 MB each = 1500 MB total, limit = 1 GB (1024 MB)
|
||||
// Expect 2 oldest deleted, 3 remain (900 MB < 1024 MB)
|
||||
now := time.Now().UTC()
|
||||
var backupIDs []uuid.UUID
|
||||
for i := 0; i < 5; i++ {
|
||||
@@ -256,7 +269,7 @@ func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T)
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10,
|
||||
BackupSizeMb: 300,
|
||||
CreatedAt: now.Add(-time.Duration(4-i) * time.Hour),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
@@ -264,8 +277,11 @@ func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T)
|
||||
backupIDs = append(backupIDs, backup.ID)
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanExceededBackups()
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
|
||||
}
|
||||
cleaner := CreateTestBackupCleaner(mockBilling)
|
||||
err = cleaner.cleanExceededStorageBackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -284,6 +300,7 @@ func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T)
|
||||
}
|
||||
|
||||
func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
|
||||
enableCloud(t)
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
@@ -307,20 +324,21 @@ func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
|
||||
interval := createTestInterval()
|
||||
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
MaxBackupsTotalSizeMB: 50,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
// 3 completed at 500 MB each = 1500 MB, limit = 1 GB (1024 MB)
|
||||
completedBackups := make([]*backups_core.Backup, 3)
|
||||
for i := 0; i < 3; i++ {
|
||||
backup := &backups_core.Backup{
|
||||
@@ -328,7 +346,7 @@ func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 30,
|
||||
BackupSizeMb: 500,
|
||||
CreatedAt: now.Add(-time.Duration(3-i) * time.Hour),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
@@ -347,8 +365,11 @@ func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
|
||||
err = backupRepository.Save(inProgressBackup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanExceededBackups()
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
|
||||
}
|
||||
cleaner := CreateTestBackupCleaner(mockBilling)
|
||||
err = cleaner.cleanExceededStorageBackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -365,7 +386,8 @@ func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
|
||||
assert.True(t, inProgressFound, "In-progress backup should not be deleted")
|
||||
}
|
||||
|
||||
func Test_CleanExceededBackups_WithZeroLimit_SkipsDatabase(t *testing.T) {
|
||||
func Test_CleanExceededBackups_WithZeroStorageLimit_RemovesAllBackups(t *testing.T) {
|
||||
enableCloud(t)
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
@@ -389,14 +411,14 @@ func Test_CleanExceededBackups_WithZeroLimit_SkipsDatabase(t *testing.T) {
|
||||
interval := createTestInterval()
|
||||
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
MaxBackupsTotalSizeMB: 0,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
@@ -408,19 +430,23 @@ func Test_CleanExceededBackups_WithZeroLimit_SkipsDatabase(t *testing.T) {
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 100,
|
||||
CreatedAt: time.Now().UTC().Add(-time.Duration(i) * time.Hour),
|
||||
CreatedAt: time.Now().UTC().Add(-time.Duration(i+2) * time.Hour),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanExceededBackups()
|
||||
// StorageGB=0 means no storage allowed — all backups should be removed
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{StorageGB: 0, Status: billing_models.StatusActive},
|
||||
}
|
||||
cleaner := CreateTestBackupCleaner(mockBilling)
|
||||
err = cleaner.cleanExceededStorageBackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 10, len(remainingBackups))
|
||||
assert.Equal(t, 0, len(remainingBackups))
|
||||
}
|
||||
|
||||
func Test_GetTotalSizeByDatabase_CalculatesCorrectly(t *testing.T) {
|
||||
@@ -522,6 +548,7 @@ func Test_CleanByCount_KeepsNewestNBackups_DeletesOlder(t *testing.T) {
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
@@ -545,7 +572,7 @@ func Test_CleanByCount_KeepsNewestNBackups_DeletesOlder(t *testing.T) {
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -594,6 +621,7 @@ func Test_CleanByCount_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
@@ -612,7 +640,7 @@ func Test_CleanByCount_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -651,6 +679,7 @@ func Test_CleanByCount_DoesNotDeleteInProgressBackups(t *testing.T) {
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
@@ -682,7 +711,7 @@ func Test_CleanByCount_DoesNotDeleteInProgressBackups(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -776,6 +805,7 @@ func Test_CleanByTimePeriod_SkipsRecentBackup_EvenIfOlderThanRetention(t *testin
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
@@ -805,7 +835,7 @@ func Test_CleanByTimePeriod_SkipsRecentBackup_EvenIfOlderThanRetention(t *testin
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -847,6 +877,7 @@ func Test_CleanByCount_SkipsRecentBackup_EvenIfOverLimit(t *testing.T) {
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
@@ -893,7 +924,7 @@ func Test_CleanByCount_SkipsRecentBackup_EvenIfOverLimit(t *testing.T) {
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -914,7 +945,8 @@ func Test_CleanByCount_SkipsRecentBackup_EvenIfOverLimit(t *testing.T) {
|
||||
assert.True(t, remainingIDs[newestBackup.ID], "Newest backup should be preserved")
|
||||
}
|
||||
|
||||
func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testing.T) {
|
||||
func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverStorageLimit(t *testing.T) {
|
||||
enableCloud(t)
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
@@ -937,18 +969,18 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
|
||||
|
||||
interval := createTestInterval()
|
||||
|
||||
// Total size limit is 10 MB. We have two backups of 8 MB each (16 MB total).
|
||||
// Total size limit = 1 GB (1024 MB). Two backups of 600 MB each (1200 MB total).
|
||||
// The oldest backup was created 30 minutes ago — within the grace period.
|
||||
// The cleaner must stop and leave both backups intact.
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
MaxBackupsTotalSizeMB: 10,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
@@ -960,7 +992,7 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 8,
|
||||
BackupSizeMb: 600,
|
||||
CreatedAt: now.Add(-30 * time.Minute),
|
||||
}
|
||||
newerRecentBackup := &backups_core.Backup{
|
||||
@@ -968,7 +1000,7 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 8,
|
||||
BackupSizeMb: 600,
|
||||
CreatedAt: now.Add(-10 * time.Minute),
|
||||
}
|
||||
|
||||
@@ -977,8 +1009,11 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
|
||||
err = backupRepository.Save(newerRecentBackup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanExceededBackups()
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
|
||||
}
|
||||
cleaner := CreateTestBackupCleaner(mockBilling)
|
||||
err = cleaner.cleanExceededStorageBackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -991,6 +1026,82 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
|
||||
)
|
||||
}
|
||||
|
||||
func Test_CleanExceededStorageBackups_WhenNonCloud_SkipsCleanup(t *testing.T) {
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
interval := createTestInterval()
|
||||
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 5 backups at 500 MB each = 2500 MB, would exceed 1 GB limit in cloud mode
|
||||
now := time.Now().UTC()
|
||||
for i := 0; i < 5; i++ {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 500,
|
||||
CreatedAt: now.Add(-time.Duration(i+2) * time.Hour),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// IsCloud is false by default — cleaner should skip entirely
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
|
||||
}
|
||||
cleaner := CreateTestBackupCleaner(mockBilling)
|
||||
err = cleaner.cleanExceededStorageBackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 5, len(remainingBackups), "All backups must remain in non-cloud mode")
|
||||
}
|
||||
|
||||
type mockBillingService struct {
|
||||
subscription *billing_models.Subscription
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockBillingService) GetSubscription(
|
||||
logger *slog.Logger,
|
||||
databaseID uuid.UUID,
|
||||
) (*billing_models.Subscription, error) {
|
||||
return m.subscription, m.err
|
||||
}
|
||||
|
||||
// Mock listener for testing
|
||||
type mockBackupRemoveListener struct {
|
||||
onBeforeBackupRemove func(*backups_core.Backup) error
|
||||
@@ -1041,7 +1152,7 @@ func Test_CleanStaleUploadedBasebackups_MarksAsFailed(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanStaleUploadedBasebackups()
|
||||
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
updated, err := backupRepository.FindByID(staleBackup.ID)
|
||||
@@ -1088,7 +1199,7 @@ func Test_CleanStaleUploadedBasebackups_SkipsRecentUploads(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanStaleUploadedBasebackups()
|
||||
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
updated, err := backupRepository.FindByID(recentBackup.ID)
|
||||
@@ -1131,7 +1242,7 @@ func Test_CleanStaleUploadedBasebackups_SkipsActiveStreaming(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanStaleUploadedBasebackups()
|
||||
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
updated, err := backupRepository.FindByID(activeBackup.ID)
|
||||
@@ -1179,7 +1290,7 @@ func Test_CleanStaleUploadedBasebackups_CleansStorageFiles(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanStaleUploadedBasebackups()
|
||||
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
updated, err := backupRepository.FindByID(staleBackup.ID)
|
||||
@@ -1189,6 +1300,18 @@ func Test_CleanStaleUploadedBasebackups_CleansStorageFiles(t *testing.T) {
|
||||
assert.Contains(t, *updated.FailMessage, "finalization timed out")
|
||||
}
|
||||
|
||||
func enableCloud(t *testing.T) {
|
||||
t.Helper()
|
||||
config.GetEnv().IsCloud = true
|
||||
t.Cleanup(func() {
|
||||
config.GetEnv().IsCloud = false
|
||||
})
|
||||
}
|
||||
|
||||
func testLogger() *slog.Logger {
|
||||
return logger.GetLogger().With("task_name", "test")
|
||||
}
|
||||
|
||||
func createTestInterval() *intervals.Interval {
|
||||
timeOfDay := "04:00"
|
||||
interval := &intervals.Interval{
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/usecases"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/billing"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
@@ -28,6 +29,7 @@ var backupCleaner = &BackupCleaner{
|
||||
backupRepository,
|
||||
storages.GetStorageService(),
|
||||
backups_config.GetBackupConfigService(),
|
||||
billing.GetBillingService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
logger.GetLogger(),
|
||||
[]backups_core.BackupRemoveListener{},
|
||||
@@ -73,6 +75,7 @@ var backupsScheduler = &BackupsScheduler{
|
||||
taskCancelManager,
|
||||
backupNodesRegistry,
|
||||
databases.GetDatabaseService(),
|
||||
billing.GetBillingService(),
|
||||
time.Now().UTC(),
|
||||
logger.GetLogger(),
|
||||
make(map[uuid.UUID]BackupToNodeRelation),
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
billing_models "databasus-backend/internal/features/billing/models"
|
||||
)
|
||||
|
||||
type BillingService interface {
|
||||
GetSubscription(logger *slog.Logger, databaseID uuid.UUID) (*billing_models.Subscription, error)
|
||||
}
|
||||
@@ -29,6 +29,7 @@ type BackupsScheduler struct {
|
||||
taskCancelManager *task_cancellation.TaskCancelManager
|
||||
backupNodesRegistry *BackupNodesRegistry
|
||||
databaseService *databases.DatabaseService
|
||||
billingService BillingService
|
||||
|
||||
lastBackupTime time.Time
|
||||
logger *slog.Logger
|
||||
@@ -127,6 +128,34 @@ func (s *BackupsScheduler) StartBackup(database *databases.Database, isCallNotif
|
||||
return
|
||||
}
|
||||
|
||||
if config.GetEnv().IsCloud {
|
||||
subscription, subErr := s.billingService.GetSubscription(s.logger, database.ID)
|
||||
if subErr != nil || !subscription.CanCreateNewBackups() {
|
||||
failMessage := "subscription has expired, please renew"
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: *backupConfig.StorageID,
|
||||
Status: backups_core.BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
IsSkipRetry: true,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
backup.GenerateFilename(database.Name)
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error(
|
||||
"failed to save failed backup for expired subscription",
|
||||
"database_id", database.ID,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check for existing in-progress backups
|
||||
inProgressBackups, err := s.backupRepository.FindByDatabaseIdAndStatus(
|
||||
database.ID,
|
||||
@@ -346,6 +375,27 @@ func (s *BackupsScheduler) runPendingBackups() error {
|
||||
continue
|
||||
}
|
||||
|
||||
if config.GetEnv().IsCloud {
|
||||
subscription, subErr := s.billingService.GetSubscription(s.logger, backupConfig.DatabaseID)
|
||||
if subErr != nil {
|
||||
s.logger.Warn(
|
||||
"failed to get subscription, skipping backup",
|
||||
"database_id", backupConfig.DatabaseID,
|
||||
"error", subErr,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
if !subscription.CanCreateNewBackups() {
|
||||
s.logger.Debug(
|
||||
"subscription is not active, skipping scheduled backup",
|
||||
"database_id", backupConfig.DatabaseID,
|
||||
"subscription_status", subscription.Status,
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
s.StartBackup(database, remainedBackupTryCount == 1)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
billing_models "databasus-backend/internal/features/billing/models"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
@@ -968,7 +969,7 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
// Start scheduler so it can handle task completions
|
||||
scheduler := CreateTestScheduler()
|
||||
scheduler := CreateTestScheduler(nil)
|
||||
schedulerCancel := StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
@@ -1065,7 +1066,7 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
// Start scheduler so it can handle task completions
|
||||
scheduler := CreateTestScheduler()
|
||||
scheduler := CreateTestScheduler(nil)
|
||||
schedulerCancel := StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
@@ -1332,7 +1333,7 @@ func Test_StartBackup_When2BackupsStartedForDifferentDatabases_BothUseCasesAreCa
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
// Create scheduler
|
||||
scheduler := CreateTestScheduler()
|
||||
scheduler := CreateTestScheduler(nil)
|
||||
schedulerCancel := StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
@@ -1458,3 +1459,313 @@ func Test_StartBackup_When2BackupsStartedForDifferentDatabases_BothUseCasesAreCa
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartBackup_WhenCloudAndSubscriptionExpired_CreatesFailedBackup(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{
|
||||
Status: billing_models.StatusExpired,
|
||||
},
|
||||
}
|
||||
scheduler := CreateTestScheduler(mockBilling)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
enableCloud(t)
|
||||
|
||||
scheduler.StartBackup(database, false)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
|
||||
newestBackup := backups[0]
|
||||
assert.Equal(t, backups_core.BackupStatusFailed, newestBackup.Status)
|
||||
assert.NotNil(t, newestBackup.FailMessage)
|
||||
assert.Equal(t, "subscription has expired, please renew", *newestBackup.FailMessage)
|
||||
assert.True(t, newestBackup.IsSkipRetry)
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartBackup_WhenCloudAndSubscriptionActive_ProceedsNormally(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{
|
||||
Status: billing_models.StatusActive,
|
||||
StorageGB: 10,
|
||||
},
|
||||
}
|
||||
scheduler := CreateTestScheduler(mockBilling)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
enableCloud(t)
|
||||
|
||||
scheduler.StartBackup(database, false)
|
||||
|
||||
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
|
||||
newestBackup := backups[0]
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, newestBackup.Status)
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_RunPendingBackups_WhenCloudAndSubscriptionExpired_SilentlySkips(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{
|
||||
Status: billing_models.StatusExpired,
|
||||
},
|
||||
}
|
||||
scheduler := CreateTestScheduler(mockBilling)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
|
||||
backupConfig.RetentionTimePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupRepository.Save(&backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
enableCloud(t)
|
||||
|
||||
scheduler.runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1, "No new backup should be created, scheduler silently skips expired subscriptions")
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartBackup_WhenNotCloudAndSubscriptionExpired_ProceedsNormally(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{
|
||||
Status: billing_models.StatusExpired,
|
||||
},
|
||||
}
|
||||
scheduler := CreateTestScheduler(mockBilling)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
scheduler.StartBackup(database, false)
|
||||
|
||||
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backups[0].Status,
|
||||
"Billing check should not apply in non-cloud mode")
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_RunPendingBackups_WhenNotCloudAndSubscriptionExpired_ProceedsNormally(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{
|
||||
Status: billing_models.StatusExpired,
|
||||
},
|
||||
}
|
||||
scheduler := CreateTestScheduler(mockBilling)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
|
||||
backupConfig.RetentionTimePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupRepository.Save(&backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
scheduler.runPendingBackups()
|
||||
|
||||
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 2, "Billing check should not apply in non-cloud mode, new backup should be created")
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
@@ -35,58 +35,74 @@ func CreateTestRouter() *gin.Engine {
|
||||
return router
|
||||
}
|
||||
|
||||
func CreateTestBackupCleaner(billingService BillingService) *BackupCleaner {
|
||||
return &BackupCleaner{
|
||||
backupRepository,
|
||||
storages.GetStorageService(),
|
||||
backups_config.GetBackupConfigService(),
|
||||
billingService,
|
||||
encryption.GetFieldEncryptor(),
|
||||
logger.GetLogger(),
|
||||
[]backups_core.BackupRemoveListener{},
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestBackuperNode() *BackuperNode {
|
||||
return &BackuperNode{
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
workspaceService: workspaces_services.GetWorkspaceService(),
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
notificationSender: notifiers.GetNotifierService(),
|
||||
backupCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
logger: logger.GetLogger(),
|
||||
createBackupUseCase: usecases.GetCreateBackupUsecase(),
|
||||
nodeID: uuid.New(),
|
||||
lastHeartbeat: time.Time{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
databases.GetDatabaseService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
taskCancelManager,
|
||||
backupNodesRegistry,
|
||||
logger.GetLogger(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
uuid.New(),
|
||||
time.Time{},
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestBackuperNodeWithUseCase(useCase backups_core.CreateBackupUsecase) *BackuperNode {
|
||||
return &BackuperNode{
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
workspaceService: workspaces_services.GetWorkspaceService(),
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
notificationSender: notifiers.GetNotifierService(),
|
||||
backupCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
logger: logger.GetLogger(),
|
||||
createBackupUseCase: useCase,
|
||||
nodeID: uuid.New(),
|
||||
lastHeartbeat: time.Time{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
databases.GetDatabaseService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
taskCancelManager,
|
||||
backupNodesRegistry,
|
||||
logger.GetLogger(),
|
||||
useCase,
|
||||
uuid.New(),
|
||||
time.Time{},
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestScheduler() *BackupsScheduler {
|
||||
func CreateTestScheduler(billingService BillingService) *BackupsScheduler {
|
||||
return &BackupsScheduler{
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
taskCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
lastBackupTime: time.Now().UTC(),
|
||||
logger: logger.GetLogger(),
|
||||
backupToNodeRelations: make(map[uuid.UUID]BackupToNodeRelation),
|
||||
backuperNode: CreateTestBackuperNode(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
taskCancelManager,
|
||||
backupNodesRegistry,
|
||||
databases.GetDatabaseService(),
|
||||
billingService,
|
||||
time.Now().UTC(),
|
||||
logger.GetLogger(),
|
||||
make(map[uuid.UUID]BackupToNodeRelation),
|
||||
CreateTestBackuperNode(),
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1263,7 +1263,7 @@ func Test_MakeBackup_VerifyBackupAndMetadataFilesExistInStorage(t *testing.T) {
|
||||
backuperCancel := backuping.StartBackuperNodeForTest(t, backuperNode)
|
||||
defer backuping.StopBackuperNodeForTest(t, backuperCancel, backuperNode)
|
||||
|
||||
scheduler := backuping.CreateTestScheduler()
|
||||
scheduler := backuping.CreateTestScheduler(nil)
|
||||
schedulerCancel := backuping.StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
@@ -1838,7 +1838,7 @@ func Test_DeleteBackup_RemovesBackupAndMetadataFilesFromDisk(t *testing.T) {
|
||||
backuperCancel := backuping.StartBackuperNodeForTest(t, backuperNode)
|
||||
defer backuping.StopBackuperNodeForTest(t, backuperCancel, backuperNode)
|
||||
|
||||
scheduler := backuping.CreateTestScheduler()
|
||||
scheduler := backuping.CreateTestScheduler(nil)
|
||||
schedulerCancel := backuping.StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ type BackupConfigController struct {
|
||||
|
||||
func (c *BackupConfigController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
router.POST("/backup-configs/save", c.SaveBackupConfig)
|
||||
router.GET("/backup-configs/database/:id/plan", c.GetDatabasePlan)
|
||||
router.GET("/backup-configs/database/:id", c.GetBackupConfigByDbID)
|
||||
router.GET("/backup-configs/storage/:id/is-using", c.IsStorageUsing)
|
||||
router.GET("/backup-configs/storage/:id/databases-count", c.CountDatabasesForStorage)
|
||||
@@ -93,39 +92,6 @@ func (c *BackupConfigController) GetBackupConfigByDbID(ctx *gin.Context) {
|
||||
ctx.JSON(http.StatusOK, backupConfig)
|
||||
}
|
||||
|
||||
// GetDatabasePlan
|
||||
// @Summary Get database plan by database ID
|
||||
// @Description Get the plan limits for a specific database (max backup size, max total size, max storage period)
|
||||
// @Tags backup-configs
|
||||
// @Produce json
|
||||
// @Param id path string true "Database ID"
|
||||
// @Success 200 {object} plans.DatabasePlan
|
||||
// @Failure 400 {object} map[string]string "Invalid database ID"
|
||||
// @Failure 401 {object} map[string]string "User not authenticated"
|
||||
// @Failure 404 {object} map[string]string "Database not found or access denied"
|
||||
// @Router /backup-configs/database/{id}/plan [get]
|
||||
func (c *BackupConfigController) GetDatabasePlan(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
|
||||
return
|
||||
}
|
||||
|
||||
plan, err := c.backupConfigService.GetDatabasePlan(user, id)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusNotFound, gin.H{"error": "database plan not found"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, plan)
|
||||
}
|
||||
|
||||
// IsStorageUsing
|
||||
// @Summary Check if storage is being used
|
||||
// @Description Check if a storage is currently being used by any backup configuration
|
||||
|
||||
@@ -17,14 +17,12 @@ import (
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/features/storages"
|
||||
local_storage "databasus-backend/internal/features/storages/models/local"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/storage"
|
||||
"databasus-backend/internal/util/period"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
"databasus-backend/internal/util/tools"
|
||||
@@ -326,218 +324,13 @@ func Test_GetBackupConfigByDbID_ReturnsDefaultConfigForNewDatabase(t *testing.T)
|
||||
&response,
|
||||
)
|
||||
|
||||
var plan plans.DatabasePlan
|
||||
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&plan,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, response.DatabaseID)
|
||||
assert.False(t, response.IsBackupsEnabled)
|
||||
assert.Equal(t, plan.MaxStoragePeriod, response.RetentionTimePeriod)
|
||||
assert.Equal(t, plan.MaxBackupSizeMB, response.MaxBackupSizeMB)
|
||||
assert.Equal(t, plan.MaxBackupsTotalSizeMB, response.MaxBackupsTotalSizeMB)
|
||||
assert.True(t, response.IsRetryIfFailed)
|
||||
assert.Equal(t, 3, response.MaxFailedTriesCount)
|
||||
assert.NotNil(t, response.BackupInterval)
|
||||
}
|
||||
|
||||
func Test_GetDatabasePlan_ForNewDatabase_PlanAlwaysReturned(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
var response plans.DatabasePlan
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, response.DatabaseID)
|
||||
assert.NotNil(t, response.MaxBackupSizeMB)
|
||||
assert.NotNil(t, response.MaxBackupsTotalSizeMB)
|
||||
assert.NotEmpty(t, response.MaxStoragePeriod)
|
||||
}
|
||||
|
||||
func Test_SaveBackupConfig_WhenPlanLimitsAreAdjusted_ValidationEnforced(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Get plan via API (triggers auto-creation)
|
||||
var plan plans.DatabasePlan
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&plan,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, plan.DatabaseID)
|
||||
|
||||
// Adjust plan limits directly in database to fixed restrictive values
|
||||
err := storage.GetDb().Model(&plans.DatabasePlan{}).
|
||||
Where("database_id = ?", database.ID).
|
||||
Updates(map[string]any{
|
||||
"max_backup_size_mb": 100,
|
||||
"max_backups_total_size_mb": 1000,
|
||||
"max_storage_period": period.PeriodMonth,
|
||||
}).Error
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test 1: Try to save backup config with exceeded backup size limit
|
||||
timeOfDay := "04:00"
|
||||
backupConfigExceededSize := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
MaxBackupSizeMB: 200, // Exceeds limit of 100
|
||||
MaxBackupsTotalSizeMB: 800,
|
||||
}
|
||||
|
||||
respExceededSize := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner.Token,
|
||||
backupConfigExceededSize,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
assert.Contains(t, string(respExceededSize.Body), "max backup size exceeds plan limit")
|
||||
|
||||
// Test 2: Try to save backup config with exceeded total size limit
|
||||
backupConfigExceededTotal := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
MaxBackupSizeMB: 50,
|
||||
MaxBackupsTotalSizeMB: 2000, // Exceeds limit of 1000
|
||||
}
|
||||
|
||||
respExceededTotal := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner.Token,
|
||||
backupConfigExceededTotal,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
assert.Contains(t, string(respExceededTotal.Body), "max total backups size exceeds plan limit")
|
||||
|
||||
// Test 3: Try to save backup config with exceeded storage period limit
|
||||
backupConfigExceededPeriod := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodYear, // Exceeds limit of Month
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
MaxBackupSizeMB: 80,
|
||||
MaxBackupsTotalSizeMB: 800,
|
||||
}
|
||||
|
||||
respExceededPeriod := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner.Token,
|
||||
backupConfigExceededPeriod,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
assert.Contains(t, string(respExceededPeriod.Body), "storage period exceeds plan limit")
|
||||
|
||||
// Test 4: Save backup config within all limits - should succeed
|
||||
backupConfigValid := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek, // Within Month limit
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
MaxBackupSizeMB: 80, // Within 100 limit
|
||||
MaxBackupsTotalSizeMB: 800, // Within 1000 limit
|
||||
}
|
||||
|
||||
var responseValid BackupConfig
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner.Token,
|
||||
backupConfigValid,
|
||||
http.StatusOK,
|
||||
&responseValid,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, responseValid.DatabaseID)
|
||||
assert.Equal(t, int64(80), responseValid.MaxBackupSizeMB)
|
||||
assert.Equal(t, int64(800), responseValid.MaxBackupsTotalSizeMB)
|
||||
assert.Equal(t, period.PeriodWeek, responseValid.RetentionTimePeriod)
|
||||
}
|
||||
|
||||
func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/features/storages"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/logger"
|
||||
@@ -20,7 +19,6 @@ var (
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
plans.GetDatabasePlanService(),
|
||||
nil,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/features/storages"
|
||||
"databasus-backend/internal/util/period"
|
||||
)
|
||||
@@ -42,11 +41,6 @@ type BackupConfig struct {
|
||||
MaxFailedTriesCount int `json:"maxFailedTriesCount" gorm:"column:max_failed_tries_count;type:int;not null"`
|
||||
|
||||
Encryption BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
|
||||
|
||||
// MaxBackupSizeMB limits individual backup size. 0 = unlimited.
|
||||
MaxBackupSizeMB int64 `json:"maxBackupSizeMb" gorm:"column:max_backup_size_mb;type:int;not null"`
|
||||
// MaxBackupsTotalSizeMB limits total size of all backups. 0 = unlimited.
|
||||
MaxBackupsTotalSizeMB int64 `json:"maxBackupsTotalSizeMb" gorm:"column:max_backups_total_size_mb;type:int;not null"`
|
||||
}
|
||||
|
||||
func (h *BackupConfig) TableName() string {
|
||||
@@ -86,12 +80,12 @@ func (b *BackupConfig) AfterFind(tx *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *BackupConfig) Validate(plan *plans.DatabasePlan) error {
|
||||
func (b *BackupConfig) Validate() error {
|
||||
if b.BackupIntervalID == uuid.Nil && b.BackupInterval == nil {
|
||||
return errors.New("backup interval is required")
|
||||
}
|
||||
|
||||
if err := b.validateRetentionPolicy(plan); err != nil {
|
||||
if err := b.validateRetentionPolicy(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -110,67 +104,38 @@ func (b *BackupConfig) Validate(plan *plans.DatabasePlan) error {
|
||||
}
|
||||
}
|
||||
|
||||
if b.MaxBackupSizeMB < 0 {
|
||||
return errors.New("max backup size must be non-negative")
|
||||
}
|
||||
|
||||
if b.MaxBackupsTotalSizeMB < 0 {
|
||||
return errors.New("max backups total size must be non-negative")
|
||||
}
|
||||
|
||||
if plan.MaxBackupSizeMB > 0 {
|
||||
if b.MaxBackupSizeMB == 0 || b.MaxBackupSizeMB > plan.MaxBackupSizeMB {
|
||||
return errors.New("max backup size exceeds plan limit")
|
||||
}
|
||||
}
|
||||
|
||||
if plan.MaxBackupsTotalSizeMB > 0 {
|
||||
if b.MaxBackupsTotalSizeMB == 0 ||
|
||||
b.MaxBackupsTotalSizeMB > plan.MaxBackupsTotalSizeMB {
|
||||
return errors.New("max total backups size exceeds plan limit")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *BackupConfig) Copy(newDatabaseID uuid.UUID) *BackupConfig {
|
||||
return &BackupConfig{
|
||||
DatabaseID: newDatabaseID,
|
||||
IsBackupsEnabled: b.IsBackupsEnabled,
|
||||
RetentionPolicyType: b.RetentionPolicyType,
|
||||
RetentionTimePeriod: b.RetentionTimePeriod,
|
||||
RetentionCount: b.RetentionCount,
|
||||
RetentionGfsHours: b.RetentionGfsHours,
|
||||
RetentionGfsDays: b.RetentionGfsDays,
|
||||
RetentionGfsWeeks: b.RetentionGfsWeeks,
|
||||
RetentionGfsMonths: b.RetentionGfsMonths,
|
||||
RetentionGfsYears: b.RetentionGfsYears,
|
||||
BackupIntervalID: uuid.Nil,
|
||||
BackupInterval: b.BackupInterval.Copy(),
|
||||
StorageID: b.StorageID,
|
||||
SendNotificationsOn: b.SendNotificationsOn,
|
||||
IsRetryIfFailed: b.IsRetryIfFailed,
|
||||
MaxFailedTriesCount: b.MaxFailedTriesCount,
|
||||
Encryption: b.Encryption,
|
||||
MaxBackupSizeMB: b.MaxBackupSizeMB,
|
||||
MaxBackupsTotalSizeMB: b.MaxBackupsTotalSizeMB,
|
||||
DatabaseID: newDatabaseID,
|
||||
IsBackupsEnabled: b.IsBackupsEnabled,
|
||||
RetentionPolicyType: b.RetentionPolicyType,
|
||||
RetentionTimePeriod: b.RetentionTimePeriod,
|
||||
RetentionCount: b.RetentionCount,
|
||||
RetentionGfsHours: b.RetentionGfsHours,
|
||||
RetentionGfsDays: b.RetentionGfsDays,
|
||||
RetentionGfsWeeks: b.RetentionGfsWeeks,
|
||||
RetentionGfsMonths: b.RetentionGfsMonths,
|
||||
RetentionGfsYears: b.RetentionGfsYears,
|
||||
BackupIntervalID: uuid.Nil,
|
||||
BackupInterval: b.BackupInterval.Copy(),
|
||||
StorageID: b.StorageID,
|
||||
SendNotificationsOn: b.SendNotificationsOn,
|
||||
IsRetryIfFailed: b.IsRetryIfFailed,
|
||||
MaxFailedTriesCount: b.MaxFailedTriesCount,
|
||||
Encryption: b.Encryption,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BackupConfig) validateRetentionPolicy(plan *plans.DatabasePlan) error {
|
||||
func (b *BackupConfig) validateRetentionPolicy() error {
|
||||
switch b.RetentionPolicyType {
|
||||
case RetentionPolicyTypeTimePeriod, "":
|
||||
if b.RetentionTimePeriod == "" {
|
||||
return errors.New("retention time period is required")
|
||||
}
|
||||
|
||||
if plan.MaxStoragePeriod != period.PeriodForever {
|
||||
if b.RetentionTimePeriod.CompareTo(plan.MaxStoragePeriod) > 0 {
|
||||
return errors.New("storage period exceeds plan limit")
|
||||
}
|
||||
}
|
||||
|
||||
case RetentionPolicyTypeCount:
|
||||
if b.RetentionCount <= 0 {
|
||||
return errors.New("retention count must be greater than 0")
|
||||
|
||||
@@ -6,248 +6,34 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/util/period"
|
||||
)
|
||||
|
||||
func Test_Validate_WhenRetentionTimePeriodIsWeekAndPlanAllowsMonth_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodWeek
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodMonth
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenRetentionTimePeriodIsYearAndPlanAllowsMonth_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodYear
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodMonth
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "storage period exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenRetentionTimePeriodIsForeverAndPlanAllowsForever_ValidationPasses(
|
||||
t *testing.T,
|
||||
) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodForever
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodForever
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenRetentionTimePeriodIsForeverAndPlanAllowsYear_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodForever
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodYear
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "storage period exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenRetentionTimePeriodEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodMonth
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodMonth
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenBackupSize100MBAndPlanAllows500MB_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = 100
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupSizeMB = 500
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenBackupSize500MBAndPlanAllows100MB_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = 500
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupSizeMB = 100
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max backup size exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenBackupSizeIsUnlimitedAndPlanAllowsUnlimited_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupSizeMB = 0
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenBackupSizeIsUnlimitedAndPlanHas500MBLimit_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupSizeMB = 500
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max backup size exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenBackupSizeEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = 500
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupSizeMB = 500
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenTotalSize1GBAndPlanAllows5GB_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = 1000
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupsTotalSizeMB = 5000
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenTotalSize5GBAndPlanAllows1GB_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = 5000
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupsTotalSizeMB = 1000
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max total backups size exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenTotalSizeIsUnlimitedAndPlanAllowsUnlimited_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupsTotalSizeMB = 0
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenTotalSizeIsUnlimitedAndPlanHas1GBLimit_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupsTotalSizeMB = 1000
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max total backups size exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenTotalSizeEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = 5000
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupsTotalSizeMB = 5000
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenAllLimitsAreUnlimitedInPlan_AnyConfigurationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodForever
|
||||
config.MaxBackupSizeMB = 0
|
||||
config.MaxBackupsTotalSizeMB = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenMultipleLimitsExceeded_ValidationFailsWithFirstError(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodYear
|
||||
config.MaxBackupSizeMB = 500
|
||||
config.MaxBackupsTotalSizeMB = 5000
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodMonth
|
||||
plan.MaxBackupSizeMB = 100
|
||||
plan.MaxBackupsTotalSizeMB = 1000
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.Error(t, err)
|
||||
assert.EqualError(t, err, "storage period exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenConfigHasInvalidIntervalButPlanIsValid_ValidationFailsOnInterval(
|
||||
t *testing.T,
|
||||
) {
|
||||
func Test_Validate_WhenIntervalIsMissing_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.BackupIntervalID = uuid.Nil
|
||||
config.BackupInterval = nil
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.EqualError(t, err, "backup interval is required")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenIntervalIsMissing_ValidationFailsRegardlessOfPlan(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.BackupIntervalID = uuid.Nil
|
||||
config.BackupInterval = nil
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "backup interval is required")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenRetryEnabledButMaxTriesIsZero_ValidationFailsRegardlessOfPlan(t *testing.T) {
|
||||
func Test_Validate_WhenRetryEnabledButMaxTriesIsZero_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.IsRetryIfFailed = true
|
||||
config.MaxFailedTriesCount = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.EqualError(t, err, "max failed tries count must be greater than 0")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenEncryptionIsInvalid_ValidationFailsRegardlessOfPlan(t *testing.T) {
|
||||
func Test_Validate_WhenEncryptionIsInvalid_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.Encryption = "INVALID"
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.EqualError(t, err, "encryption must be NONE or ENCRYPTED")
|
||||
}
|
||||
|
||||
@@ -255,125 +41,16 @@ func Test_Validate_WhenRetentionTimePeriodIsEmpty_ValidationFails(t *testing.T)
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = ""
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.EqualError(t, err, "retention time period is required")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenMaxBackupSizeIsNegative_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = -100
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max backup size must be non-negative")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenMaxTotalSizeIsNegative_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = -1000
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max backups total size must be non-negative")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenPlanLimitsAreAtBoundary_ValidationWorks(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
configPeriod period.TimePeriod
|
||||
planPeriod period.TimePeriod
|
||||
configSize int64
|
||||
planSize int64
|
||||
configTotal int64
|
||||
planTotal int64
|
||||
shouldSucceed bool
|
||||
}{
|
||||
{
|
||||
name: "all values just under limit",
|
||||
configPeriod: period.PeriodWeek,
|
||||
planPeriod: period.PeriodMonth,
|
||||
configSize: 99,
|
||||
planSize: 100,
|
||||
configTotal: 999,
|
||||
planTotal: 1000,
|
||||
shouldSucceed: true,
|
||||
},
|
||||
{
|
||||
name: "all values equal to limit",
|
||||
configPeriod: period.PeriodMonth,
|
||||
planPeriod: period.PeriodMonth,
|
||||
configSize: 100,
|
||||
planSize: 100,
|
||||
configTotal: 1000,
|
||||
planTotal: 1000,
|
||||
shouldSucceed: true,
|
||||
},
|
||||
{
|
||||
name: "period just over limit",
|
||||
configPeriod: period.Period3Month,
|
||||
planPeriod: period.PeriodMonth,
|
||||
configSize: 100,
|
||||
planSize: 100,
|
||||
configTotal: 1000,
|
||||
planTotal: 1000,
|
||||
shouldSucceed: false,
|
||||
},
|
||||
{
|
||||
name: "size just over limit",
|
||||
configPeriod: period.PeriodMonth,
|
||||
planPeriod: period.PeriodMonth,
|
||||
configSize: 101,
|
||||
planSize: 100,
|
||||
configTotal: 1000,
|
||||
planTotal: 1000,
|
||||
shouldSucceed: false,
|
||||
},
|
||||
{
|
||||
name: "total size just over limit",
|
||||
configPeriod: period.PeriodMonth,
|
||||
planPeriod: period.PeriodMonth,
|
||||
configSize: 100,
|
||||
planSize: 100,
|
||||
configTotal: 1001,
|
||||
planTotal: 1000,
|
||||
shouldSucceed: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = tt.configPeriod
|
||||
config.MaxBackupSizeMB = tt.configSize
|
||||
config.MaxBackupsTotalSizeMB = tt.configTotal
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = tt.planPeriod
|
||||
plan.MaxBackupSizeMB = tt.planSize
|
||||
plan.MaxBackupsTotalSizeMB = tt.planTotal
|
||||
|
||||
err := config.Validate(plan)
|
||||
if tt.shouldSucceed {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Validate_WhenPolicyTypeIsCount_RequiresPositiveCount(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionPolicyType = RetentionPolicyTypeCount
|
||||
config.RetentionCount = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.EqualError(t, err, "retention count must be greater than 0")
|
||||
}
|
||||
|
||||
@@ -382,9 +59,7 @@ func Test_Validate_WhenPolicyTypeIsCount_WithPositiveCount_ValidationPasses(t *t
|
||||
config.RetentionPolicyType = RetentionPolicyTypeCount
|
||||
config.RetentionCount = 10
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -396,9 +71,7 @@ func Test_Validate_WhenPolicyTypeIsGFS_RequiresAtLeastOneField(t *testing.T) {
|
||||
config.RetentionGfsMonths = 0
|
||||
config.RetentionGfsYears = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.EqualError(t, err, "at least one GFS retention field must be greater than 0")
|
||||
}
|
||||
|
||||
@@ -407,9 +80,7 @@ func Test_Validate_WhenPolicyTypeIsGFS_WithOnlyHours_ValidationPasses(t *testing
|
||||
config.RetentionPolicyType = RetentionPolicyTypeGFS
|
||||
config.RetentionGfsHours = 24
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -418,9 +89,7 @@ func Test_Validate_WhenPolicyTypeIsGFS_WithOnlyDays_ValidationPasses(t *testing.
|
||||
config.RetentionPolicyType = RetentionPolicyTypeGFS
|
||||
config.RetentionGfsDays = 7
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -433,9 +102,7 @@ func Test_Validate_WhenPolicyTypeIsGFS_WithAllFields_ValidationPasses(t *testing
|
||||
config.RetentionGfsMonths = 12
|
||||
config.RetentionGfsYears = 3
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -443,35 +110,59 @@ func Test_Validate_WhenPolicyTypeIsInvalid_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionPolicyType = "INVALID"
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.EqualError(t, err, "invalid retention policy type")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenCloudAndEncryptionIsNotEncrypted_ValidationFails(t *testing.T) {
|
||||
enableCloud(t)
|
||||
|
||||
backupConfig := createValidBackupConfig()
|
||||
backupConfig.Encryption = BackupEncryptionNone
|
||||
|
||||
err := backupConfig.Validate()
|
||||
assert.EqualError(t, err, "encryption is mandatory for cloud storage")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenCloudAndEncryptionIsEncrypted_ValidationPasses(t *testing.T) {
|
||||
enableCloud(t)
|
||||
|
||||
backupConfig := createValidBackupConfig()
|
||||
backupConfig.Encryption = BackupEncryptionEncrypted
|
||||
|
||||
err := backupConfig.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenNotCloudAndEncryptionIsNotEncrypted_ValidationPasses(t *testing.T) {
|
||||
backupConfig := createValidBackupConfig()
|
||||
backupConfig.Encryption = BackupEncryptionNone
|
||||
|
||||
err := backupConfig.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func enableCloud(t *testing.T) {
|
||||
t.Helper()
|
||||
config.GetEnv().IsCloud = true
|
||||
t.Cleanup(func() {
|
||||
config.GetEnv().IsCloud = false
|
||||
})
|
||||
}
|
||||
|
||||
func createValidBackupConfig() *BackupConfig {
|
||||
intervalID := uuid.New()
|
||||
return &BackupConfig{
|
||||
DatabaseID: uuid.New(),
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodMonth,
|
||||
BackupIntervalID: intervalID,
|
||||
BackupInterval: &intervals.Interval{ID: intervalID},
|
||||
SendNotificationsOn: []BackupNotificationType{},
|
||||
IsRetryIfFailed: false,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
MaxBackupSizeMB: 100,
|
||||
MaxBackupsTotalSizeMB: 1000,
|
||||
}
|
||||
}
|
||||
|
||||
func createUnlimitedPlan() *plans.DatabasePlan {
|
||||
return &plans.DatabasePlan{
|
||||
DatabaseID: uuid.New(),
|
||||
MaxBackupSizeMB: 0,
|
||||
MaxBackupsTotalSizeMB: 0,
|
||||
MaxStoragePeriod: period.PeriodForever,
|
||||
return &BackupConfig{
|
||||
DatabaseID: uuid.New(),
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodMonth,
|
||||
BackupIntervalID: intervalID,
|
||||
BackupInterval: &intervals.Interval{ID: intervalID},
|
||||
SendNotificationsOn: []BackupNotificationType{},
|
||||
IsRetryIfFailed: false,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,10 +8,10 @@ import (
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_models "databasus-backend/internal/features/users/models"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/period"
|
||||
)
|
||||
|
||||
type BackupConfigService struct {
|
||||
@@ -20,7 +20,6 @@ type BackupConfigService struct {
|
||||
storageService *storages.StorageService
|
||||
notifierService *notifiers.NotifierService
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
databasePlanService *plans.DatabasePlanService
|
||||
|
||||
dbStorageChangeListener BackupConfigStorageChangeListener
|
||||
}
|
||||
@@ -46,12 +45,7 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
|
||||
user *users_models.User,
|
||||
backupConfig *BackupConfig,
|
||||
) (*BackupConfig, error) {
|
||||
plan, err := s.databasePlanService.GetDatabasePlan(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := backupConfig.Validate(plan); err != nil {
|
||||
if err := backupConfig.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -88,12 +82,7 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
|
||||
func (s *BackupConfigService) SaveBackupConfig(
|
||||
backupConfig *BackupConfig,
|
||||
) (*BackupConfig, error) {
|
||||
plan, err := s.databasePlanService.GetDatabasePlan(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := backupConfig.Validate(plan); err != nil {
|
||||
if err := backupConfig.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -131,18 +120,6 @@ func (s *BackupConfigService) GetBackupConfigByDbIdWithAuth(
|
||||
return s.GetBackupConfigByDbId(databaseID)
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) GetDatabasePlan(
|
||||
user *users_models.User,
|
||||
databaseID uuid.UUID,
|
||||
) (*plans.DatabasePlan, error) {
|
||||
_, err := s.databaseService.GetDatabase(user, databaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.databasePlanService.GetDatabasePlan(databaseID)
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) GetBackupConfigByDbId(
|
||||
databaseID uuid.UUID,
|
||||
) (*BackupConfig, error) {
|
||||
@@ -322,20 +299,13 @@ func (s *BackupConfigService) TransferDatabaseToWorkspace(
|
||||
func (s *BackupConfigService) initializeDefaultConfig(
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
plan, err := s.databasePlanService.GetDatabasePlan(databaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
timeOfDay := "04:00"
|
||||
|
||||
_, err = s.backupConfigRepository.Save(&BackupConfig{
|
||||
DatabaseID: databaseID,
|
||||
IsBackupsEnabled: false,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: plan.MaxStoragePeriod,
|
||||
MaxBackupSizeMB: plan.MaxBackupSizeMB,
|
||||
MaxBackupsTotalSizeMB: plan.MaxBackupsTotalSizeMB,
|
||||
_, err := s.backupConfigRepository.Save(&BackupConfig{
|
||||
DatabaseID: databaseID,
|
||||
IsBackupsEnabled: false,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.Period3Month,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
|
||||
305
backend/internal/features/billing/controller.go
Normal file
305
backend/internal/features/billing/controller.go
Normal file
@@ -0,0 +1,305 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
type BillingController struct {
|
||||
billingService *BillingService
|
||||
}
|
||||
|
||||
func (c *BillingController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
billing := router.Group("/billing")
|
||||
|
||||
billing.POST("/subscription", c.CreateSubscription)
|
||||
billing.POST("/subscription/change-storage", c.ChangeSubscriptionStorage)
|
||||
billing.POST("/subscription/portal/:subscription_id", c.GetPortalSession)
|
||||
billing.GET("/subscription/events/:subscription_id", c.GetSubscriptionEvents)
|
||||
billing.GET("/subscription/invoices/:subscription_id", c.GetInvoices)
|
||||
billing.GET("/subscription/:database_id", c.GetSubscription)
|
||||
}
|
||||
|
||||
// CreateSubscription
|
||||
// @Summary Create a new subscription
|
||||
// @Description Create a billing subscription for the specified database with the given storage
|
||||
// @Tags billing
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body CreateSubscriptionRequest true "Subscription creation data"
|
||||
// @Success 200 {object} CreateSubscriptionResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /billing/subscription [post]
|
||||
func (c *BillingController) CreateSubscription(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(401, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var request CreateSubscriptionRequest
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(400, gin.H{"error": "Invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
log := logger.GetLogger().With(
|
||||
"request_id", uuid.New(),
|
||||
"database_id", request.DatabaseID,
|
||||
"user_id", user.ID,
|
||||
)
|
||||
|
||||
transactionID, err := c.billingService.CreateSubscription(
|
||||
log,
|
||||
user,
|
||||
request.DatabaseID,
|
||||
request.StorageGB,
|
||||
)
|
||||
if err != nil {
|
||||
log.Error("Failed to create subscription", "error", err)
|
||||
ctx.JSON(500, gin.H{"error": "Failed to create subscription"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(200, CreateSubscriptionResponse{PaddleTransactionID: transactionID})
|
||||
}
|
||||
|
||||
// ChangeSubscriptionStorage
|
||||
// @Summary Change subscription storage
|
||||
// @Description Update the storage allocation for an existing subscription
|
||||
// @Tags billing
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body ChangeStorageRequest true "New storage configuration"
|
||||
// @Success 200 {object} ChangeStorageResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /billing/subscription/change-storage [post]
|
||||
func (c *BillingController) ChangeSubscriptionStorage(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(401, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var request ChangeStorageRequest
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(400, gin.H{"error": "Invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
log := logger.GetLogger().With(
|
||||
"request_id", uuid.New(),
|
||||
"database_id", request.DatabaseID,
|
||||
"user_id", user.ID,
|
||||
)
|
||||
|
||||
result, err := c.billingService.ChangeSubscriptionStorage(log, user, request.DatabaseID, request.StorageGB)
|
||||
if err != nil {
|
||||
log.Error("Failed to change subscription storage", "error", err)
|
||||
ctx.JSON(500, gin.H{"error": "Failed to change subscription storage"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(200, ChangeStorageResponse{
|
||||
ApplyMode: result.ApplyMode,
|
||||
CurrentGB: result.CurrentGB,
|
||||
PendingGB: result.PendingGB,
|
||||
})
|
||||
}
|
||||
|
||||
// GetPortalSession
|
||||
// @Summary Get billing portal session
|
||||
// @Description Generate a portal session URL for managing the subscription
|
||||
// @Tags billing
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param subscription_id path string true "Subscription ID"
|
||||
// @Success 200 {object} GetPortalSessionResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /billing/subscription/portal/{subscription_id} [post]
|
||||
func (c *BillingController) GetPortalSession(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(401, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
subscriptionID := ctx.Param("subscription_id")
|
||||
if subscriptionID == "" {
|
||||
ctx.JSON(400, gin.H{"error": "Subscription ID is required"})
|
||||
return
|
||||
}
|
||||
|
||||
log := logger.GetLogger().With(
|
||||
"request_id", uuid.New(),
|
||||
"subscription_id", subscriptionID,
|
||||
"user_id", user.ID,
|
||||
)
|
||||
|
||||
url, err := c.billingService.GetPortalURL(log, user, uuid.MustParse(subscriptionID))
|
||||
if err != nil {
|
||||
log.Error("Failed to get portal session", "error", err)
|
||||
ctx.JSON(500, gin.H{"error": "Failed to get portal session"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(200, GetPortalSessionResponse{PortalURL: url})
|
||||
}
|
||||
|
||||
// GetSubscriptionEvents
|
||||
// @Summary Get subscription events
|
||||
// @Description Retrieve the event history for a subscription
|
||||
// @Tags billing
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param subscription_id path string true "Subscription ID"
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Success 200 {object} GetSubscriptionEventsResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /billing/subscription/events/{subscription_id} [get]
|
||||
func (c *BillingController) GetSubscriptionEvents(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(401, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
subscriptionID, err := uuid.Parse(ctx.Param("subscription_id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid subscription ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var request PaginatedRequest
|
||||
if err := ctx.ShouldBindQuery(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
log := logger.GetLogger().With(
|
||||
"request_id", uuid.New(),
|
||||
"subscription_id", subscriptionID,
|
||||
"user_id", user.ID,
|
||||
)
|
||||
|
||||
response, err := c.billingService.GetSubscriptionEvents(log, user, subscriptionID, request.Limit, request.Offset)
|
||||
if err != nil {
|
||||
log.Error("Failed to get subscription events", "error", err)
|
||||
ctx.JSON(500, gin.H{"error": "Failed to get subscription events"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(200, response)
|
||||
}
|
||||
|
||||
// GetInvoices
|
||||
// @Summary Get subscription invoices
|
||||
// @Description Retrieve all invoices for a subscription
|
||||
// @Tags billing
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param subscription_id path string true "Subscription ID"
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Success 200 {object} GetInvoicesResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /billing/subscription/invoices/{subscription_id} [get]
|
||||
func (c *BillingController) GetInvoices(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(401, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
subscriptionID, err := uuid.Parse(ctx.Param("subscription_id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid subscription ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var request PaginatedRequest
|
||||
if err := ctx.ShouldBindQuery(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
log := logger.GetLogger().With(
|
||||
"request_id", uuid.New(),
|
||||
"subscription_id", subscriptionID,
|
||||
"user_id", user.ID,
|
||||
)
|
||||
|
||||
response, err := c.billingService.GetSubscriptionInvoices(log, user, subscriptionID, request.Limit, request.Offset)
|
||||
if err != nil {
|
||||
log.Error("Failed to get invoices", "error", err)
|
||||
ctx.JSON(500, gin.H{"error": "Failed to get invoices"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(200, response)
|
||||
}
|
||||
|
||||
// GetSubscription
|
||||
// @Summary Get subscription by database
|
||||
// @Description Retrieve the subscription associated with a specific database
|
||||
// @Tags billing
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param database_id path string true "Database ID"
|
||||
// @Success 200 {object} billing_models.Subscription
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /billing/subscription/{database_id} [get]
|
||||
func (c *BillingController) GetSubscription(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(401, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
databaseID, err := uuid.Parse(ctx.Param("database_id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid database ID"})
|
||||
return
|
||||
}
|
||||
|
||||
log := logger.GetLogger().With(
|
||||
"request_id", uuid.New(),
|
||||
"database_id", databaseID,
|
||||
"user_id", user.ID,
|
||||
)
|
||||
|
||||
subscription, err := c.billingService.GetSubscriptionByDatabaseID(log, user, databaseID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSubscriptionNotFound) {
|
||||
ctx.JSON(http.StatusNotFound, gin.H{"error": "Subscription not found"})
|
||||
return
|
||||
}
|
||||
|
||||
log.Error("failed to get subscription", "error", err)
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get subscription"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(200, subscription)
|
||||
}
|
||||
1450
backend/internal/features/billing/controller_test.go
Normal file
1450
backend/internal/features/billing/controller_test.go
Normal file
File diff suppressed because it is too large
Load Diff
49
backend/internal/features/billing/di.go
Normal file
49
backend/internal/features/billing/di.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
billing_repositories "databasus-backend/internal/features/billing/repositories"
|
||||
"databasus-backend/internal/features/databases"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
billingService = &BillingService{
|
||||
&billing_repositories.SubscriptionRepository{},
|
||||
&billing_repositories.SubscriptionEventRepository{},
|
||||
&billing_repositories.InvoiceRepository{},
|
||||
nil, // billing provider will be set later to avoid circular dependency
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
*databases.GetDatabaseService(),
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
billingController = &BillingController{billingService}
|
||||
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func GetBillingService() *BillingService {
|
||||
return billingService
|
||||
}
|
||||
|
||||
func GetBillingController() *BillingController {
|
||||
return billingController
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
databases.GetDatabaseService().AddDbCreationListener(billingService)
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("billing.SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
67
backend/internal/features/billing/dto.go
Normal file
67
backend/internal/features/billing/dto.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
billing_models "databasus-backend/internal/features/billing/models"
|
||||
)
|
||||
|
||||
type CreateSubscriptionRequest struct {
|
||||
DatabaseID uuid.UUID `json:"databaseId" validate:"required"`
|
||||
StorageGB int `json:"storageGb" validate:"required,min=1"`
|
||||
}
|
||||
|
||||
type CreateSubscriptionResponse struct {
|
||||
PaddleTransactionID string `json:"paddleTransactionId"`
|
||||
}
|
||||
|
||||
type ChangeStorageApplyMode string
|
||||
|
||||
const (
|
||||
ChangeStorageApplyImmediate ChangeStorageApplyMode = "immediate"
|
||||
ChangeStorageApplyNextCycle ChangeStorageApplyMode = "next_cycle"
|
||||
)
|
||||
|
||||
type ChangeStorageRequest struct {
|
||||
DatabaseID uuid.UUID `json:"databaseId" validate:"required"`
|
||||
StorageGB int `json:"storageGb" validate:"required,min=1"`
|
||||
}
|
||||
|
||||
type ChangeStorageResponse struct {
|
||||
ApplyMode ChangeStorageApplyMode `json:"applyMode"`
|
||||
CurrentGB int `json:"currentGb"`
|
||||
PendingGB *int `json:"pendingGb,omitempty"`
|
||||
}
|
||||
|
||||
type PortalResponse struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type ChangeStorageResult struct {
|
||||
ApplyMode ChangeStorageApplyMode
|
||||
CurrentGB int
|
||||
PendingGB *int
|
||||
}
|
||||
|
||||
type GetPortalSessionResponse struct {
|
||||
PortalURL string `json:"url"`
|
||||
}
|
||||
|
||||
type PaginatedRequest struct {
|
||||
Limit int `form:"limit" json:"limit"`
|
||||
Offset int `form:"offset" json:"offset"`
|
||||
}
|
||||
|
||||
type GetSubscriptionEventsResponse struct {
|
||||
Events []*billing_models.SubscriptionEvent `json:"events"`
|
||||
Total int64 `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
}
|
||||
|
||||
type GetInvoicesResponse struct {
|
||||
Invoices []*billing_models.Invoice `json:"invoices"`
|
||||
Total int64 `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
}
|
||||
15
backend/internal/features/billing/errors.go
Normal file
15
backend/internal/features/billing/errors.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package billing
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrInvalidStorage = errors.New("storage must be between 20 and 10000 GB")
|
||||
ErrAlreadySubscribed = errors.New("database already has an active subscription")
|
||||
ErrExceedsUsage = errors.New("cannot downgrade below current storage usage")
|
||||
ErrNoChange = errors.New("requested storage is the same as current")
|
||||
ErrDuplicate = errors.New("duplicate event already processed")
|
||||
ErrProviderUnavailable = errors.New("payment provider unavailable")
|
||||
ErrNoActiveSubscription = errors.New("no active subscription for this database")
|
||||
ErrAccessDenied = errors.New("user does not have access to this database")
|
||||
ErrSubscriptionNotFound = errors.New("subscription not found")
|
||||
)
|
||||
24
backend/internal/features/billing/models/invoice.go
Normal file
24
backend/internal/features/billing/models/invoice.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package billing_models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Invoice struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
|
||||
SubscriptionID uuid.UUID `json:"subscriptionId" gorm:"column:subscription_id;type:uuid;not null"`
|
||||
ProviderInvoiceID string `json:"providerInvoiceId" gorm:"column:provider_invoice_id;type:text;not null"`
|
||||
AmountCents int64 `json:"amountCents" gorm:"column:amount_cents;type:bigint;not null"`
|
||||
StorageGB int `json:"storageGb" gorm:"column:storage_gb;type:int;not null"`
|
||||
PeriodStart time.Time `json:"periodStart" gorm:"column:period_start;type:timestamptz;not null"`
|
||||
PeriodEnd time.Time `json:"periodEnd" gorm:"column:period_end;type:timestamptz;not null"`
|
||||
Status InvoiceStatus `json:"status" gorm:"column:status;type:text;not null"`
|
||||
PaidAt *time.Time `json:"paidAt,omitempty" gorm:"column:paid_at;type:timestamptz"`
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;type:timestamptz;not null"`
|
||||
}
|
||||
|
||||
func (Invoice) TableName() string {
|
||||
return "invoices"
|
||||
}
|
||||
11
backend/internal/features/billing/models/invoice_status.go
Normal file
11
backend/internal/features/billing/models/invoice_status.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package billing_models
|
||||
|
||||
type InvoiceStatus string
|
||||
|
||||
const (
|
||||
InvoiceStatusPending InvoiceStatus = "pending"
|
||||
InvoiceStatusPaid InvoiceStatus = "paid"
|
||||
InvoiceStatusFailed InvoiceStatus = "failed"
|
||||
InvoiceStatusRefunded InvoiceStatus = "refunded"
|
||||
InvoiceStatusDisputed InvoiceStatus = "disputed"
|
||||
)
|
||||
72
backend/internal/features/billing/models/subscription.go
Normal file
72
backend/internal/features/billing/models/subscription.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package billing_models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
)
|
||||
|
||||
type Subscription struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
|
||||
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;not null"`
|
||||
Status SubscriptionStatus `json:"status" gorm:"column:status;type:text;not null"`
|
||||
|
||||
StorageGB int `json:"storageGb" gorm:"column:storage_gb;type:int;not null"`
|
||||
PendingStorageGB *int `json:"pendingStorageGb,omitempty" gorm:"column:pending_storage_gb;type:int"`
|
||||
|
||||
CurrentPeriodStart time.Time `json:"currentPeriodStart" gorm:"column:current_period_start;type:timestamptz;not null"`
|
||||
CurrentPeriodEnd time.Time `json:"currentPeriodEnd" gorm:"column:current_period_end;type:timestamptz;not null"`
|
||||
CanceledAt *time.Time `json:"canceledAt,omitempty" gorm:"column:canceled_at;type:timestamptz"`
|
||||
|
||||
DataRetentionGracePeriodUntil *time.Time `json:"dataRetentionGracePeriodUntil,omitempty" gorm:"column:data_retention_grace_period_until;type:timestamptz"`
|
||||
|
||||
ProviderName *string `json:"providerName,omitempty" gorm:"column:provider_name;type:text"`
|
||||
ProviderSubID *string `json:"providerSubId,omitempty" gorm:"column:provider_sub_id;type:text"`
|
||||
ProviderCustomerID *string `json:"providerCustomerId,omitempty" gorm:"column:provider_customer_id;type:text"`
|
||||
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;type:timestamptz;not null"`
|
||||
UpdatedAt time.Time `json:"updatedAt" gorm:"column:updated_at;type:timestamptz;not null"`
|
||||
}
|
||||
|
||||
func (Subscription) TableName() string {
|
||||
return "subscriptions"
|
||||
}
|
||||
|
||||
func (s *Subscription) PriceCents() int64 {
|
||||
return int64(s.StorageGB) * config.GetEnv().PricePerGBCents
|
||||
}
|
||||
|
||||
// CanCreateNewBackups - whether it is allowed to create new backups
|
||||
// by scheduler or for user manually. Clarification: in grace period
|
||||
// user can download, delete and restore backups, but cannot create new ones
|
||||
func (s *Subscription) CanCreateNewBackups() bool {
|
||||
switch s.Status {
|
||||
case StatusActive, StatusPastDue:
|
||||
return true
|
||||
case StatusTrial, StatusCanceled:
|
||||
return time.Now().Before(s.CurrentPeriodEnd)
|
||||
case StatusExpired:
|
||||
return false
|
||||
default:
|
||||
panic("unknown subscription status")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Subscription) GetBackupsStorageGB() int {
|
||||
switch s.Status {
|
||||
case StatusActive, StatusPastDue, StatusCanceled:
|
||||
return s.StorageGB
|
||||
case StatusTrial:
|
||||
if time.Now().Before(s.CurrentPeriodEnd) {
|
||||
return s.StorageGB
|
||||
}
|
||||
|
||||
return 0
|
||||
case StatusExpired:
|
||||
return 0
|
||||
default:
|
||||
panic("unknown subscription status")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package billing_models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type SubscriptionEvent struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
|
||||
SubscriptionID uuid.UUID `json:"subscriptionId" gorm:"column:subscription_id;type:uuid;not null"`
|
||||
ProviderEventID *string `json:"providerEventId,omitempty" gorm:"column:provider_event_id;type:text"`
|
||||
Type SubscriptionEventType `json:"type" gorm:"column:type;type:text;not null"`
|
||||
|
||||
OldStorageGB *int `json:"oldStorageGb,omitempty" gorm:"column:old_storage_gb;type:int"`
|
||||
NewStorageGB *int `json:"newStorageGb,omitempty" gorm:"column:new_storage_gb;type:int"`
|
||||
OldStatus *SubscriptionStatus `json:"oldStatus,omitempty" gorm:"column:old_status;type:text"`
|
||||
NewStatus *SubscriptionStatus `json:"newStatus,omitempty" gorm:"column:new_status;type:text"`
|
||||
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;type:timestamptz;not null"`
|
||||
}
|
||||
|
||||
func (SubscriptionEvent) TableName() string {
|
||||
return "subscription_events"
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package billing_models
|
||||
|
||||
type SubscriptionEventType string
|
||||
|
||||
const (
|
||||
EventCreated SubscriptionEventType = "subscription.created"
|
||||
EventUpgraded SubscriptionEventType = "subscription.upgraded"
|
||||
EventDowngraded SubscriptionEventType = "subscription.downgraded"
|
||||
EventNewBillingCycleStarted SubscriptionEventType = "subscription.new_billing_cycle_started"
|
||||
EventCanceled SubscriptionEventType = "subscription.canceled"
|
||||
EventReactivated SubscriptionEventType = "subscription.reactivated"
|
||||
EventExpired SubscriptionEventType = "subscription.expired"
|
||||
EventPastDue SubscriptionEventType = "subscription.past_due"
|
||||
EventRecoveredFromPastDue SubscriptionEventType = "subscription.recovered_from_past_due"
|
||||
EventRefund SubscriptionEventType = "payment.refund"
|
||||
EventDispute SubscriptionEventType = "payment.dispute"
|
||||
)
|
||||
@@ -0,0 +1,11 @@
|
||||
package billing_models
|
||||
|
||||
type SubscriptionStatus string
|
||||
|
||||
const (
|
||||
StatusTrial SubscriptionStatus = "trial" // trial period (~24h after DB creation)
|
||||
StatusActive SubscriptionStatus = "active" // paid, everything works
|
||||
StatusPastDue SubscriptionStatus = "past_due" // payment failed, trying to charge again, but everything still works
|
||||
StatusCanceled SubscriptionStatus = "canceled" // subscription canceled by user or after past_due (grace period is active)
|
||||
StatusExpired SubscriptionStatus = "expired" // grace period ended, data marked for deletion, can come from canceled and trial
|
||||
)
|
||||
22
backend/internal/features/billing/models/webhook_event.go
Normal file
22
backend/internal/features/billing/models/webhook_event.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package billing_models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type WebhookEvent struct {
|
||||
RequestID uuid.UUID
|
||||
ProviderEventID string
|
||||
DatabaseID *uuid.UUID
|
||||
Type WebhookEventType
|
||||
ProviderSubscriptionID string
|
||||
ProviderCustomerID string
|
||||
ProviderInvoiceID string
|
||||
QuantityGB int
|
||||
Status SubscriptionStatus
|
||||
PeriodStart *time.Time
|
||||
PeriodEnd *time.Time
|
||||
AmountCents int64
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package billing_models
|
||||
|
||||
type WebhookEventType string
|
||||
|
||||
const (
|
||||
WHEventSubscriptionCreated WebhookEventType = "subscription.created"
|
||||
WHEventSubscriptionUpdated WebhookEventType = "subscription.updated"
|
||||
WHEventSubscriptionCanceled WebhookEventType = "subscription.canceled"
|
||||
WHEventSubscriptionPastDue WebhookEventType = "subscription.past_due"
|
||||
WHEventSubscriptionReactivated WebhookEventType = "subscription.reactivated"
|
||||
WHEventPaymentSucceeded WebhookEventType = "payment.succeeded"
|
||||
WHEventSubscriptionDisputeCreated WebhookEventType = "dispute.created"
|
||||
)
|
||||
5
backend/internal/features/billing/paddle/README.md
Normal file
5
backend/internal/features/billing/paddle/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
**Paddle hints:**
|
||||
|
||||
- **max_quantity on price:** Paddle limits `quantity` on a price to 100 by default. You need to explicitly set the range (`quantity: {minimum: 20, maximum: 10000}`) when creating a price via API or dashboard. Otherwise requests with quantity > 100 will return an error.
|
||||
- **Full items list on update:** Unlike Stripe, Paddle requires sending **all** subscription items in `PATCH /subscriptions/{id}`, not just the changed ones. `proration_billing_mode` is also required. Without this you can accidentally remove a line item or get a 400.
|
||||
- **Webhook events mapping:** Paddle uses `transaction.completed` instead of `payment.succeeded`, `transaction.payment_failed` instead of `payment.failed`, `adjustment.created` instead of `dispute.created`.
|
||||
83
backend/internal/features/billing/paddle/controller.go
Normal file
83
backend/internal/features/billing/paddle/controller.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package billing_paddle
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
billing_webhooks "databasus-backend/internal/features/billing/webhooks"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
type PaddleBillingController struct {
|
||||
paddleBillingService *PaddleBillingService
|
||||
}
|
||||
|
||||
func (c *PaddleBillingController) RegisterPublicRoutes(router *gin.RouterGroup) {
|
||||
router.POST("/billing/paddle/webhook", c.HandlePaddleWebhook)
|
||||
}
|
||||
|
||||
// HandlePaddleWebhook
|
||||
// @Summary Handle Paddle webhook
|
||||
// @Description Process incoming webhook events from Paddle payment provider
|
||||
// @Tags billing
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param Paddle-Signature header string true "Paddle webhook signature"
|
||||
// @Success 200
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500
|
||||
// @Router /billing/paddle/webhook [post]
|
||||
func (c *PaddleBillingController) HandlePaddleWebhook(ctx *gin.Context) {
|
||||
requestID := uuid.New()
|
||||
log := logger.GetLogger().With("request_id", requestID)
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(ctx.Request.Body, 1<<20))
|
||||
if err != nil {
|
||||
log.Error("failed to read webhook request body", "error", err)
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Failed to read request body"})
|
||||
return
|
||||
}
|
||||
|
||||
headers := make(map[string]string)
|
||||
for k := range ctx.Request.Header {
|
||||
headers[k] = ctx.GetHeader(k)
|
||||
}
|
||||
|
||||
if err := c.paddleBillingService.VerifyWebhookSignature(body, headers); err != nil {
|
||||
log.Warn("paddle webhook signature verification failed", "error", err)
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid webhook signature"})
|
||||
return
|
||||
}
|
||||
|
||||
var webhookDTO PaddleWebhookDTO
|
||||
if err := json.Unmarshal(body, &webhookDTO); err != nil {
|
||||
log.Error("failed to unmarshal webhook payload", "error", err)
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid webhook payload"})
|
||||
return
|
||||
}
|
||||
|
||||
log = log.With(
|
||||
"provider_event_id", webhookDTO.EventID,
|
||||
"event_type", webhookDTO.EventType,
|
||||
)
|
||||
|
||||
if err := c.paddleBillingService.ProcessWebhookEvent(log, requestID, webhookDTO, body); err != nil {
|
||||
if errors.Is(err, billing_webhooks.ErrDuplicateWebhook) {
|
||||
log.Info("duplicate webhook event, returning 200 to not force retry")
|
||||
ctx.Status(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
log.Error("Failed to process paddle webhook", "error", err)
|
||||
ctx.Status(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Status(http.StatusOK)
|
||||
}
|
||||
1056
backend/internal/features/billing/paddle/controller_test.go
Normal file
1056
backend/internal/features/billing/paddle/controller_test.go
Normal file
File diff suppressed because it is too large
Load Diff
72
backend/internal/features/billing/paddle/di.go
Normal file
72
backend/internal/features/billing/paddle/di.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package billing_paddle
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/PaddleHQ/paddle-go-sdk"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/billing"
|
||||
billing_webhooks "databasus-backend/internal/features/billing/webhooks"
|
||||
)
|
||||
|
||||
var (
|
||||
paddleBillingService *PaddleBillingService
|
||||
paddleBillingController *PaddleBillingController
|
||||
initOnce sync.Once
|
||||
)
|
||||
|
||||
func GetPaddleBillingService() *PaddleBillingService {
|
||||
if !config.GetEnv().IsCloud {
|
||||
return nil
|
||||
}
|
||||
|
||||
initOnce.Do(func() {
|
||||
if config.GetEnv().IsPaddleSandbox {
|
||||
paddleClient, err := paddle.NewSandbox(config.GetEnv().PaddleApiKey)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
paddleBillingService = &PaddleBillingService{
|
||||
paddleClient,
|
||||
paddle.NewWebhookVerifier(config.GetEnv().PaddleWebhookSecret),
|
||||
config.GetEnv().PaddlePriceID,
|
||||
billing_webhooks.WebhookRepository{},
|
||||
billing.GetBillingService(),
|
||||
}
|
||||
} else {
|
||||
paddleClient, err := paddle.New(config.GetEnv().PaddleApiKey)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
paddleBillingService = &PaddleBillingService{
|
||||
paddleClient,
|
||||
paddle.NewWebhookVerifier(config.GetEnv().PaddleWebhookSecret),
|
||||
config.GetEnv().PaddlePriceID,
|
||||
billing_webhooks.WebhookRepository{},
|
||||
billing.GetBillingService(),
|
||||
}
|
||||
}
|
||||
|
||||
paddleBillingController = &PaddleBillingController{paddleBillingService}
|
||||
})
|
||||
|
||||
return paddleBillingService
|
||||
}
|
||||
|
||||
func GetPaddleBillingController() *PaddleBillingController {
|
||||
if !config.GetEnv().IsCloud {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure service + controller are initialized
|
||||
GetPaddleBillingService()
|
||||
|
||||
return paddleBillingController
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
billing.GetBillingService().SetBillingProvider(GetPaddleBillingService())
|
||||
}
|
||||
9
backend/internal/features/billing/paddle/dto.go
Normal file
9
backend/internal/features/billing/paddle/dto.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package billing_paddle
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type PaddleWebhookDTO struct {
|
||||
EventID string `json:"event_id"`
|
||||
EventType string `json:"event_type"`
|
||||
Data json.RawMessage
|
||||
}
|
||||
50
backend/internal/features/billing/paddle/dto_test.go
Normal file
50
backend/internal/features/billing/paddle/dto_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package billing_paddle
|
||||
|
||||
import "time"
|
||||
|
||||
type TestSubscriptionCreatedPayload struct {
|
||||
EventID string
|
||||
SubID string
|
||||
CustomerID string
|
||||
DatabaseID string
|
||||
QuantityGB int
|
||||
PeriodStart time.Time
|
||||
PeriodEnd time.Time
|
||||
}
|
||||
|
||||
type TestSubscriptionUpdatedPayload struct {
|
||||
EventID string
|
||||
SubID string
|
||||
CustomerID string
|
||||
QuantityGB int
|
||||
PeriodStart time.Time
|
||||
PeriodEnd time.Time
|
||||
HasScheduledChange bool
|
||||
ScheduledChangeAction string
|
||||
}
|
||||
|
||||
type TestSubscriptionCanceledPayload struct {
|
||||
EventID string
|
||||
SubID string
|
||||
CustomerID string
|
||||
}
|
||||
|
||||
type TestTransactionCompletedPayload struct {
|
||||
EventID string
|
||||
TxnID string
|
||||
SubID string
|
||||
CustomerID string
|
||||
TotalCents int64
|
||||
QuantityGB int
|
||||
PeriodStart time.Time
|
||||
PeriodEnd time.Time
|
||||
}
|
||||
|
||||
type TestSubscriptionPastDuePayload struct {
|
||||
EventID string
|
||||
SubID string
|
||||
CustomerID string
|
||||
QuantityGB int
|
||||
PeriodStart time.Time
|
||||
PeriodEnd time.Time
|
||||
}
|
||||
638
backend/internal/features/billing/paddle/service.go
Normal file
638
backend/internal/features/billing/paddle/service.go
Normal file
@@ -0,0 +1,638 @@
|
||||
package billing_paddle
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/PaddleHQ/paddle-go-sdk"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/features/billing"
|
||||
billing_models "databasus-backend/internal/features/billing/models"
|
||||
billing_provider "databasus-backend/internal/features/billing/provider"
|
||||
billing_webhooks "databasus-backend/internal/features/billing/webhooks"
|
||||
)
|
||||
|
||||
type PaddleBillingService struct {
|
||||
client *paddle.SDK
|
||||
webhookVerified *paddle.WebhookVerifier
|
||||
priceID string
|
||||
webhookRepository billing_webhooks.WebhookRepository
|
||||
billingService *billing.BillingService
|
||||
}
|
||||
|
||||
func (s *PaddleBillingService) GetProviderName() billing_provider.ProviderName {
|
||||
return billing_provider.ProviderPaddle
|
||||
}
|
||||
|
||||
func (s *PaddleBillingService) CreateCheckoutSession(
|
||||
logger *slog.Logger,
|
||||
request billing_provider.CheckoutRequest,
|
||||
) (string, error) {
|
||||
logger = logger.With("database_id", request.DatabaseID)
|
||||
logger.Debug(fmt.Sprintf("paddle: creating checkout session for %d GB", request.StorageGB))
|
||||
|
||||
txRequest := &paddle.CreateTransactionRequest{
|
||||
Items: []paddle.CreateTransactionItems{
|
||||
*paddle.NewCreateTransactionItemsCatalogItem(&paddle.CatalogItem{
|
||||
PriceID: s.priceID,
|
||||
Quantity: request.StorageGB,
|
||||
}),
|
||||
},
|
||||
CustomData: paddle.CustomData{"database_id": request.DatabaseID.String()},
|
||||
Checkout: &paddle.TransactionCheckout{
|
||||
URL: &request.SuccessURL,
|
||||
},
|
||||
}
|
||||
|
||||
tx, err := s.client.CreateTransaction(context.Background(), txRequest)
|
||||
if err != nil {
|
||||
logger.Error("paddle: failed to create transaction", "error", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
return tx.ID, nil
|
||||
}
|
||||
|
||||
func (s *PaddleBillingService) UpgradeQuantityWithSurcharge(
|
||||
logger *slog.Logger,
|
||||
providerSubscriptionID string,
|
||||
quantityGB int,
|
||||
) error {
|
||||
logger = logger.With("provider_subscription_id", providerSubscriptionID)
|
||||
logger.Debug(fmt.Sprintf("paddle: applying upgrade: new storage %d GB", quantityGB))
|
||||
|
||||
// important: paddle requires to send all items
|
||||
// in the subscription when updating, not just the changed one
|
||||
subscription, err := s.client.GetSubscription(context.Background(), &paddle.GetSubscriptionRequest{
|
||||
SubscriptionID: providerSubscriptionID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("paddle: failed to get subscription", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
currentQuantity := subscription.Items[0].Quantity
|
||||
if currentQuantity == quantityGB {
|
||||
logger.Info("paddle: subscription already at requested quantity, skipping upgrade",
|
||||
"current_quantity_gb", currentQuantity,
|
||||
"requested_quantity_gb", quantityGB,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
priceID := subscription.Items[0].Price.ID
|
||||
|
||||
_, err = s.client.UpdateSubscription(context.Background(), &paddle.UpdateSubscriptionRequest{
|
||||
SubscriptionID: providerSubscriptionID,
|
||||
Items: paddle.NewPatchField([]paddle.SubscriptionUpdateCatalogItem{
|
||||
{
|
||||
PriceID: priceID,
|
||||
Quantity: quantityGB,
|
||||
},
|
||||
}),
|
||||
ProrationBillingMode: paddle.NewPatchField(paddle.ProrationBillingModeProratedImmediately),
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("paddle: failed to update subscription", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Debug("paddle: successfully applied upgrade")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PaddleBillingService) ScheduleQuantityDowngradeFromNextBillingCycle(
|
||||
logger *slog.Logger,
|
||||
providerSubscriptionID string,
|
||||
quantityGB int,
|
||||
) error {
|
||||
logger = logger.With("provider_subscription_id", providerSubscriptionID)
|
||||
logger.Debug(fmt.Sprintf("paddle: scheduling downgrade from next billing cycle: new storage %d GB", quantityGB))
|
||||
|
||||
// important: paddle requires to send all items
|
||||
// in the subscription when updating, not just the changed one
|
||||
subscription, err := s.client.GetSubscription(context.Background(), &paddle.GetSubscriptionRequest{
|
||||
SubscriptionID: providerSubscriptionID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("paddle: failed to get subscription", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
currentQuantity := subscription.Items[0].Quantity
|
||||
if currentQuantity == quantityGB {
|
||||
logger.Info("paddle: subscription already at requested quantity, skipping downgrade",
|
||||
"current_quantity_gb", currentQuantity,
|
||||
"requested_quantity_gb", quantityGB,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
if subscription.ScheduledChange != nil {
|
||||
logger.Info("paddle: subscription already has a scheduled change, skipping downgrade")
|
||||
return nil
|
||||
}
|
||||
|
||||
priceID := subscription.Items[0].Price.ID
|
||||
|
||||
// apply downgrade from next billing cycle by setting the proration billing mode to "prorate on next billing period"
|
||||
_, err = s.client.UpdateSubscription(context.Background(), &paddle.UpdateSubscriptionRequest{
|
||||
SubscriptionID: providerSubscriptionID,
|
||||
Items: paddle.NewPatchField([]paddle.SubscriptionUpdateCatalogItem{
|
||||
{
|
||||
PriceID: priceID,
|
||||
Quantity: quantityGB,
|
||||
},
|
||||
}),
|
||||
ProrationBillingMode: paddle.NewPatchField(paddle.ProrationBillingModeFullNextBillingPeriod),
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("paddle: failed to update subscription for downgrade", "error", err)
|
||||
return fmt.Errorf("failed to update subscription: %w", err)
|
||||
}
|
||||
|
||||
logger.Debug("paddle: successfully scheduled downgrade from next billing cycle")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PaddleBillingService) GetSubscription(
|
||||
logger *slog.Logger,
|
||||
providerSubscriptionID string,
|
||||
) (billing_provider.ProviderSubscription, error) {
|
||||
logger = logger.With("provider_subscription_id", providerSubscriptionID)
|
||||
logger.Debug("paddle: getting subscription details")
|
||||
|
||||
subscription, err := s.client.GetSubscription(context.Background(), &paddle.GetSubscriptionRequest{
|
||||
SubscriptionID: providerSubscriptionID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("paddle: failed to get subscription", "error", err)
|
||||
return billing_provider.ProviderSubscription{}, err
|
||||
}
|
||||
|
||||
logger.Debug(
|
||||
fmt.Sprintf(
|
||||
"paddle: successfully got subscription details: status=%s, quantity=%d",
|
||||
subscription.Status,
|
||||
subscription.Items[0].Quantity,
|
||||
),
|
||||
)
|
||||
|
||||
return s.toProviderSubscription(logger, subscription)
|
||||
}
|
||||
|
||||
func (s *PaddleBillingService) CreatePortalSession(
|
||||
logger *slog.Logger,
|
||||
providerCustomerID, returnURL string,
|
||||
) (string, error) {
|
||||
logger = logger.With("provider_customer_id", providerCustomerID)
|
||||
logger.Debug("paddle: creating portal session")
|
||||
|
||||
subscriptions, err := s.client.ListSubscriptions(context.Background(), &paddle.ListSubscriptionsRequest{
|
||||
CustomerID: []string{providerCustomerID},
|
||||
Status: []string{
|
||||
string(paddle.SubscriptionStatusActive),
|
||||
string(paddle.SubscriptionStatusPastDue),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("paddle: failed to list subscriptions for portal session", "error", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
res := subscriptions.Next(context.Background())
|
||||
if !res.Ok() {
|
||||
if res.Err() != nil {
|
||||
logger.Error("paddle: failed to iterate subscriptions", "error", res.Err())
|
||||
return "", res.Err()
|
||||
}
|
||||
|
||||
logger.Error("paddle: no active subscriptions found for customer")
|
||||
return "", fmt.Errorf("no active subscriptions found for customer %s", providerCustomerID)
|
||||
}
|
||||
|
||||
subscription := res.Value()
|
||||
if subscription.ManagementURLs.UpdatePaymentMethod == nil {
|
||||
logger.Error("paddle: subscription has no management URL")
|
||||
return "", fmt.Errorf("subscription %s has no management URL", subscription.ID)
|
||||
}
|
||||
|
||||
return *subscription.ManagementURLs.UpdatePaymentMethod, nil
|
||||
}
|
||||
|
||||
func (s *PaddleBillingService) VerifyWebhookSignature(body []byte, headers map[string]string) error {
|
||||
req, _ := http.NewRequestWithContext(context.Background(), "POST", "/", bytes.NewReader(body))
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
ok, err := s.webhookVerified.Verify(req)
|
||||
if err != nil || !ok {
|
||||
return fmt.Errorf("failed to verify webhook signature: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PaddleBillingService) ProcessWebhookEvent(
|
||||
logger *slog.Logger,
|
||||
requestID uuid.UUID,
|
||||
webhookDTO PaddleWebhookDTO,
|
||||
rawBody []byte,
|
||||
) error {
|
||||
webhookEvent, err := s.normalizeWebhookEvent(
|
||||
logger,
|
||||
requestID,
|
||||
webhookDTO.EventID,
|
||||
webhookDTO.EventType,
|
||||
webhookDTO.Data,
|
||||
)
|
||||
if err != nil {
|
||||
if errors.Is(err, billing_webhooks.ErrUnsupportedEventType) {
|
||||
return s.skipWebhookEvent(logger, requestID, webhookDTO, rawBody)
|
||||
}
|
||||
|
||||
logger.Error("paddle: failed to normalize webhook event", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
logArgs := []any{
|
||||
"provider_event_id", webhookEvent.ProviderEventID,
|
||||
"provider_subscription_id", webhookEvent.ProviderSubscriptionID,
|
||||
"provider_customer_id", webhookEvent.ProviderCustomerID,
|
||||
}
|
||||
if webhookEvent.DatabaseID != nil {
|
||||
logArgs = append(logArgs, "database_id", webhookEvent.DatabaseID)
|
||||
}
|
||||
|
||||
logger = logger.With(logArgs...)
|
||||
|
||||
existingRecord, err := s.webhookRepository.FindSuccessfulByProviderEventID(webhookEvent.ProviderEventID)
|
||||
if err == nil && existingRecord != nil {
|
||||
logger.Info("paddle: webhook already processed successfully, skipping",
|
||||
"existing_request_id", existingRecord.RequestID,
|
||||
)
|
||||
return billing_webhooks.ErrDuplicateWebhook
|
||||
}
|
||||
|
||||
webhookRecord := &billing_webhooks.WebhookRecord{
|
||||
RequestID: requestID,
|
||||
ProviderName: billing_provider.ProviderPaddle,
|
||||
EventType: string(webhookEvent.Type),
|
||||
ProviderEventID: webhookEvent.ProviderEventID,
|
||||
RawPayload: string(rawBody),
|
||||
}
|
||||
if err := s.webhookRepository.Insert(webhookRecord); err != nil {
|
||||
logger.Error("paddle: failed to save webhook record", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.processWebhookEvent(logger, webhookEvent); err != nil {
|
||||
logger.Error("paddle: failed to process webhook event", "error", err)
|
||||
|
||||
if markErr := s.webhookRepository.MarkError(requestID.String(), err.Error()); markErr != nil {
|
||||
logger.Error("paddle: failed to mark webhook as errored", "error", markErr)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
if markErr := s.webhookRepository.MarkProcessed(requestID.String()); markErr != nil {
|
||||
logger.Error("paddle: failed to mark webhook as processed", "error", markErr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PaddleBillingService) skipWebhookEvent(
|
||||
logger *slog.Logger,
|
||||
requestID uuid.UUID,
|
||||
webhookDTO PaddleWebhookDTO,
|
||||
rawBody []byte,
|
||||
) error {
|
||||
webhookRecord := &billing_webhooks.WebhookRecord{
|
||||
RequestID: requestID,
|
||||
ProviderName: billing_provider.ProviderPaddle,
|
||||
EventType: webhookDTO.EventType,
|
||||
ProviderEventID: webhookDTO.EventID,
|
||||
RawPayload: string(rawBody),
|
||||
}
|
||||
|
||||
if err := s.webhookRepository.Insert(webhookRecord); err != nil {
|
||||
logger.Error("paddle: failed to save skipped webhook record", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.webhookRepository.MarkSkipped(requestID.String()); err != nil {
|
||||
logger.Error("paddle: failed to mark webhook as skipped", "error", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PaddleBillingService) processWebhookEvent(
|
||||
logger *slog.Logger,
|
||||
webhookEvent billing_models.WebhookEvent,
|
||||
) error {
|
||||
logger.Debug("processing webhook event")
|
||||
|
||||
// subscription.created - there is no subscription in the database yet
|
||||
if webhookEvent.Type == billing_models.WHEventSubscriptionCreated {
|
||||
return s.billingService.ActivateSubscription(logger, webhookEvent)
|
||||
}
|
||||
|
||||
// dispute - finds subscription via invoice, no provider subscription ID available
|
||||
if webhookEvent.Type == billing_models.WHEventSubscriptionDisputeCreated {
|
||||
return s.billingService.RecordDispute(logger, webhookEvent)
|
||||
}
|
||||
|
||||
// for others - search subscription first
|
||||
subscription, err := s.billingService.GetSubscriptionByProviderSubID(logger, webhookEvent.ProviderSubscriptionID)
|
||||
if err != nil {
|
||||
logger.Error("paddle: failed to find subscription for webhook event", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
logger = logger.With("subscription_id", subscription.ID, "database_id", subscription.DatabaseID)
|
||||
logger.Debug(fmt.Sprintf("found subscription in DB with ID: %s", subscription.ID))
|
||||
|
||||
switch webhookEvent.Type {
|
||||
case billing_models.WHEventSubscriptionUpdated:
|
||||
if subscription.Status == billing_models.StatusCanceled {
|
||||
return s.billingService.ReactivateSubscription(logger, subscription, webhookEvent)
|
||||
}
|
||||
|
||||
return s.billingService.SyncSubscriptionFromProvider(logger, subscription, webhookEvent)
|
||||
case billing_models.WHEventSubscriptionCanceled:
|
||||
return s.billingService.CancelSubscription(logger, subscription, webhookEvent)
|
||||
case billing_models.WHEventPaymentSucceeded:
|
||||
return s.billingService.RecordPaymentSuccess(logger, subscription, webhookEvent)
|
||||
case billing_models.WHEventSubscriptionPastDue:
|
||||
return s.billingService.RecordPaymentFailed(logger, subscription, webhookEvent)
|
||||
default:
|
||||
logger.Error(fmt.Sprintf("unhandled webhook event type: %s", string(webhookEvent.Type)))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PaddleBillingService) normalizeWebhookEvent(
|
||||
logger *slog.Logger,
|
||||
requestID uuid.UUID,
|
||||
eventID, eventType string,
|
||||
data json.RawMessage,
|
||||
) (billing_models.WebhookEvent, error) {
|
||||
webhookEvent := billing_models.WebhookEvent{
|
||||
RequestID: requestID,
|
||||
ProviderEventID: eventID,
|
||||
}
|
||||
|
||||
switch eventType {
|
||||
case "subscription.created":
|
||||
webhookEvent.Type = billing_models.WHEventSubscriptionCreated
|
||||
|
||||
var subscription paddle.Subscription
|
||||
if err := json.Unmarshal(data, &subscription); err != nil {
|
||||
logger.Error("paddle: failed to unmarshal subscription.created webhook data", "error", err)
|
||||
return webhookEvent, err
|
||||
}
|
||||
|
||||
webhookEvent.ProviderSubscriptionID = subscription.ID
|
||||
webhookEvent.ProviderCustomerID = subscription.CustomerID
|
||||
webhookEvent.QuantityGB = subscription.Items[0].Quantity
|
||||
status, err := mapPaddleStatus(logger, subscription.Status)
|
||||
if err != nil {
|
||||
return webhookEvent, err
|
||||
}
|
||||
|
||||
webhookEvent.Status = status
|
||||
|
||||
if subscription.CurrentBillingPeriod != nil {
|
||||
periodStart := mustParseRFC3339(logger, "period start", subscription.CurrentBillingPeriod.StartsAt)
|
||||
periodEnd := mustParseRFC3339(logger, "period end", subscription.CurrentBillingPeriod.EndsAt)
|
||||
|
||||
webhookEvent.PeriodStart = &periodStart
|
||||
webhookEvent.PeriodEnd = &periodEnd
|
||||
}
|
||||
|
||||
if subscription.CustomData == nil || subscription.CustomData["database_id"] == "" {
|
||||
logger.Error("paddle: subscription has no database_id in custom data")
|
||||
}
|
||||
|
||||
databaseIDStr, isOk := subscription.CustomData["database_id"].(string)
|
||||
if !isOk {
|
||||
logger.Error("paddle: database_id in custom data is not a string")
|
||||
return webhookEvent, fmt.Errorf("invalid database_id type in custom data")
|
||||
}
|
||||
|
||||
databaseID := uuid.MustParse(databaseIDStr)
|
||||
webhookEvent.DatabaseID = &databaseID
|
||||
|
||||
case "subscription.updated":
|
||||
var subscription paddle.Subscription
|
||||
if err := json.Unmarshal(data, &subscription); err != nil {
|
||||
return webhookEvent, err
|
||||
}
|
||||
|
||||
webhookEvent.ProviderSubscriptionID = subscription.ID
|
||||
webhookEvent.ProviderCustomerID = subscription.CustomerID
|
||||
webhookEvent.QuantityGB = subscription.Items[0].Quantity
|
||||
webhookEvent.Type = billing_models.WHEventSubscriptionUpdated
|
||||
|
||||
status, err := mapPaddleStatus(logger, subscription.Status)
|
||||
if err != nil {
|
||||
return webhookEvent, err
|
||||
}
|
||||
|
||||
webhookEvent.Status = status
|
||||
|
||||
if subscription.CurrentBillingPeriod != nil {
|
||||
periodStart := mustParseRFC3339(logger, "period start", subscription.CurrentBillingPeriod.StartsAt)
|
||||
periodEnd := mustParseRFC3339(logger, "period end", subscription.CurrentBillingPeriod.EndsAt)
|
||||
|
||||
webhookEvent.PeriodStart = &periodStart
|
||||
webhookEvent.PeriodEnd = &periodEnd
|
||||
}
|
||||
|
||||
if subscription.ScheduledChange != nil &&
|
||||
subscription.ScheduledChange.Action == paddle.ScheduledChangeActionCancel {
|
||||
webhookEvent.Type = billing_models.WHEventSubscriptionCanceled
|
||||
}
|
||||
|
||||
case "subscription.canceled":
|
||||
webhookEvent.Type = billing_models.WHEventSubscriptionCanceled
|
||||
|
||||
var subscription paddle.Subscription
|
||||
if err := json.Unmarshal(data, &subscription); err != nil {
|
||||
return webhookEvent, err
|
||||
}
|
||||
|
||||
webhookEvent.ProviderSubscriptionID = subscription.ID
|
||||
webhookEvent.ProviderCustomerID = subscription.CustomerID
|
||||
|
||||
status, err := mapPaddleStatus(logger, subscription.Status)
|
||||
if err != nil {
|
||||
return webhookEvent, err
|
||||
}
|
||||
|
||||
webhookEvent.Status = status
|
||||
|
||||
case "transaction.completed":
|
||||
webhookEvent.Type = billing_models.WHEventPaymentSucceeded
|
||||
|
||||
var transaction paddle.Transaction
|
||||
if err := json.Unmarshal(data, &transaction); err != nil {
|
||||
return webhookEvent, err
|
||||
}
|
||||
|
||||
webhookEvent.ProviderInvoiceID = transaction.ID
|
||||
|
||||
if len(transaction.Items) > 0 {
|
||||
webhookEvent.QuantityGB = transaction.Items[0].Quantity
|
||||
}
|
||||
|
||||
if transaction.SubscriptionID != nil {
|
||||
webhookEvent.ProviderSubscriptionID = *transaction.SubscriptionID
|
||||
}
|
||||
|
||||
if transaction.CustomerID != nil {
|
||||
webhookEvent.ProviderCustomerID = *transaction.CustomerID
|
||||
}
|
||||
|
||||
amountCents, err := strconv.ParseInt(transaction.Details.Totals.Total, 10, 64)
|
||||
if err != nil {
|
||||
logger.Error("paddle: failed to parse transaction total", "error", err)
|
||||
} else {
|
||||
webhookEvent.AmountCents = amountCents
|
||||
}
|
||||
|
||||
if transaction.BillingPeriod != nil {
|
||||
periodStart := mustParseRFC3339(logger, "period start", transaction.BillingPeriod.StartsAt)
|
||||
periodEnd := mustParseRFC3339(logger, "period end", transaction.BillingPeriod.EndsAt)
|
||||
|
||||
webhookEvent.PeriodStart = &periodStart
|
||||
webhookEvent.PeriodEnd = &periodEnd
|
||||
}
|
||||
|
||||
case "subscription.past_due":
|
||||
webhookEvent.Type = billing_models.WHEventSubscriptionPastDue
|
||||
|
||||
var subscription paddle.Subscription
|
||||
if err := json.Unmarshal(data, &subscription); err != nil {
|
||||
return webhookEvent, err
|
||||
}
|
||||
|
||||
webhookEvent.ProviderSubscriptionID = subscription.ID
|
||||
webhookEvent.ProviderCustomerID = subscription.CustomerID
|
||||
webhookEvent.QuantityGB = subscription.Items[0].Quantity
|
||||
|
||||
status, err := mapPaddleStatus(logger, subscription.Status)
|
||||
if err != nil {
|
||||
return webhookEvent, err
|
||||
}
|
||||
|
||||
webhookEvent.Status = status
|
||||
|
||||
if subscription.CurrentBillingPeriod != nil {
|
||||
periodStart := mustParseRFC3339(logger, "period start", subscription.CurrentBillingPeriod.StartsAt)
|
||||
periodEnd := mustParseRFC3339(logger, "period end", subscription.CurrentBillingPeriod.EndsAt)
|
||||
|
||||
webhookEvent.PeriodStart = &periodStart
|
||||
webhookEvent.PeriodEnd = &periodEnd
|
||||
}
|
||||
|
||||
case "adjustment.created":
|
||||
webhookEvent.Type = billing_models.WHEventSubscriptionDisputeCreated
|
||||
|
||||
var adjustment struct {
|
||||
TransactionID string `json:"transaction_id"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &adjustment); err != nil {
|
||||
return webhookEvent, err
|
||||
}
|
||||
|
||||
webhookEvent.ProviderInvoiceID = adjustment.TransactionID
|
||||
|
||||
default:
|
||||
logger.Debug("unsupported paddle event type, skipping", "event_type", eventType)
|
||||
return webhookEvent, billing_webhooks.ErrUnsupportedEventType
|
||||
}
|
||||
|
||||
return webhookEvent, nil
|
||||
}
|
||||
|
||||
func (s *PaddleBillingService) toProviderSubscription(
|
||||
logger *slog.Logger,
|
||||
paddleSubscription *paddle.Subscription,
|
||||
) (billing_provider.ProviderSubscription, error) {
|
||||
status, err := mapPaddleStatus(logger, paddleSubscription.Status)
|
||||
if err != nil {
|
||||
return billing_provider.ProviderSubscription{}, err
|
||||
}
|
||||
|
||||
if len(paddleSubscription.Items) == 0 {
|
||||
return billing_provider.ProviderSubscription{}, fmt.Errorf(
|
||||
"paddle subscription %s has no items",
|
||||
paddleSubscription.ID,
|
||||
)
|
||||
}
|
||||
|
||||
providerSubscription := &billing_provider.ProviderSubscription{
|
||||
ProviderSubscriptionID: paddleSubscription.ID,
|
||||
ProviderCustomerID: paddleSubscription.CustomerID,
|
||||
Status: status,
|
||||
QuantityGB: paddleSubscription.Items[0].Quantity,
|
||||
}
|
||||
|
||||
if paddleSubscription.CurrentBillingPeriod != nil {
|
||||
providerSubscription.PeriodStart = mustParseRFC3339(
|
||||
logger,
|
||||
"period start",
|
||||
paddleSubscription.CurrentBillingPeriod.StartsAt,
|
||||
)
|
||||
providerSubscription.PeriodEnd = mustParseRFC3339(
|
||||
logger,
|
||||
"period end",
|
||||
paddleSubscription.CurrentBillingPeriod.EndsAt,
|
||||
)
|
||||
}
|
||||
|
||||
return *providerSubscription, nil
|
||||
}
|
||||
|
||||
func mustParseRFC3339(logger *slog.Logger, label, value string) time.Time {
|
||||
parsed, err := time.Parse(time.RFC3339, value)
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("paddle: failed to parse %s", label), "error", err)
|
||||
}
|
||||
|
||||
return parsed
|
||||
}
|
||||
|
||||
func mapPaddleStatus(logger *slog.Logger, s paddle.SubscriptionStatus) (billing_models.SubscriptionStatus, error) {
|
||||
switch s {
|
||||
case paddle.SubscriptionStatusActive:
|
||||
return billing_models.StatusActive, nil
|
||||
case paddle.SubscriptionStatusPastDue:
|
||||
return billing_models.StatusPastDue, nil
|
||||
case paddle.SubscriptionStatusCanceled:
|
||||
return billing_models.StatusCanceled, nil
|
||||
case paddle.SubscriptionStatusTrialing:
|
||||
return billing_models.StatusTrial, nil
|
||||
case paddle.SubscriptionStatusPaused:
|
||||
return billing_models.StatusCanceled, nil
|
||||
default:
|
||||
logger.Error(fmt.Sprintf("paddle: unknown subscription status: %s", string(s)))
|
||||
|
||||
return "", fmt.Errorf("paddle: unknown subscription status: %s", string(s))
|
||||
}
|
||||
}
|
||||
38
backend/internal/features/billing/provider/dto.go
Normal file
38
backend/internal/features/billing/provider/dto.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package billing_provider
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
billing_models "databasus-backend/internal/features/billing/models"
|
||||
)
|
||||
|
||||
type CreateSubscriptionRequest struct {
|
||||
ProviderCustomerID string
|
||||
DatabaseID uuid.UUID
|
||||
StorageGB int
|
||||
}
|
||||
|
||||
type ProviderSubscription struct {
|
||||
ProviderSubscriptionID string
|
||||
ProviderCustomerID string
|
||||
Status billing_models.SubscriptionStatus
|
||||
QuantityGB int
|
||||
PeriodStart time.Time
|
||||
PeriodEnd time.Time
|
||||
}
|
||||
|
||||
type CheckoutRequest struct {
|
||||
DatabaseID uuid.UUID
|
||||
Email string
|
||||
StorageGB int
|
||||
SuccessURL string
|
||||
CancelURL string
|
||||
}
|
||||
|
||||
type ProviderName string
|
||||
|
||||
const (
|
||||
ProviderPaddle ProviderName = "paddle"
|
||||
)
|
||||
21
backend/internal/features/billing/provider/provider.go
Normal file
21
backend/internal/features/billing/provider/provider.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package billing_provider
|
||||
|
||||
import "log/slog"
|
||||
|
||||
type BillingProvider interface {
|
||||
GetProviderName() ProviderName
|
||||
|
||||
UpgradeQuantityWithSurcharge(logger *slog.Logger, providerSubscriptionID string, quantityGB int) error
|
||||
|
||||
ScheduleQuantityDowngradeFromNextBillingCycle(
|
||||
logger *slog.Logger,
|
||||
providerSubscriptionID string,
|
||||
quantityGB int,
|
||||
) error
|
||||
|
||||
GetSubscription(logger *slog.Logger, providerSubscriptionID string) (ProviderSubscription, error)
|
||||
|
||||
CreateCheckoutSession(logger *slog.Logger, req CheckoutRequest) (checkoutURL string, err error)
|
||||
|
||||
CreatePortalSession(logger *slog.Logger, providerCustomerID, returnURL string) (portalURL string, err error)
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
package billing_repositories
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
|
||||
billing_models "databasus-backend/internal/features/billing/models"
|
||||
"databasus-backend/internal/storage"
|
||||
)
|
||||
|
||||
type InvoiceRepository struct{}
|
||||
|
||||
func (r *InvoiceRepository) Save(invoice billing_models.Invoice) error {
|
||||
if invoice.SubscriptionID == uuid.Nil {
|
||||
return errors.New("subscription id is required")
|
||||
}
|
||||
|
||||
db := storage.GetDb()
|
||||
|
||||
if invoice.ID == uuid.Nil {
|
||||
invoice.ID = uuid.New()
|
||||
return db.Create(&invoice).Error
|
||||
}
|
||||
|
||||
return db.Save(invoice).Error
|
||||
}
|
||||
|
||||
func (r *InvoiceRepository) FindByProviderInvID(providerInvoiceID string) (*billing_models.Invoice, error) {
|
||||
var invoice billing_models.Invoice
|
||||
|
||||
if err := storage.GetDb().Where("provider_invoice_id = ?", providerInvoiceID).
|
||||
First(&invoice).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &invoice, nil
|
||||
}
|
||||
|
||||
func (r *InvoiceRepository) FindByDatabaseID(
|
||||
databaseID uuid.UUID,
|
||||
limit, offset int,
|
||||
) ([]*billing_models.Invoice, error) {
|
||||
var invoices []*billing_models.Invoice
|
||||
|
||||
if err := storage.GetDb().Joins("JOIN subscriptions ON subscriptions.id = invoices.subscription_id").
|
||||
Where("subscriptions.database_id = ?", databaseID).
|
||||
Order("invoices.created_at DESC").
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&invoices).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invoices, nil
|
||||
}
|
||||
|
||||
func (r *InvoiceRepository) CountByDatabaseID(databaseID uuid.UUID) (int64, error) {
|
||||
var count int64
|
||||
|
||||
err := storage.GetDb().Model(&billing_models.Invoice{}).
|
||||
Joins("JOIN subscriptions ON subscriptions.id = invoices.subscription_id").
|
||||
Where("subscriptions.database_id = ?", databaseID).
|
||||
Count(&count).Error
|
||||
|
||||
return count, err
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
package billing_repositories
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
billing_models "databasus-backend/internal/features/billing/models"
|
||||
"databasus-backend/internal/storage"
|
||||
)
|
||||
|
||||
type SubscriptionEventRepository struct{}
|
||||
|
||||
func (r *SubscriptionEventRepository) Create(event billing_models.SubscriptionEvent) error {
|
||||
if event.SubscriptionID == uuid.Nil {
|
||||
return errors.New("subscription id is required")
|
||||
}
|
||||
|
||||
event.ID = uuid.New()
|
||||
return storage.GetDb().Create(&event).Error
|
||||
}
|
||||
|
||||
func (r *SubscriptionEventRepository) FindByDatabaseID(
|
||||
databaseID uuid.UUID,
|
||||
limit, offset int,
|
||||
) ([]*billing_models.SubscriptionEvent, error) {
|
||||
var events []*billing_models.SubscriptionEvent
|
||||
|
||||
if err := storage.GetDb().Joins("JOIN subscriptions ON subscriptions.id = subscription_events.subscription_id").
|
||||
Where("subscriptions.database_id = ?", databaseID).
|
||||
Order("subscription_events.created_at DESC").
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&events).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
func (r *SubscriptionEventRepository) CountByDatabaseID(databaseID uuid.UUID) (int64, error) {
|
||||
var count int64
|
||||
|
||||
err := storage.GetDb().Model(&billing_models.SubscriptionEvent{}).
|
||||
Joins("JOIN subscriptions ON subscriptions.id = subscription_events.subscription_id").
|
||||
Where("subscriptions.database_id = ?", databaseID).
|
||||
Count(&count).Error
|
||||
|
||||
return count, err
|
||||
}
|
||||
@@ -0,0 +1,123 @@
|
||||
package billing_repositories
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
|
||||
billing_models "databasus-backend/internal/features/billing/models"
|
||||
"databasus-backend/internal/storage"
|
||||
)
|
||||
|
||||
type SubscriptionRepository struct{}
|
||||
|
||||
func (r *SubscriptionRepository) Save(sub billing_models.Subscription) error {
|
||||
db := storage.GetDb()
|
||||
|
||||
if sub.ID == uuid.Nil {
|
||||
sub.ID = uuid.New()
|
||||
return db.Create(&sub).Error
|
||||
}
|
||||
|
||||
return db.Save(&sub).Error
|
||||
}
|
||||
|
||||
func (r *SubscriptionRepository) FindByID(id uuid.UUID) (*billing_models.Subscription, error) {
|
||||
var sub billing_models.Subscription
|
||||
|
||||
if err := storage.GetDb().Where("id = ?", id).First(&sub).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
func (r *SubscriptionRepository) FindByDatabaseIDAndStatuses(
|
||||
databaseID uuid.UUID,
|
||||
stauses []billing_models.SubscriptionStatus,
|
||||
) ([]*billing_models.Subscription, error) {
|
||||
var subs []*billing_models.Subscription
|
||||
|
||||
if err := storage.GetDb().Where("database_id = ? AND status IN ?", databaseID, stauses).
|
||||
Find(&subs).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return subs, nil
|
||||
}
|
||||
|
||||
func (r *SubscriptionRepository) FindLatestByDatabaseID(databaseID uuid.UUID) (*billing_models.Subscription, error) {
|
||||
var sub billing_models.Subscription
|
||||
|
||||
if err := storage.GetDb().
|
||||
Where("database_id = ?", databaseID).
|
||||
Order("created_at DESC").
|
||||
First(&sub).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
func (r *SubscriptionRepository) FindByProviderSubID(providerSubID string) (*billing_models.Subscription, error) {
|
||||
var sub billing_models.Subscription
|
||||
|
||||
if err := storage.GetDb().Where("provider_sub_id = ?", providerSubID).
|
||||
First(&sub).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
func (r *SubscriptionRepository) FindByStatuses(
|
||||
statuses []billing_models.SubscriptionStatus,
|
||||
) ([]billing_models.Subscription, error) {
|
||||
var subs []billing_models.Subscription
|
||||
|
||||
if err := storage.GetDb().Where("status IN ?", statuses).Find(&subs).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return subs, nil
|
||||
}
|
||||
|
||||
func (r *SubscriptionRepository) FindCanceledWithEndedGracePeriod(
|
||||
now time.Time,
|
||||
) ([]billing_models.Subscription, error) {
|
||||
var subs []billing_models.Subscription
|
||||
|
||||
if err := storage.GetDb().
|
||||
Where("status = ? AND data_retention_grace_period_until < ?", billing_models.StatusCanceled, now).
|
||||
Find(&subs).
|
||||
Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return subs, nil
|
||||
}
|
||||
|
||||
func (r *SubscriptionRepository) FindExpiredTrials(now time.Time) ([]billing_models.Subscription, error) {
|
||||
var subs []billing_models.Subscription
|
||||
|
||||
if err := storage.GetDb().Where("status = ? AND current_period_end < ?", billing_models.StatusTrial, now).
|
||||
Find(&subs).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return subs, nil
|
||||
}
|
||||
1261
backend/internal/features/billing/service.go
Normal file
1261
backend/internal/features/billing/service.go
Normal file
File diff suppressed because it is too large
Load Diff
8
backend/internal/features/billing/webhooks/errors.go
Normal file
8
backend/internal/features/billing/webhooks/errors.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package billing_webhooks
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrDuplicateWebhook = errors.New("duplicate webhook event")
|
||||
ErrUnsupportedEventType = errors.New("unsupported webhook event type")
|
||||
)
|
||||
25
backend/internal/features/billing/webhooks/model.go
Normal file
25
backend/internal/features/billing/webhooks/model.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package billing_webhooks
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
billing_provider "databasus-backend/internal/features/billing/provider"
|
||||
)
|
||||
|
||||
type WebhookRecord struct {
|
||||
RequestID uuid.UUID `gorm:"column:request_id;primaryKey;type:uuid;default:gen_random_uuid()"`
|
||||
ProviderName billing_provider.ProviderName `gorm:"column:provider_name;type:text;not null"`
|
||||
EventType string `gorm:"column:event_type;type:text;not null"`
|
||||
ProviderEventID string `gorm:"column:provider_event_id;type:text;not null;index"`
|
||||
RawPayload string `gorm:"column:raw_payload;type:text;not null"`
|
||||
ProcessedAt *time.Time `gorm:"column:processed_at"`
|
||||
IsSkipped bool `gorm:"column:is_skipped;not null;default:false"`
|
||||
Error *string `gorm:"column:error"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;not null"`
|
||||
}
|
||||
|
||||
func (WebhookRecord) TableName() string {
|
||||
return "webhook_records"
|
||||
}
|
||||
73
backend/internal/features/billing/webhooks/repository.go
Normal file
73
backend/internal/features/billing/webhooks/repository.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package billing_webhooks
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"databasus-backend/internal/storage"
|
||||
)
|
||||
|
||||
type WebhookRepository struct{}
|
||||
|
||||
func (r *WebhookRepository) FindSuccessfulByProviderEventID(providerEventID string) (*WebhookRecord, error) {
|
||||
var record WebhookRecord
|
||||
|
||||
err := storage.GetDb().
|
||||
Where("provider_event_id = ? AND processed_at IS NOT NULL", providerEventID).
|
||||
First(&record).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
func (r *WebhookRepository) Insert(record *WebhookRecord) error {
|
||||
if record.ProviderEventID == "" {
|
||||
return errors.New("provider event ID is required")
|
||||
}
|
||||
|
||||
record.CreatedAt = time.Now().UTC()
|
||||
|
||||
return storage.GetDb().Create(record).Error
|
||||
}
|
||||
|
||||
func (r *WebhookRepository) MarkProcessed(requestID string) error {
|
||||
now := time.Now().UTC()
|
||||
|
||||
return storage.
|
||||
GetDb().
|
||||
Model(&WebhookRecord{}).
|
||||
Where("request_id = ?", requestID).
|
||||
Update("processed_at", now).
|
||||
Error
|
||||
}
|
||||
|
||||
func (r *WebhookRepository) MarkSkipped(requestID string) error {
|
||||
now := time.Now().UTC()
|
||||
|
||||
return storage.
|
||||
GetDb().
|
||||
Model(&WebhookRecord{}).
|
||||
Where("request_id = ?", requestID).
|
||||
Updates(map[string]any{
|
||||
"is_skipped": true,
|
||||
"processed_at": now,
|
||||
}).
|
||||
Error
|
||||
}
|
||||
|
||||
func (r *WebhookRepository) MarkError(requestID, errMsg string) error {
|
||||
return storage.
|
||||
GetDb().
|
||||
Model(&WebhookRecord{}).
|
||||
Where("request_id = ?", requestID).
|
||||
Update("error", errMsg).
|
||||
Error
|
||||
}
|
||||
@@ -1328,6 +1328,143 @@ func getTestPostgresConfig() *postgresql.PostgresqlDatabase {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateDatabase_WhenCloudAndUserIsNotReadOnly_ReturnsBadRequest(t *testing.T) {
|
||||
enableCloud(t)
|
||||
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Cloud Not ReadOnly", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
request := Database{
|
||||
Name: "Cloud Non-ReadOnly DB",
|
||||
WorkspaceID: &workspace.ID,
|
||||
Type: DatabaseTypePostgres,
|
||||
Postgresql: getTestPostgresConfig(),
|
||||
}
|
||||
|
||||
resp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/databases/create",
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(resp.Body), "in cloud mode, only read-only database users are allowed")
|
||||
}
|
||||
|
||||
func Test_CreateDatabase_WhenCloudAndUserIsReadOnly_DatabaseCreated(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Cloud ReadOnly", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Temp DB for RO User", workspace.ID, owner.Token, router)
|
||||
|
||||
readOnlyUser := createReadOnlyUserViaAPI(t, router, database.ID, owner.Token)
|
||||
assert.NotEmpty(t, readOnlyUser.Username)
|
||||
assert.NotEmpty(t, readOnlyUser.Password)
|
||||
|
||||
RemoveTestDatabase(database)
|
||||
|
||||
enableCloud(t)
|
||||
|
||||
pgConfig := getTestPostgresConfig()
|
||||
pgConfig.Username = readOnlyUser.Username
|
||||
pgConfig.Password = readOnlyUser.Password
|
||||
|
||||
request := Database{
|
||||
Name: "Cloud ReadOnly DB",
|
||||
WorkspaceID: &workspace.ID,
|
||||
Type: DatabaseTypePostgres,
|
||||
Postgresql: pgConfig,
|
||||
}
|
||||
|
||||
var response Database
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/databases/create",
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusCreated,
|
||||
&response,
|
||||
)
|
||||
defer RemoveTestDatabase(&response)
|
||||
|
||||
assert.Equal(t, "Cloud ReadOnly DB", response.Name)
|
||||
assert.NotEqual(t, uuid.Nil, response.ID)
|
||||
}
|
||||
|
||||
func Test_CreateDatabase_WhenNotCloudAndUserIsNotReadOnly_DatabaseCreated(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Non-Cloud", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
request := Database{
|
||||
Name: "Non-Cloud DB",
|
||||
WorkspaceID: &workspace.ID,
|
||||
Type: DatabaseTypePostgres,
|
||||
Postgresql: getTestPostgresConfig(),
|
||||
}
|
||||
|
||||
var response Database
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/databases/create",
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusCreated,
|
||||
&response,
|
||||
)
|
||||
defer RemoveTestDatabase(&response)
|
||||
|
||||
assert.Equal(t, "Non-Cloud DB", response.Name)
|
||||
assert.NotEqual(t, uuid.Nil, response.ID)
|
||||
}
|
||||
|
||||
func enableCloud(t *testing.T) {
|
||||
t.Helper()
|
||||
config.GetEnv().IsCloud = true
|
||||
t.Cleanup(func() {
|
||||
config.GetEnv().IsCloud = false
|
||||
})
|
||||
}
|
||||
|
||||
func createReadOnlyUserViaAPI(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
databaseID uuid.UUID,
|
||||
token string,
|
||||
) *CreateReadOnlyUserResponse {
|
||||
var database Database
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/databases/%s", databaseID.String()),
|
||||
"Bearer "+token,
|
||||
http.StatusOK,
|
||||
&database,
|
||||
)
|
||||
|
||||
var response CreateReadOnlyUserResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/databases/create-readonly-user",
|
||||
"Bearer "+token,
|
||||
database,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
return &response
|
||||
}
|
||||
|
||||
func getTestMariadbConfig() *mariadb.MariadbDatabase {
|
||||
env := config.GetEnv()
|
||||
portStr := env.TestMariadb1011Port
|
||||
|
||||
@@ -81,8 +81,8 @@ func (p *PostgresqlDatabase) Validate() error {
|
||||
p.BackupType = PostgresBackupTypePgDump
|
||||
}
|
||||
|
||||
if p.BackupType == PostgresBackupTypePgDump && config.GetEnv().IsCloud {
|
||||
return errors.New("PG_DUMP backup type is not supported in cloud mode")
|
||||
if p.BackupType != PostgresBackupTypePgDump && config.GetEnv().IsCloud {
|
||||
return errors.New("only PG_DUMP backup type is supported in cloud mode")
|
||||
}
|
||||
|
||||
if p.BackupType == PostgresBackupTypePgDump {
|
||||
|
||||
@@ -1310,6 +1310,46 @@ func Test_Validate_WhenRequiredFieldsMissing_ReturnsError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Validate_WhenCloudAndBackupTypeIsNotPgDump_ValidationFails(t *testing.T) {
|
||||
enableCloud(t)
|
||||
|
||||
model := &PostgresqlDatabase{
|
||||
Host: "example.com",
|
||||
Port: 5432,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
CpuCount: 1,
|
||||
BackupType: PostgresBackupTypeWalV1,
|
||||
}
|
||||
|
||||
err := model.Validate()
|
||||
assert.EqualError(t, err, "only PG_DUMP backup type is supported in cloud mode")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenCloudAndBackupTypeIsPgDump_ValidationPasses(t *testing.T) {
|
||||
enableCloud(t)
|
||||
|
||||
model := &PostgresqlDatabase{
|
||||
Host: "example.com",
|
||||
Port: 5432,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
CpuCount: 1,
|
||||
BackupType: PostgresBackupTypePgDump,
|
||||
}
|
||||
|
||||
err := model.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func enableCloud(t *testing.T) {
|
||||
t.Helper()
|
||||
config.GetEnv().IsCloud = true
|
||||
t.Cleanup(func() {
|
||||
config.GetEnv().IsCloud = false
|
||||
})
|
||||
}
|
||||
|
||||
type PostgresContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
package plans
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var databasePlanRepository = &DatabasePlanRepository{}
|
||||
|
||||
var databasePlanService = &DatabasePlanService{
|
||||
databasePlanRepository,
|
||||
logger.GetLogger(),
|
||||
}
|
||||
|
||||
func GetDatabasePlanService() *DatabasePlanService {
|
||||
return databasePlanService
|
||||
}
|
||||
|
||||
func GetDatabasePlanRepository() *DatabasePlanRepository {
|
||||
return databasePlanRepository
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
package plans
|
||||
|
||||
import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/util/period"
|
||||
)
|
||||
|
||||
type DatabasePlan struct {
|
||||
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;primaryKey;not null"`
|
||||
|
||||
MaxBackupSizeMB int64 `json:"maxBackupSizeMb" gorm:"column:max_backup_size_mb;type:int;not null"`
|
||||
MaxBackupsTotalSizeMB int64 `json:"maxBackupsTotalSizeMb" gorm:"column:max_backups_total_size_mb;type:int;not null"`
|
||||
MaxStoragePeriod period.TimePeriod `json:"maxStoragePeriod" gorm:"column:max_storage_period;type:text;not null"`
|
||||
}
|
||||
|
||||
func (p *DatabasePlan) TableName() string {
|
||||
return "database_plans"
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
package plans
|
||||
|
||||
import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/storage"
|
||||
)
|
||||
|
||||
type DatabasePlanRepository struct{}
|
||||
|
||||
func (r *DatabasePlanRepository) GetDatabasePlan(databaseID uuid.UUID) (*DatabasePlan, error) {
|
||||
var databasePlan DatabasePlan
|
||||
|
||||
if err := storage.GetDb().Where("database_id = ?", databaseID).First(&databasePlan).Error; err != nil {
|
||||
if err.Error() == "record not found" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &databasePlan, nil
|
||||
}
|
||||
|
||||
func (r *DatabasePlanRepository) CreateDatabasePlan(databasePlan *DatabasePlan) error {
|
||||
return storage.GetDb().Create(&databasePlan).Error
|
||||
}
|
||||
@@ -1,68 +0,0 @@
|
||||
package plans
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/util/period"
|
||||
)
|
||||
|
||||
type DatabasePlanService struct {
|
||||
databasePlanRepository *DatabasePlanRepository
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *DatabasePlanService) GetDatabasePlan(databaseID uuid.UUID) (*DatabasePlan, error) {
|
||||
plan, err := s.databasePlanRepository.GetDatabasePlan(databaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if plan == nil {
|
||||
s.logger.Info("no database plan found, creating default plan", "databaseID", databaseID)
|
||||
|
||||
defaultPlan := s.createDefaultDatabasePlan(databaseID)
|
||||
|
||||
err := s.databasePlanRepository.CreateDatabasePlan(defaultPlan)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to create default database plan", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return defaultPlan, nil
|
||||
}
|
||||
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
func (s *DatabasePlanService) createDefaultDatabasePlan(databaseID uuid.UUID) *DatabasePlan {
|
||||
var plan DatabasePlan
|
||||
|
||||
isCloud := config.GetEnv().IsCloud
|
||||
if isCloud {
|
||||
s.logger.Info("creating default database plan for cloud", "databaseID", databaseID)
|
||||
|
||||
// for playground we set limited storages enough to test,
|
||||
// but not too expensive to provide it for Databasus
|
||||
plan = DatabasePlan{
|
||||
DatabaseID: databaseID,
|
||||
MaxBackupSizeMB: 100, // ~ 1.5GB database
|
||||
MaxBackupsTotalSizeMB: 4000, // ~ 30 daily backups + 10 manual backups
|
||||
MaxStoragePeriod: period.PeriodWeek,
|
||||
}
|
||||
} else {
|
||||
s.logger.Info("creating default database plan for self hosted", "databaseID", databaseID)
|
||||
|
||||
// by default - everything is unlimited in self hosted mode
|
||||
plan = DatabasePlan{
|
||||
DatabaseID: databaseID,
|
||||
MaxBackupSizeMB: 0,
|
||||
MaxBackupsTotalSizeMB: 0,
|
||||
MaxStoragePeriod: period.PeriodForever,
|
||||
}
|
||||
}
|
||||
|
||||
return &plan
|
||||
}
|
||||
@@ -775,7 +775,123 @@ func cleanupDatabaseWithBackup(database *databases.Database, backup *backups_cor
|
||||
}
|
||||
}
|
||||
|
||||
func Test_RestoreBackup_WhenCloudAndCpuCountMoreThanOne_ReturnsBadRequest(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
defer cleanupDatabaseWithBackup(database, backup)
|
||||
|
||||
enableCloud(t)
|
||||
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
CpuCount: 4,
|
||||
},
|
||||
}
|
||||
|
||||
testResp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "multi-thread restore is not supported in cloud mode")
|
||||
}
|
||||
|
||||
func Test_RestoreBackup_WhenCloudAndCpuCountIsOne_RestoreInitiated(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
|
||||
_, cleanup := SetupMockRestoreNode(t)
|
||||
defer cleanup()
|
||||
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
defer cleanupDatabaseWithBackup(database, backup)
|
||||
|
||||
enableCloud(t)
|
||||
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
CpuCount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
testResp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "restore started successfully")
|
||||
}
|
||||
|
||||
func Test_RestoreBackup_WhenNotCloudAndCpuCountMoreThanOne_RestoreInitiated(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
|
||||
_, cleanup := SetupMockRestoreNode(t)
|
||||
defer cleanup()
|
||||
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
defer cleanupDatabaseWithBackup(database, backup)
|
||||
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
CpuCount: 4,
|
||||
},
|
||||
}
|
||||
|
||||
testResp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "restore started successfully")
|
||||
}
|
||||
|
||||
func cleanupBackup(backup *backups_core.Backup) {
|
||||
repo := &backups_core.BackupRepository{}
|
||||
repo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
func enableCloud(t *testing.T) {
|
||||
t.Helper()
|
||||
env_config.GetEnv().IsCloud = true
|
||||
t.Cleanup(func() {
|
||||
env_config.GetEnv().IsCloud = false
|
||||
})
|
||||
}
|
||||
|
||||
@@ -129,11 +129,14 @@ func (s *RestoreService) RestoreBackupWithAuth(
|
||||
return err
|
||||
}
|
||||
|
||||
if config.GetEnv().IsCloud {
|
||||
// in cloud mode we use only single thread mode,
|
||||
// because otherwise we will exhaust local storage
|
||||
// space (instead of streaming from S3 directly to DB)
|
||||
requestDTO.PostgresqlDatabase.CpuCount = 1
|
||||
if config.GetEnv().IsCloud && requestDTO.PostgresqlDatabase != nil &&
|
||||
requestDTO.PostgresqlDatabase.CpuCount > 1 {
|
||||
s.logger.Warn("restore rejected: multi-thread mode not supported in cloud",
|
||||
"requested_cpu_count", requestDTO.PostgresqlDatabase.CpuCount)
|
||||
|
||||
return errors.New(
|
||||
"multi-thread restore is not supported in cloud mode, only single thread (CPU=1) is allowed",
|
||||
)
|
||||
}
|
||||
|
||||
if err := s.validateVersionCompatibility(backupDatabase, requestDTO); err != nil {
|
||||
|
||||
@@ -21,13 +21,19 @@ func NewMultiHandler(
|
||||
}
|
||||
|
||||
func (h *MultiHandler) Enabled(ctx context.Context, level slog.Level) bool {
|
||||
if h.victoriaLogsWriter != nil {
|
||||
return level >= slog.LevelDebug
|
||||
}
|
||||
|
||||
return h.stdoutHandler.Enabled(ctx, level)
|
||||
}
|
||||
|
||||
func (h *MultiHandler) Handle(ctx context.Context, record slog.Record) error {
|
||||
// Send to stdout handler
|
||||
if err := h.stdoutHandler.Handle(ctx, record); err != nil {
|
||||
return err
|
||||
// Send to stdout handler (only if level is enabled for stdout)
|
||||
if h.stdoutHandler.Enabled(ctx, record.Level) {
|
||||
if err := h.stdoutHandler.Handle(ctx, record); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Send to VictoriaLogs if configured
|
||||
|
||||
102
backend/migrations/20260326130504_add_billing.sql
Normal file
102
backend/migrations/20260326130504_add_billing.sql
Normal file
@@ -0,0 +1,102 @@
|
||||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
|
||||
CREATE TABLE subscriptions (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
database_id UUID NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
storage_gb INT NOT NULL,
|
||||
pending_storage_gb INT,
|
||||
current_period_start TIMESTAMPTZ NOT NULL,
|
||||
current_period_end TIMESTAMPTZ NOT NULL,
|
||||
canceled_at TIMESTAMPTZ,
|
||||
data_retention_grace_period_until TIMESTAMPTZ,
|
||||
provider_name TEXT,
|
||||
provider_sub_id TEXT,
|
||||
provider_customer_id TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX idx_subscriptions_database_id ON subscriptions (database_id);
|
||||
CREATE INDEX idx_subscriptions_status ON subscriptions (status);
|
||||
CREATE INDEX idx_subscriptions_provider_sub_id ON subscriptions (provider_sub_id);
|
||||
|
||||
CREATE TABLE invoices (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
subscription_id UUID NOT NULL,
|
||||
provider_invoice_id TEXT NOT NULL,
|
||||
amount_cents BIGINT NOT NULL,
|
||||
storage_gb INT NOT NULL,
|
||||
period_start TIMESTAMPTZ NOT NULL,
|
||||
period_end TIMESTAMPTZ NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
paid_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
ALTER TABLE invoices
|
||||
ADD CONSTRAINT fk_invoices_subscription_id
|
||||
FOREIGN KEY (subscription_id)
|
||||
REFERENCES subscriptions (id)
|
||||
ON DELETE CASCADE;
|
||||
|
||||
CREATE INDEX idx_invoices_subscription_id ON invoices (subscription_id);
|
||||
CREATE INDEX idx_invoices_provider_invoice_id ON invoices (provider_invoice_id);
|
||||
|
||||
CREATE TABLE subscription_events (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
subscription_id UUID NOT NULL,
|
||||
provider_event_id TEXT,
|
||||
type TEXT NOT NULL,
|
||||
old_storage_gb INT,
|
||||
new_storage_gb INT,
|
||||
old_status TEXT,
|
||||
new_status TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
ALTER TABLE subscription_events
|
||||
ADD CONSTRAINT fk_subscription_events_subscription_id
|
||||
FOREIGN KEY (subscription_id)
|
||||
REFERENCES subscriptions (id)
|
||||
ON DELETE CASCADE;
|
||||
|
||||
CREATE INDEX idx_subscription_events_subscription_id ON subscription_events (subscription_id);
|
||||
|
||||
CREATE TABLE webhook_records (
|
||||
request_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
provider_name TEXT NOT NULL,
|
||||
event_type TEXT NOT NULL,
|
||||
provider_event_id TEXT NOT NULL,
|
||||
raw_payload TEXT NOT NULL,
|
||||
processed_at TIMESTAMPTZ,
|
||||
error TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX idx_webhook_records_provider_event_id ON webhook_records (provider_event_id);
|
||||
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
|
||||
DROP INDEX IF EXISTS idx_webhook_records_provider_event_id;
|
||||
DROP TABLE IF EXISTS webhook_records;
|
||||
|
||||
DROP INDEX IF EXISTS idx_subscription_events_subscription_id;
|
||||
ALTER TABLE subscription_events DROP CONSTRAINT IF EXISTS fk_subscription_events_subscription_id;
|
||||
DROP TABLE IF EXISTS subscription_events;
|
||||
|
||||
DROP INDEX IF EXISTS idx_invoices_provider_invoice_id;
|
||||
DROP INDEX IF EXISTS idx_invoices_subscription_id;
|
||||
ALTER TABLE invoices DROP CONSTRAINT IF EXISTS fk_invoices_subscription_id;
|
||||
DROP TABLE IF EXISTS invoices;
|
||||
|
||||
DROP INDEX IF EXISTS idx_subscriptions_provider_sub_id;
|
||||
DROP INDEX IF EXISTS idx_subscriptions_status;
|
||||
DROP INDEX IF EXISTS idx_subscriptions_database_id;
|
||||
DROP TABLE IF EXISTS subscriptions;
|
||||
|
||||
-- +goose StatementEnd
|
||||
@@ -0,0 +1,39 @@
|
||||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
|
||||
ALTER TABLE backup_configs
|
||||
DROP COLUMN IF EXISTS max_backup_size_mb,
|
||||
DROP COLUMN IF EXISTS max_backups_total_size_mb;
|
||||
|
||||
DROP INDEX IF EXISTS idx_database_plans_database_id;
|
||||
|
||||
ALTER TABLE database_plans
|
||||
DROP CONSTRAINT IF EXISTS fk_database_plans_database_id;
|
||||
|
||||
DROP TABLE IF EXISTS database_plans;
|
||||
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
|
||||
ALTER TABLE backup_configs
|
||||
ADD COLUMN max_backup_size_mb BIGINT NOT NULL DEFAULT 0,
|
||||
ADD COLUMN max_backups_total_size_mb BIGINT NOT NULL DEFAULT 0;
|
||||
|
||||
CREATE TABLE database_plans (
|
||||
database_id UUID PRIMARY KEY,
|
||||
max_backup_size_mb BIGINT NOT NULL,
|
||||
max_backups_total_size_mb BIGINT NOT NULL,
|
||||
max_storage_period TEXT NOT NULL
|
||||
);
|
||||
|
||||
ALTER TABLE database_plans
|
||||
ADD CONSTRAINT fk_database_plans_database_id
|
||||
FOREIGN KEY (database_id)
|
||||
REFERENCES databases (id)
|
||||
ON DELETE CASCADE;
|
||||
|
||||
CREATE INDEX idx_database_plans_database_id ON database_plans (database_id);
|
||||
|
||||
-- +goose StatementEnd
|
||||
@@ -0,0 +1,11 @@
|
||||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
ALTER TABLE webhook_records
|
||||
ADD COLUMN is_skipped BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
ALTER TABLE webhook_records
|
||||
DROP COLUMN is_skipped;
|
||||
-- +goose StatementEnd
|
||||
@@ -2,5 +2,8 @@ MODE=development
|
||||
VITE_GITHUB_CLIENT_ID=
|
||||
VITE_GOOGLE_CLIENT_ID=
|
||||
VITE_IS_EMAIL_CONFIGURED=false
|
||||
VITE_CLOUDFLARE_TURNSTILE_SITE_KEY=
|
||||
VITE_IS_CLOUD=false
|
||||
VITE_CLOUDFLARE_TURNSTILE_SITE_KEY=
|
||||
VITE_CLOUD_PRICE_PER_GB=
|
||||
VITE_CLOUD_PADDLE_CLIENT_TOKEN=
|
||||
VITE_CLOUD_IS_PADDLE_SANDBOX=true
|
||||
@@ -5,6 +5,7 @@ import { Routes } from 'react-router';
|
||||
|
||||
import { useVersionCheck } from './shared/hooks/useVersionCheck';
|
||||
|
||||
import { IS_CLOUD, IS_PADDLE_SANDBOX, PADDLE_CLIENT_TOKEN } from './constants';
|
||||
import { userApi } from './entity/users';
|
||||
import { AuthPageComponent } from './pages/AuthPageComponent';
|
||||
import { OAuthCallbackPage } from './pages/OAuthCallbackPage';
|
||||
@@ -18,6 +19,18 @@ function AppContent() {
|
||||
|
||||
useVersionCheck();
|
||||
|
||||
useEffect(() => {
|
||||
if (IS_CLOUD && PADDLE_CLIENT_TOKEN) {
|
||||
Paddle.Environment.set(IS_PADDLE_SANDBOX ? 'sandbox' : 'production');
|
||||
Paddle.Initialize({
|
||||
token: PADDLE_CLIENT_TOKEN,
|
||||
eventCallback: (event) => {
|
||||
window.dispatchEvent(new CustomEvent('paddle-event', { detail: event }));
|
||||
},
|
||||
});
|
||||
}
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
const isAuthorized = userApi.isAuthorized();
|
||||
setIsAuthorized(isAuthorized);
|
||||
|
||||
@@ -5,6 +5,9 @@ interface RuntimeConfig {
|
||||
IS_EMAIL_CONFIGURED?: string;
|
||||
CLOUDFLARE_TURNSTILE_SITE_KEY?: string;
|
||||
CONTAINER_ARCH?: string;
|
||||
CLOUD_PRICE_PER_GB?: string;
|
||||
CLOUD_PADDLE_CLIENT_TOKEN?: string;
|
||||
CLOUD_IS_PADDLE_SANDBOX?: string;
|
||||
}
|
||||
|
||||
declare global {
|
||||
@@ -31,6 +34,10 @@ export const APP_VERSION = (import.meta.env.VITE_APP_VERSION as string) || 'dev'
|
||||
export const IS_CLOUD =
|
||||
window.__RUNTIME_CONFIG__?.IS_CLOUD === 'true' || import.meta.env.VITE_IS_CLOUD === 'true';
|
||||
|
||||
export const CLOUD_PRICE_PER_GB = Number(
|
||||
window.__RUNTIME_CONFIG__?.CLOUD_PRICE_PER_GB || import.meta.env.VITE_CLOUD_PRICE_PER_GB || '0',
|
||||
);
|
||||
|
||||
export const GITHUB_CLIENT_ID =
|
||||
window.__RUNTIME_CONFIG__?.GITHUB_CLIENT_ID || import.meta.env.VITE_GITHUB_CLIENT_ID || '';
|
||||
|
||||
@@ -46,6 +53,15 @@ export const CLOUDFLARE_TURNSTILE_SITE_KEY =
|
||||
import.meta.env.VITE_CLOUDFLARE_TURNSTILE_SITE_KEY ||
|
||||
'';
|
||||
|
||||
export const PADDLE_CLIENT_TOKEN =
|
||||
window.__RUNTIME_CONFIG__?.CLOUD_PADDLE_CLIENT_TOKEN ||
|
||||
import.meta.env.VITE_CLOUD_PADDLE_CLIENT_TOKEN ||
|
||||
'';
|
||||
|
||||
export const IS_PADDLE_SANDBOX =
|
||||
window.__RUNTIME_CONFIG__?.CLOUD_IS_PADDLE_SANDBOX === 'true' ||
|
||||
import.meta.env.VITE_CLOUD_IS_PADDLE_SANDBOX === 'true';
|
||||
|
||||
const archMap: Record<string, string> = { amd64: 'x64', arm64: 'arm64' };
|
||||
const rawArch = window.__RUNTIME_CONFIG__?.CONTAINER_ARCH || 'unknown';
|
||||
export const CONTAINER_ARCH = archMap[rawArch] || rawArch;
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { getApplicationServer } from '../../../constants';
|
||||
import RequestOptions from '../../../shared/api/RequestOptions';
|
||||
import { apiHelper } from '../../../shared/api/apiHelper';
|
||||
import type { DatabasePlan } from '../../plan';
|
||||
import type { BackupConfig } from '../model/BackupConfig';
|
||||
import type { TransferDatabaseRequest } from '../model/TransferDatabaseRequest';
|
||||
|
||||
@@ -55,12 +54,4 @@ export const backupConfigApi = {
|
||||
requestOptions,
|
||||
);
|
||||
},
|
||||
|
||||
async getDatabasePlan(databaseId: string) {
|
||||
return apiHelper.fetchGetJson<DatabasePlan>(
|
||||
`${getApplicationServer()}/api/v1/backup-configs/database/${databaseId}/plan`,
|
||||
undefined,
|
||||
true,
|
||||
);
|
||||
},
|
||||
};
|
||||
|
||||
@@ -8,4 +8,3 @@ export { BackupEncryption } from './model/BackupEncryption';
|
||||
export { PgWalBackupType } from './model/PgWalBackupType';
|
||||
export { RetentionPolicyType } from './model/RetentionPolicyType';
|
||||
export type { TransferDatabaseRequest } from './model/TransferDatabaseRequest';
|
||||
export type { DatabasePlan } from '../plan';
|
||||
|
||||
@@ -25,7 +25,4 @@ export interface BackupConfig {
|
||||
isRetryIfFailed: boolean;
|
||||
maxFailedTriesCount: number;
|
||||
encryption: BackupEncryption;
|
||||
|
||||
maxBackupSizeMb: number;
|
||||
maxBackupsTotalSizeMb: number;
|
||||
}
|
||||
|
||||
66
frontend/src/entity/billing/api/billingApi.ts
Normal file
66
frontend/src/entity/billing/api/billingApi.ts
Normal file
@@ -0,0 +1,66 @@
|
||||
import { getApplicationServer } from '../../../constants';
|
||||
import RequestOptions from '../../../shared/api/RequestOptions';
|
||||
import { apiHelper } from '../../../shared/api/apiHelper';
|
||||
import type { ChangeStorageResponse } from '../model/ChangeStorageResponse';
|
||||
import type { GetInvoicesResponse } from '../model/GetInvoicesResponse';
|
||||
import type { GetSubscriptionEventsResponse } from '../model/GetSubscriptionEventsResponse';
|
||||
import type { Subscription } from '../model/Subscription';
|
||||
|
||||
export const billingApi = {
|
||||
async createSubscription(databaseId: string, storageGb: number) {
|
||||
const requestOptions = new RequestOptions();
|
||||
requestOptions.setBody(JSON.stringify({ databaseId, storageGb }));
|
||||
|
||||
return apiHelper.fetchPostJson<{ paddleTransactionId: string }>(
|
||||
`${getApplicationServer()}/api/v1/billing/subscription`,
|
||||
requestOptions,
|
||||
);
|
||||
},
|
||||
|
||||
async changeStorage(databaseId: string, storageGb: number) {
|
||||
const requestOptions = new RequestOptions();
|
||||
requestOptions.setBody(JSON.stringify({ databaseId, storageGb }));
|
||||
|
||||
return apiHelper.fetchPostJson<ChangeStorageResponse>(
|
||||
`${getApplicationServer()}/api/v1/billing/subscription/change-storage`,
|
||||
requestOptions,
|
||||
);
|
||||
},
|
||||
|
||||
async getPortalSession(subscriptionId: string) {
|
||||
return apiHelper.fetchPostJson<{ url: string }>(
|
||||
`${getApplicationServer()}/api/v1/billing/subscription/portal/${subscriptionId}`,
|
||||
new RequestOptions(),
|
||||
);
|
||||
},
|
||||
|
||||
async getSubscriptionEvents(subscriptionId: string, limit?: number, offset?: number) {
|
||||
const params = new URLSearchParams();
|
||||
if (limit !== undefined) params.append('limit', limit.toString());
|
||||
if (offset !== undefined) params.append('offset', offset.toString());
|
||||
|
||||
const query = params.toString();
|
||||
const url = `${getApplicationServer()}/api/v1/billing/subscription/events/${subscriptionId}${query ? `?${query}` : ''}`;
|
||||
|
||||
return apiHelper.fetchGetJson<GetSubscriptionEventsResponse>(url, undefined, true);
|
||||
},
|
||||
|
||||
async getInvoices(subscriptionId: string, limit?: number, offset?: number) {
|
||||
const params = new URLSearchParams();
|
||||
if (limit !== undefined) params.append('limit', limit.toString());
|
||||
if (offset !== undefined) params.append('offset', offset.toString());
|
||||
|
||||
const query = params.toString();
|
||||
const url = `${getApplicationServer()}/api/v1/billing/subscription/invoices/${subscriptionId}${query ? `?${query}` : ''}`;
|
||||
|
||||
return apiHelper.fetchGetJson<GetInvoicesResponse>(url, undefined, true);
|
||||
},
|
||||
|
||||
async getSubscription(databaseId: string) {
|
||||
return apiHelper.fetchGetJson<Subscription>(
|
||||
`${getApplicationServer()}/api/v1/billing/subscription/${databaseId}`,
|
||||
undefined,
|
||||
true,
|
||||
);
|
||||
},
|
||||
};
|
||||
11
frontend/src/entity/billing/index.ts
Normal file
11
frontend/src/entity/billing/index.ts
Normal file
@@ -0,0 +1,11 @@
|
||||
export { billingApi } from './api/billingApi';
|
||||
export { SubscriptionStatus } from './model/SubscriptionStatus';
|
||||
export type { Subscription } from './model/Subscription';
|
||||
export { InvoiceStatus } from './model/InvoiceStatus';
|
||||
export type { Invoice } from './model/Invoice';
|
||||
export { SubscriptionEventType } from './model/SubscriptionEventType';
|
||||
export type { SubscriptionEvent } from './model/SubscriptionEvent';
|
||||
export { ChangeStorageApplyMode } from './model/ChangeStorageApplyMode';
|
||||
export type { ChangeStorageResponse } from './model/ChangeStorageResponse';
|
||||
export type { GetSubscriptionEventsResponse } from './model/GetSubscriptionEventsResponse';
|
||||
export type { GetInvoicesResponse } from './model/GetInvoicesResponse';
|
||||
@@ -0,0 +1,4 @@
|
||||
export enum ChangeStorageApplyMode {
|
||||
Immediate = 'immediate',
|
||||
NextCycle = 'next_cycle',
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
import { ChangeStorageApplyMode } from './ChangeStorageApplyMode';
|
||||
|
||||
export interface ChangeStorageResponse {
|
||||
applyMode: ChangeStorageApplyMode;
|
||||
currentGb: number;
|
||||
pendingGb?: number;
|
||||
}
|
||||
8
frontend/src/entity/billing/model/GetInvoicesResponse.ts
Normal file
8
frontend/src/entity/billing/model/GetInvoicesResponse.ts
Normal file
@@ -0,0 +1,8 @@
|
||||
import type { Invoice } from './Invoice';
|
||||
|
||||
export interface GetInvoicesResponse {
|
||||
invoices: Invoice[];
|
||||
total: number;
|
||||
limit: number;
|
||||
offset: number;
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
import type { SubscriptionEvent } from './SubscriptionEvent';
|
||||
|
||||
export interface GetSubscriptionEventsResponse {
|
||||
events: SubscriptionEvent[];
|
||||
total: number;
|
||||
limit: number;
|
||||
offset: number;
|
||||
}
|
||||
14
frontend/src/entity/billing/model/Invoice.ts
Normal file
14
frontend/src/entity/billing/model/Invoice.ts
Normal file
@@ -0,0 +1,14 @@
|
||||
import { InvoiceStatus } from './InvoiceStatus';
|
||||
|
||||
export interface Invoice {
|
||||
id: string;
|
||||
subscriptionId: string;
|
||||
providerInvoiceId: string;
|
||||
amountCents: number;
|
||||
storageGb: number;
|
||||
periodStart: string;
|
||||
periodEnd: string;
|
||||
status: InvoiceStatus;
|
||||
paidAt?: string;
|
||||
createdAt: string;
|
||||
}
|
||||
7
frontend/src/entity/billing/model/InvoiceStatus.ts
Normal file
7
frontend/src/entity/billing/model/InvoiceStatus.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
export enum InvoiceStatus {
|
||||
Pending = 'pending',
|
||||
Paid = 'paid',
|
||||
Failed = 'failed',
|
||||
Refunded = 'refunded',
|
||||
Disputed = 'disputed',
|
||||
}
|
||||
18
frontend/src/entity/billing/model/Subscription.ts
Normal file
18
frontend/src/entity/billing/model/Subscription.ts
Normal file
@@ -0,0 +1,18 @@
|
||||
import { SubscriptionStatus } from './SubscriptionStatus';
|
||||
|
||||
export interface Subscription {
|
||||
id: string;
|
||||
databaseId: string;
|
||||
status: SubscriptionStatus;
|
||||
storageGb: number;
|
||||
pendingStorageGb?: number;
|
||||
currentPeriodStart: string;
|
||||
currentPeriodEnd: string;
|
||||
canceledAt?: string;
|
||||
dataRetentionGracePeriodUntil?: string;
|
||||
providerName?: string;
|
||||
providerSubId?: string;
|
||||
providerCustomerId?: string;
|
||||
createdAt: string;
|
||||
updatedAt: string;
|
||||
}
|
||||
14
frontend/src/entity/billing/model/SubscriptionEvent.ts
Normal file
14
frontend/src/entity/billing/model/SubscriptionEvent.ts
Normal file
@@ -0,0 +1,14 @@
|
||||
import { SubscriptionEventType } from './SubscriptionEventType';
|
||||
import { SubscriptionStatus } from './SubscriptionStatus';
|
||||
|
||||
export interface SubscriptionEvent {
|
||||
id: string;
|
||||
subscriptionId: string;
|
||||
providerEventId?: string;
|
||||
type: SubscriptionEventType;
|
||||
oldStorageGb?: number;
|
||||
newStorageGb?: number;
|
||||
oldStatus?: SubscriptionStatus;
|
||||
newStatus?: SubscriptionStatus;
|
||||
createdAt: string;
|
||||
}
|
||||
13
frontend/src/entity/billing/model/SubscriptionEventType.ts
Normal file
13
frontend/src/entity/billing/model/SubscriptionEventType.ts
Normal file
@@ -0,0 +1,13 @@
|
||||
export enum SubscriptionEventType {
|
||||
Created = 'subscription.created',
|
||||
Upgraded = 'subscription.upgraded',
|
||||
Downgraded = 'subscription.downgraded',
|
||||
NewBillingCycleStarted = 'subscription.new_billing_cycle_started',
|
||||
Canceled = 'subscription.canceled',
|
||||
Reactivated = 'subscription.reactivated',
|
||||
Expired = 'subscription.expired',
|
||||
PastDue = 'subscription.past_due',
|
||||
RecoveredFromPastDue = 'subscription.recovered_from_past_due',
|
||||
Refund = 'payment.refund',
|
||||
Dispute = 'payment.dispute',
|
||||
}
|
||||
7
frontend/src/entity/billing/model/SubscriptionStatus.ts
Normal file
7
frontend/src/entity/billing/model/SubscriptionStatus.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
export enum SubscriptionStatus {
|
||||
Trial = 'trial',
|
||||
Active = 'active',
|
||||
PastDue = 'past_due',
|
||||
Canceled = 'canceled',
|
||||
Expired = 'expired',
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
export type { DatabasePlan } from './model/DatabasePlan';
|
||||
@@ -1,8 +0,0 @@
|
||||
import type { Period } from '../../databases/model/Period';
|
||||
|
||||
export interface DatabasePlan {
|
||||
databaseId: string;
|
||||
maxBackupSizeMb: number;
|
||||
maxBackupsTotalSizeMb: number;
|
||||
maxStoragePeriod: Period;
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
import { Button } from 'antd';
|
||||
import dayjs from 'dayjs';
|
||||
import { useEffect, useState } from 'react';
|
||||
|
||||
import { type Subscription, SubscriptionStatus, billingApi } from '../../../entity/billing';
|
||||
import { getUserShortTimeFormat } from '../../../shared/time';
|
||||
import { PurchaseComponent } from '../../billing';
|
||||
|
||||
interface Props {
|
||||
databaseId: string;
|
||||
isCanManageDBs: boolean;
|
||||
onNavigateToBilling?: () => void;
|
||||
}
|
||||
|
||||
export const BackupsBillingBannerComponent = ({
|
||||
databaseId,
|
||||
isCanManageDBs,
|
||||
onNavigateToBilling,
|
||||
}: Props) => {
|
||||
const [subscription, setSubscription] = useState<Subscription | null>(null);
|
||||
const [isPurchaseModalOpen, setIsPurchaseModalOpen] = useState(false);
|
||||
|
||||
const loadSubscription = async () => {
|
||||
try {
|
||||
const sub = await billingApi.getSubscription(databaseId);
|
||||
setSubscription(sub);
|
||||
} catch {
|
||||
setSubscription(null);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
loadSubscription();
|
||||
}, [databaseId]);
|
||||
|
||||
if (
|
||||
!subscription ||
|
||||
(subscription.status !== SubscriptionStatus.Trial &&
|
||||
subscription.status !== SubscriptionStatus.Canceled &&
|
||||
subscription.status !== SubscriptionStatus.Expired)
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<div
|
||||
className={`mt-3 rounded-lg px-4 py-3 text-sm ${
|
||||
subscription.status === SubscriptionStatus.Canceled ||
|
||||
subscription.status === SubscriptionStatus.Expired
|
||||
? 'border border-red-600/30 bg-red-900/20'
|
||||
: 'border border-yellow-600/30 bg-yellow-900/20'
|
||||
}`}
|
||||
>
|
||||
<p
|
||||
className={
|
||||
subscription.status === SubscriptionStatus.Canceled ||
|
||||
subscription.status === SubscriptionStatus.Expired
|
||||
? 'text-red-400'
|
||||
: 'text-yellow-400'
|
||||
}
|
||||
>
|
||||
{subscription.status === SubscriptionStatus.Trial && (
|
||||
<>
|
||||
You are on a free trial. Your trial ends on{' '}
|
||||
<span className="font-medium">
|
||||
{dayjs
|
||||
.utc(subscription.currentPeriodEnd)
|
||||
.local()
|
||||
.format(getUserShortTimeFormat().format)}
|
||||
</span>{' '}
|
||||
({dayjs.utc(subscription.currentPeriodEnd).local().fromNow()}). After that, backups
|
||||
will be removed.
|
||||
</>
|
||||
)}
|
||||
|
||||
{subscription.status === SubscriptionStatus.Canceled && (
|
||||
<>
|
||||
Your subscription has been canceled.{' '}
|
||||
{subscription.dataRetentionGracePeriodUntil ? (
|
||||
<>
|
||||
Backups will be removed on{' '}
|
||||
<span className="font-medium">
|
||||
{dayjs
|
||||
.utc(subscription.dataRetentionGracePeriodUntil)
|
||||
.local()
|
||||
.format(getUserShortTimeFormat().format)}
|
||||
</span>{' '}
|
||||
({dayjs.utc(subscription.dataRetentionGracePeriodUntil).local().fromNow()}).
|
||||
</>
|
||||
) : (
|
||||
<> Backups will be removed after the grace period.</>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
|
||||
{subscription.status === SubscriptionStatus.Expired && (
|
||||
<>Your subscription has expired.</>
|
||||
)}
|
||||
</p>
|
||||
|
||||
{isCanManageDBs &&
|
||||
subscription.status === SubscriptionStatus.Canceled &&
|
||||
onNavigateToBilling && (
|
||||
<Button type="primary" size="small" className="mt-2" onClick={onNavigateToBilling}>
|
||||
Go to Billing
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{isCanManageDBs && subscription.status !== SubscriptionStatus.Canceled && (
|
||||
<Button
|
||||
type="primary"
|
||||
size="small"
|
||||
className="mt-2"
|
||||
onClick={() => setIsPurchaseModalOpen(true)}
|
||||
>
|
||||
Purchase storage
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{isPurchaseModalOpen && (
|
||||
<PurchaseComponent
|
||||
databaseId={databaseId}
|
||||
onSubscriptionChanged={() => loadSubscription()}
|
||||
onClose={() => setIsPurchaseModalOpen(false)}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
};
|
||||
@@ -14,6 +14,7 @@ import type { ColumnsType } from 'antd/es/table';
|
||||
import dayjs from 'dayjs';
|
||||
import { useEffect, useRef, useState } from 'react';
|
||||
|
||||
import { IS_CLOUD } from '../../../constants';
|
||||
import {
|
||||
type Backup,
|
||||
type BackupConfig,
|
||||
@@ -28,6 +29,7 @@ import { getUserTimeFormat } from '../../../shared/time';
|
||||
import { ConfirmationComponent } from '../../../shared/ui';
|
||||
import { RestoresComponent } from '../../restores';
|
||||
import { AgentRestoreComponent } from './AgentRestoreComponent';
|
||||
import { BackupsBillingBannerComponent } from './BackupsBillingBannerComponent';
|
||||
|
||||
const BACKUPS_PAGE_SIZE = 50;
|
||||
|
||||
@@ -36,6 +38,7 @@ interface Props {
|
||||
isCanManageDBs: boolean;
|
||||
isDirectlyUnderTab?: boolean;
|
||||
scrollContainerRef?: React.RefObject<HTMLDivElement | null>;
|
||||
onNavigateToBilling?: () => void;
|
||||
}
|
||||
|
||||
export const BackupsComponent = ({
|
||||
@@ -43,6 +46,7 @@ export const BackupsComponent = ({
|
||||
isCanManageDBs,
|
||||
isDirectlyUnderTab,
|
||||
scrollContainerRef,
|
||||
onNavigateToBilling,
|
||||
}: Props) => {
|
||||
const [isBackupsLoading, setIsBackupsLoading] = useState(false);
|
||||
const [backups, setBackups] = useState<Backup[]>([]);
|
||||
@@ -510,6 +514,14 @@ export const BackupsComponent = ({
|
||||
>
|
||||
<h2 className="text-lg font-bold md:text-xl dark:text-white">Backups</h2>
|
||||
|
||||
{IS_CLOUD && (
|
||||
<BackupsBillingBannerComponent
|
||||
databaseId={database.id}
|
||||
isCanManageDBs={isCanManageDBs}
|
||||
onNavigateToBilling={onNavigateToBilling}
|
||||
/>
|
||||
)}
|
||||
|
||||
{!isBackupConfigLoading && !backupConfig?.isBackupsEnabled && (
|
||||
<div className="text-sm text-red-600">
|
||||
Scheduled backups are disabled (you can enable it back in the backup configuration)
|
||||
|
||||
@@ -19,7 +19,6 @@ import { IS_CLOUD } from '../../../constants';
|
||||
import {
|
||||
type BackupConfig,
|
||||
BackupEncryption,
|
||||
type DatabasePlan,
|
||||
RetentionPolicyType,
|
||||
backupConfigApi,
|
||||
} from '../../../entity/backups';
|
||||
@@ -97,13 +96,9 @@ export const EditBackupConfigComponent = ({
|
||||
|
||||
const [isShowWarn, setIsShowWarn] = useState(false);
|
||||
|
||||
const [databasePlan, setDatabasePlan] = useState<DatabasePlan>();
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
|
||||
const hasAdvancedValues =
|
||||
!!backupConfig?.isRetryIfFailed ||
|
||||
(backupConfig?.maxBackupSizeMb ?? 0) > 0 ||
|
||||
(backupConfig?.maxBackupsTotalSizeMb ?? 0) > 0;
|
||||
const hasAdvancedValues = !!backupConfig?.isRetryIfFailed;
|
||||
const [isShowAdvanced, setShowAdvanced] = useState(hasAdvancedValues);
|
||||
const [isShowGfsHint, setShowGfsHint] = useState(false);
|
||||
|
||||
@@ -114,65 +109,6 @@ export const EditBackupConfigComponent = ({
|
||||
|
||||
const dateTimeFormat = useMemo(() => getUserTimeFormat(), []);
|
||||
|
||||
const createDefaultPlan = (databaseId: string, isCloud: boolean): DatabasePlan => {
|
||||
if (isCloud) {
|
||||
return {
|
||||
databaseId,
|
||||
maxBackupSizeMb: 100,
|
||||
maxBackupsTotalSizeMb: 4000,
|
||||
maxStoragePeriod: Period.WEEK,
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
databaseId,
|
||||
maxBackupSizeMb: 0,
|
||||
maxBackupsTotalSizeMb: 0,
|
||||
maxStoragePeriod: Period.FOREVER,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
const isPeriodAllowed = (period: Period, maxPeriod: Period): boolean => {
|
||||
const periodOrder = [
|
||||
Period.DAY,
|
||||
Period.WEEK,
|
||||
Period.MONTH,
|
||||
Period.THREE_MONTH,
|
||||
Period.SIX_MONTH,
|
||||
Period.YEAR,
|
||||
Period.TWO_YEARS,
|
||||
Period.THREE_YEARS,
|
||||
Period.FOUR_YEARS,
|
||||
Period.FIVE_YEARS,
|
||||
Period.FOREVER,
|
||||
];
|
||||
const periodIndex = periodOrder.indexOf(period);
|
||||
const maxIndex = periodOrder.indexOf(maxPeriod);
|
||||
return periodIndex <= maxIndex;
|
||||
};
|
||||
|
||||
const availablePeriods = useMemo(() => {
|
||||
const allPeriods = [
|
||||
{ label: '1 day', value: Period.DAY },
|
||||
{ label: '1 week', value: Period.WEEK },
|
||||
{ label: '1 month', value: Period.MONTH },
|
||||
{ label: '3 months', value: Period.THREE_MONTH },
|
||||
{ label: '6 months', value: Period.SIX_MONTH },
|
||||
{ label: '1 year', value: Period.YEAR },
|
||||
{ label: '2 years', value: Period.TWO_YEARS },
|
||||
{ label: '3 years', value: Period.THREE_YEARS },
|
||||
{ label: '4 years', value: Period.FOUR_YEARS },
|
||||
{ label: '5 years', value: Period.FIVE_YEARS },
|
||||
{ label: 'Forever', value: Period.FOREVER },
|
||||
];
|
||||
|
||||
if (!databasePlan) {
|
||||
return allPeriods;
|
||||
}
|
||||
|
||||
return allPeriods.filter((p) => isPeriodAllowed(p.value, databasePlan.maxStoragePeriod));
|
||||
}, [databasePlan]);
|
||||
|
||||
const updateBackupConfig = (patch: Partial<BackupConfig>) => {
|
||||
setBackupConfig((prev) => (prev ? { ...prev, ...patch } : prev));
|
||||
setIsUnsaved(true);
|
||||
@@ -237,13 +173,7 @@ export const EditBackupConfigComponent = ({
|
||||
setBackupConfig(config);
|
||||
setIsUnsaved(false);
|
||||
setIsSaving(false);
|
||||
|
||||
const plan = await backupConfigApi.getDatabasePlan(database.id);
|
||||
setDatabasePlan(plan);
|
||||
} else {
|
||||
const plan = createDefaultPlan('', IS_CLOUD);
|
||||
setDatabasePlan(plan);
|
||||
|
||||
setBackupConfig({
|
||||
databaseId: database.id,
|
||||
isBackupsEnabled: true,
|
||||
@@ -256,11 +186,7 @@ export const EditBackupConfigComponent = ({
|
||||
retentionPolicyType: IS_CLOUD
|
||||
? RetentionPolicyType.GFS
|
||||
: RetentionPolicyType.TimePeriod,
|
||||
retentionTimePeriod: IS_CLOUD
|
||||
? plan.maxStoragePeriod === Period.FOREVER
|
||||
? Period.THREE_MONTH
|
||||
: plan.maxStoragePeriod
|
||||
: Period.THREE_MONTH,
|
||||
retentionTimePeriod: Period.THREE_MONTH,
|
||||
retentionCount: 100,
|
||||
retentionGfsHours: 24,
|
||||
retentionGfsDays: 7,
|
||||
@@ -271,8 +197,6 @@ export const EditBackupConfigComponent = ({
|
||||
isRetryIfFailed: true,
|
||||
maxFailedTriesCount: 3,
|
||||
encryption: BackupEncryption.ENCRYPTED,
|
||||
maxBackupSizeMb: plan.maxBackupSizeMb,
|
||||
maxBackupsTotalSizeMb: plan.maxBackupsTotalSizeMb,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -604,7 +528,19 @@ export const EditBackupConfigComponent = ({
|
||||
onChange={(v) => updateBackupConfig({ retentionTimePeriod: v })}
|
||||
size="small"
|
||||
className="w-[200px]"
|
||||
options={availablePeriods}
|
||||
options={[
|
||||
{ label: '1 day', value: Period.DAY },
|
||||
{ label: '1 week', value: Period.WEEK },
|
||||
{ label: '1 month', value: Period.MONTH },
|
||||
{ label: '3 months', value: Period.THREE_MONTH },
|
||||
{ label: '6 months', value: Period.SIX_MONTH },
|
||||
{ label: '1 year', value: Period.YEAR },
|
||||
{ label: '2 years', value: Period.TWO_YEARS },
|
||||
{ label: '3 years', value: Period.THREE_YEARS },
|
||||
{ label: '4 years', value: Period.FOUR_YEARS },
|
||||
{ label: '5 years', value: Period.FIVE_YEARS },
|
||||
{ label: 'Forever', value: Period.FOREVER },
|
||||
]}
|
||||
/>
|
||||
|
||||
<Tooltip
|
||||
@@ -829,121 +765,6 @@ export const EditBackupConfigComponent = ({
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="mt-5 mb-1 flex w-full flex-col items-start sm:flex-row sm:items-center">
|
||||
<div className="mb-1 min-w-[150px] sm:mb-0">Max backup size limit</div>
|
||||
<div className="flex items-center">
|
||||
<Switch
|
||||
size="small"
|
||||
checked={backupConfig.maxBackupSizeMb > 0}
|
||||
disabled={IS_CLOUD}
|
||||
onChange={(checked) => {
|
||||
updateBackupConfig({
|
||||
maxBackupSizeMb: checked ? backupConfig.maxBackupSizeMb || 1000 : 0,
|
||||
});
|
||||
}}
|
||||
/>
|
||||
|
||||
<Tooltip
|
||||
className="cursor-pointer"
|
||||
title="Limits the size of each individual backup. Note that backups are typically 15× smaller than the database size. For example, a 100 MB backup represents approximately 1.5 GB database."
|
||||
>
|
||||
<InfoCircleOutlined className="ml-2" style={{ color: 'gray' }} />
|
||||
</Tooltip>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{backupConfig.maxBackupSizeMb > 0 && (
|
||||
<div className="mb-5 flex w-full flex-col items-start sm:flex-row sm:items-center">
|
||||
<div className="mb-1 min-w-[150px] sm:mb-0">Max file size (MB)</div>
|
||||
|
||||
<InputNumber
|
||||
min={1}
|
||||
max={
|
||||
databasePlan?.maxBackupSizeMb && databasePlan.maxBackupSizeMb > 0
|
||||
? databasePlan.maxBackupSizeMb
|
||||
: undefined
|
||||
}
|
||||
value={backupConfig.maxBackupSizeMb}
|
||||
onChange={(value) => {
|
||||
const newValue = value || 1;
|
||||
if (databasePlan?.maxBackupSizeMb && databasePlan.maxBackupSizeMb > 0) {
|
||||
updateBackupConfig({
|
||||
maxBackupSizeMb: Math.min(newValue, databasePlan.maxBackupSizeMb),
|
||||
});
|
||||
} else {
|
||||
updateBackupConfig({ maxBackupSizeMb: newValue });
|
||||
}
|
||||
}}
|
||||
size="small"
|
||||
className="w-full max-w-[75px] grow"
|
||||
/>
|
||||
|
||||
<div className="ml-2 text-xs text-gray-600 dark:text-gray-400">
|
||||
~{((backupConfig.maxBackupSizeMb / 1024) * 15).toFixed(2)} GB DB size
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="mb-1 flex w-full flex-col items-start sm:flex-row sm:items-center">
|
||||
<div className="mb-1 min-w-[150px] sm:mb-0">Limit total backups size</div>
|
||||
<div className="flex items-center">
|
||||
<Switch
|
||||
size="small"
|
||||
checked={backupConfig.maxBackupsTotalSizeMb > 0}
|
||||
disabled={IS_CLOUD}
|
||||
onChange={(checked) => {
|
||||
updateBackupConfig({
|
||||
maxBackupsTotalSizeMb: checked
|
||||
? backupConfig.maxBackupsTotalSizeMb || 1_000_000
|
||||
: 0,
|
||||
});
|
||||
}}
|
||||
/>
|
||||
|
||||
<Tooltip
|
||||
className="cursor-pointer"
|
||||
title="Limits the total size of all backups in storage (like S3, local, etc.). Once this limit is exceeded, the oldest backups are automatically removed until the total size is within the limit again."
|
||||
>
|
||||
<InfoCircleOutlined className="ml-2" style={{ color: 'gray' }} />
|
||||
</Tooltip>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{backupConfig.maxBackupsTotalSizeMb > 0 && (
|
||||
<div className="mb-1 flex w-full flex-col items-start sm:flex-row sm:items-center">
|
||||
<div className="mb-1 min-w-[150px] sm:mb-0">Backups files size (MB)</div>
|
||||
<InputNumber
|
||||
min={1}
|
||||
max={
|
||||
databasePlan?.maxBackupsTotalSizeMb && databasePlan.maxBackupsTotalSizeMb > 0
|
||||
? databasePlan.maxBackupsTotalSizeMb
|
||||
: undefined
|
||||
}
|
||||
value={backupConfig.maxBackupsTotalSizeMb}
|
||||
onChange={(value) => {
|
||||
const newValue = value || 1;
|
||||
if (
|
||||
databasePlan?.maxBackupsTotalSizeMb &&
|
||||
databasePlan.maxBackupsTotalSizeMb > 0
|
||||
) {
|
||||
updateBackupConfig({
|
||||
maxBackupsTotalSizeMb: Math.min(newValue, databasePlan.maxBackupsTotalSizeMb),
|
||||
});
|
||||
} else {
|
||||
updateBackupConfig({ maxBackupsTotalSizeMb: newValue });
|
||||
}
|
||||
}}
|
||||
size="small"
|
||||
className="w-full max-w-[75px] grow"
|
||||
/>
|
||||
|
||||
<div className="ml-2 text-xs text-gray-600 dark:text-gray-400">
|
||||
{(backupConfig.maxBackupsTotalSizeMb / 1024).toFixed(2)} GB (~
|
||||
{backupConfig.maxBackupsTotalSizeMb / backupConfig.maxBackupSizeMb} backups)
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
|
||||
|
||||
199
frontend/src/features/billing/hooks/usePurchaseFlow.ts
Normal file
199
frontend/src/features/billing/hooks/usePurchaseFlow.ts
Normal file
@@ -0,0 +1,199 @@
|
||||
import { App } from 'antd';
|
||||
import { useEffect, useRef, useState } from 'react';
|
||||
|
||||
import { ChangeStorageApplyMode, SubscriptionStatus, billingApi } from '../../../entity/billing';
|
||||
import type { Subscription } from '../../../entity/billing';
|
||||
import { POLL_INTERVAL_MS, POLL_TIMEOUT_MS, findSliderPosForGb } from '../models/purchaseUtils';
|
||||
|
||||
interface UsePurchaseFlowParams {
|
||||
databaseId: string;
|
||||
onSubscriptionChanged: () => void;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
export function usePurchaseFlow({
|
||||
databaseId,
|
||||
onSubscriptionChanged,
|
||||
onClose,
|
||||
}: UsePurchaseFlowParams) {
|
||||
const { message } = App.useApp();
|
||||
|
||||
const [subscription, setSubscription] = useState<Subscription | null>(null);
|
||||
const [isLoadingSubscription, setIsLoadingSubscription] = useState(true);
|
||||
const [loadError, setLoadError] = useState<string | null>(null);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [isCheckoutOpen, setIsCheckoutOpen] = useState(false);
|
||||
const [isWaitingForPayment, setIsWaitingForPayment] = useState(false);
|
||||
const [isPaymentConfirmed, setIsPaymentConfirmed] = useState(false);
|
||||
const [confirmedStorageGb, setConfirmedStorageGb] = useState<number | undefined>();
|
||||
const [isPaymentTimedOut, setIsPaymentTimedOut] = useState(false);
|
||||
const [isWaitingForUpgrade, setIsWaitingForUpgrade] = useState(false);
|
||||
const [isUpgradeTimedOut, setIsUpgradeTimedOut] = useState(false);
|
||||
const [initialSliderPos, setInitialSliderPos] = useState(0);
|
||||
|
||||
const pollingRef = useRef<ReturnType<typeof setInterval> | null>(null);
|
||||
const timeoutRef = useRef<ReturnType<typeof setTimeout> | null>(null);
|
||||
|
||||
const stopPolling = () => {
|
||||
if (pollingRef.current) {
|
||||
clearInterval(pollingRef.current);
|
||||
pollingRef.current = null;
|
||||
}
|
||||
if (timeoutRef.current) {
|
||||
clearTimeout(timeoutRef.current);
|
||||
timeoutRef.current = null;
|
||||
}
|
||||
};
|
||||
|
||||
const loadSubscription = async () => {
|
||||
setIsLoadingSubscription(true);
|
||||
setLoadError(null);
|
||||
|
||||
try {
|
||||
const sub = await billingApi.getSubscription(databaseId);
|
||||
setSubscription(sub);
|
||||
setInitialSliderPos(findSliderPosForGb(sub.storageGb));
|
||||
} catch {
|
||||
setLoadError('Failed to load subscription');
|
||||
} finally {
|
||||
setIsLoadingSubscription(false);
|
||||
}
|
||||
};
|
||||
|
||||
const pollForPaymentConfirmation = () => {
|
||||
setIsWaitingForPayment(true);
|
||||
setIsPaymentTimedOut(false);
|
||||
|
||||
pollingRef.current = setInterval(async () => {
|
||||
try {
|
||||
const sub = await billingApi.getSubscription(databaseId);
|
||||
|
||||
if (
|
||||
sub.status !== SubscriptionStatus.Trial &&
|
||||
sub.status !== SubscriptionStatus.Expired &&
|
||||
sub.status !== SubscriptionStatus.Canceled
|
||||
) {
|
||||
stopPolling();
|
||||
setIsWaitingForPayment(false);
|
||||
setIsPaymentConfirmed(true);
|
||||
setConfirmedStorageGb(sub.storageGb);
|
||||
onSubscriptionChanged();
|
||||
}
|
||||
} catch {
|
||||
// ignore polling errors, keep trying
|
||||
}
|
||||
}, POLL_INTERVAL_MS);
|
||||
|
||||
timeoutRef.current = setTimeout(() => {
|
||||
stopPolling();
|
||||
setIsWaitingForPayment(false);
|
||||
setIsPaymentTimedOut(true);
|
||||
}, POLL_TIMEOUT_MS);
|
||||
};
|
||||
|
||||
const pollForUpgradeConfirmation = (targetStorageGb: number) => {
|
||||
setIsWaitingForUpgrade(true);
|
||||
setIsUpgradeTimedOut(false);
|
||||
|
||||
pollingRef.current = setInterval(async () => {
|
||||
try {
|
||||
const sub = await billingApi.getSubscription(databaseId);
|
||||
|
||||
if (sub.storageGb === targetStorageGb && sub.pendingStorageGb === undefined) {
|
||||
stopPolling();
|
||||
setIsWaitingForUpgrade(false);
|
||||
onSubscriptionChanged();
|
||||
onClose();
|
||||
}
|
||||
} catch {
|
||||
// ignore polling errors, keep trying
|
||||
}
|
||||
}, POLL_INTERVAL_MS);
|
||||
|
||||
timeoutRef.current = setTimeout(() => {
|
||||
stopPolling();
|
||||
setIsWaitingForUpgrade(false);
|
||||
setIsUpgradeTimedOut(true);
|
||||
}, POLL_TIMEOUT_MS);
|
||||
};
|
||||
|
||||
const handlePurchase = async (storageGb: number) => {
|
||||
setIsSubmitting(true);
|
||||
|
||||
try {
|
||||
const result = await billingApi.createSubscription(databaseId, storageGb);
|
||||
|
||||
setIsCheckoutOpen(true);
|
||||
setIsSubmitting(false);
|
||||
|
||||
Paddle.Checkout.open({
|
||||
transactionId: result.paddleTransactionId,
|
||||
});
|
||||
} catch {
|
||||
message.error('Failed to create subscription');
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleStorageChange = async (storageGb: number) => {
|
||||
if (!subscription) return;
|
||||
|
||||
setIsSubmitting(true);
|
||||
|
||||
try {
|
||||
const result = await billingApi.changeStorage(databaseId, storageGb);
|
||||
|
||||
if (result.applyMode === ChangeStorageApplyMode.Immediate) {
|
||||
setIsSubmitting(false);
|
||||
pollForUpgradeConfirmation(storageGb);
|
||||
} else {
|
||||
setIsSubmitting(false);
|
||||
onSubscriptionChanged();
|
||||
onClose();
|
||||
}
|
||||
} catch {
|
||||
message.error('Failed to change storage');
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
loadSubscription();
|
||||
|
||||
return () => stopPolling();
|
||||
}, [databaseId]);
|
||||
|
||||
useEffect(() => {
|
||||
const handlePaddleEvent = (e: Event) => {
|
||||
const event = (e as CustomEvent<PaddleEvent>).detail;
|
||||
|
||||
if (event.name === 'checkout.completed') {
|
||||
setIsCheckoutOpen(false);
|
||||
pollForPaymentConfirmation();
|
||||
} else if (event.name === 'checkout.closed') {
|
||||
setIsCheckoutOpen(false);
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener('paddle-event', handlePaddleEvent);
|
||||
|
||||
return () => window.removeEventListener('paddle-event', handlePaddleEvent);
|
||||
}, [databaseId]);
|
||||
|
||||
return {
|
||||
subscription,
|
||||
isLoadingSubscription,
|
||||
loadError,
|
||||
isSubmitting,
|
||||
isCheckoutOpen,
|
||||
isWaitingForPayment,
|
||||
isPaymentConfirmed,
|
||||
confirmedStorageGb,
|
||||
isPaymentTimedOut,
|
||||
isWaitingForUpgrade,
|
||||
isUpgradeTimedOut,
|
||||
initialSliderPos,
|
||||
handlePurchase,
|
||||
handleStorageChange,
|
||||
};
|
||||
}
|
||||
2
frontend/src/features/billing/index.ts
Normal file
2
frontend/src/features/billing/index.ts
Normal file
@@ -0,0 +1,2 @@
|
||||
export { BillingComponent } from './ui/BillingComponent';
|
||||
export { PurchaseComponent } from './ui/PurchaseComponent';
|
||||
82
frontend/src/features/billing/models/purchaseUtils.ts
Normal file
82
frontend/src/features/billing/models/purchaseUtils.ts
Normal file
@@ -0,0 +1,82 @@
|
||||
const BACKUPS_COMPRESSION_RATIO = 10;
|
||||
|
||||
function buildBackupSizeSteps(): number[] {
|
||||
const values: number[] = [];
|
||||
for (let i = 1; i <= 100; i++) values.push(i);
|
||||
for (let i = 110; i <= 200; i += 10) values.push(i);
|
||||
return values;
|
||||
}
|
||||
|
||||
function buildStorageSizeSteps(): number[] {
|
||||
const values: number[] = [];
|
||||
for (let i = 20; i <= 100; i++) values.push(i);
|
||||
for (let i = 110; i <= 1000; i += 10) values.push(i);
|
||||
for (let i = 1100; i <= 5000; i += 100) values.push(i);
|
||||
for (let i = 6000; i <= 10000; i += 1000) values.push(i);
|
||||
return values;
|
||||
}
|
||||
|
||||
const BACKUP_SIZE_STEPS = buildBackupSizeSteps();
|
||||
const STORAGE_SIZE_STEPS = buildStorageSizeSteps();
|
||||
|
||||
const DB_SIZE_COMMANDS = [
|
||||
{
|
||||
label: 'PostgreSQL',
|
||||
code: `SELECT pg_size_pretty(pg_database_size(current_database()));`,
|
||||
},
|
||||
{
|
||||
label: 'MySQL / MariaDB',
|
||||
code: `SELECT table_schema AS 'Database',
|
||||
ROUND(SUM(data_length + index_length) / 1024 / 1024, 2) AS 'Size (MB)'
|
||||
FROM information_schema.tables
|
||||
GROUP BY table_schema;`,
|
||||
},
|
||||
{
|
||||
label: 'MongoDB',
|
||||
code: `db.stats(1024 * 1024) // size in MB`,
|
||||
},
|
||||
];
|
||||
|
||||
const POLL_INTERVAL_MS = 3000;
|
||||
const POLL_TIMEOUT_MS = 2 * 60 * 1000;
|
||||
|
||||
function distributeGfs(total: number) {
|
||||
const daily = Math.min(7, total);
|
||||
const weekly = Math.min(4, Math.max(0, total - daily));
|
||||
const monthly = Math.min(12, Math.max(0, total - daily - weekly));
|
||||
const yearly = Math.max(0, total - daily - weekly - monthly);
|
||||
return { daily, weekly, monthly, yearly };
|
||||
}
|
||||
|
||||
function formatSize(gb: number): string {
|
||||
if (gb >= 1000) {
|
||||
const tb = gb / 1000;
|
||||
return tb % 1 === 0 ? `${tb} TB` : `${tb.toFixed(1)} TB`;
|
||||
}
|
||||
return `${gb} GB`;
|
||||
}
|
||||
|
||||
function sliderBackground(pos: number, max: number): React.CSSProperties {
|
||||
const pct = (pos / max) * 100;
|
||||
return {
|
||||
background: `linear-gradient(to right, #155dfc ${pct}%, #1f2937 ${pct}%)`,
|
||||
};
|
||||
}
|
||||
|
||||
function findSliderPosForGb(gb: number): number {
|
||||
const idx = STORAGE_SIZE_STEPS.findIndex((s) => s >= gb);
|
||||
return idx === -1 ? STORAGE_SIZE_STEPS.length - 1 : idx;
|
||||
}
|
||||
|
||||
export {
|
||||
BACKUPS_COMPRESSION_RATIO,
|
||||
BACKUP_SIZE_STEPS,
|
||||
STORAGE_SIZE_STEPS,
|
||||
DB_SIZE_COMMANDS,
|
||||
POLL_INTERVAL_MS,
|
||||
POLL_TIMEOUT_MS,
|
||||
distributeGfs,
|
||||
formatSize,
|
||||
sliderBackground,
|
||||
findSliderPosForGb,
|
||||
};
|
||||
59
frontend/src/features/billing/ui/BackupRetentionSection.tsx
Normal file
59
frontend/src/features/billing/ui/BackupRetentionSection.tsx
Normal file
@@ -0,0 +1,59 @@
|
||||
interface Gfs {
|
||||
daily: number;
|
||||
weekly: number;
|
||||
monthly: number;
|
||||
yearly: number;
|
||||
}
|
||||
|
||||
interface Props {
|
||||
backupsFit: number;
|
||||
gfs: Gfs;
|
||||
}
|
||||
|
||||
export function BackupRetentionSection({ backupsFit, gfs }: Props) {
|
||||
return (
|
||||
<div>
|
||||
<div className="space-y-1.5">
|
||||
<div className="rounded-lg border border-[#ffffff20] bg-[#1f2937]/50 px-3 py-2 text-center">
|
||||
<p className="text-gray-500">Total backups</p>
|
||||
<p className="text-lg font-bold text-gray-200">{backupsFit}</p>
|
||||
</div>
|
||||
|
||||
<div className="my-1 flex items-center gap-3">
|
||||
<div className="h-px flex-1 bg-[#ffffff20]" />
|
||||
<span className="text-sm text-gray-500">or</span>
|
||||
<div className="h-px flex-1 bg-[#ffffff20]" />
|
||||
</div>
|
||||
|
||||
<p className="mb-2 text-sm text-gray-400">
|
||||
Keeps recent backups frequently, older ones less often — broad time at the lowest cost. It
|
||||
is enough to keep the following amount of backups in GFS:
|
||||
</p>
|
||||
|
||||
<div className="grid grid-cols-2 gap-1.5">
|
||||
{(
|
||||
[
|
||||
['Daily', gfs.daily],
|
||||
['Weekly', gfs.weekly],
|
||||
['Monthly', gfs.monthly],
|
||||
['Yearly', gfs.yearly],
|
||||
] as const
|
||||
).map(([label, value]) => (
|
||||
<div
|
||||
key={label}
|
||||
className="rounded-lg border border-[#ffffff20] bg-[#1f2937]/50 px-2 py-1.5 text-center"
|
||||
>
|
||||
<p className="text-xs text-gray-500">{label}</p>
|
||||
<p className="text-base font-bold text-gray-200">{value}</p>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<p className="mt-2 text-sm text-gray-400">
|
||||
You can fine-tune retention values (change daily count, keep only monthly, keep N latest,
|
||||
etc.)
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
432
frontend/src/features/billing/ui/BillingComponent.tsx
Normal file
432
frontend/src/features/billing/ui/BillingComponent.tsx
Normal file
@@ -0,0 +1,432 @@
|
||||
import { App, Button, Spin, Table, Tag } from 'antd';
|
||||
import type { ColumnsType } from 'antd/es/table';
|
||||
import dayjs from 'dayjs';
|
||||
import { useEffect, useState } from 'react';
|
||||
|
||||
import { CLOUD_PRICE_PER_GB } from '../../../constants';
|
||||
import {
|
||||
type Invoice,
|
||||
InvoiceStatus,
|
||||
type Subscription,
|
||||
type SubscriptionEvent,
|
||||
SubscriptionEventType,
|
||||
SubscriptionStatus,
|
||||
billingApi,
|
||||
} from '../../../entity/billing';
|
||||
import type { Database } from '../../../entity/databases';
|
||||
import { getUserShortTimeFormat, getUserTimeFormat } from '../../../shared/time';
|
||||
import { PurchaseComponent } from './PurchaseComponent';
|
||||
|
||||
const MAX_ROWS = 25;
|
||||
|
||||
const STATUS_TAG_COLOR: Record<SubscriptionStatus, string> = {
|
||||
[SubscriptionStatus.Trial]: 'blue',
|
||||
[SubscriptionStatus.Active]: 'green',
|
||||
[SubscriptionStatus.PastDue]: 'orange',
|
||||
[SubscriptionStatus.Canceled]: 'red',
|
||||
[SubscriptionStatus.Expired]: 'default',
|
||||
};
|
||||
|
||||
const STATUS_LABEL: Record<SubscriptionStatus, string> = {
|
||||
[SubscriptionStatus.Trial]: 'Trial',
|
||||
[SubscriptionStatus.Active]: 'Active',
|
||||
[SubscriptionStatus.PastDue]: 'Past Due',
|
||||
[SubscriptionStatus.Canceled]: 'Canceled',
|
||||
[SubscriptionStatus.Expired]: 'Expired',
|
||||
};
|
||||
|
||||
const INVOICE_STATUS_COLOR: Record<InvoiceStatus, string> = {
|
||||
[InvoiceStatus.Paid]: 'green',
|
||||
[InvoiceStatus.Pending]: 'blue',
|
||||
[InvoiceStatus.Failed]: 'red',
|
||||
[InvoiceStatus.Refunded]: 'orange',
|
||||
[InvoiceStatus.Disputed]: 'red',
|
||||
};
|
||||
|
||||
const INVOICE_STATUS_LABEL: Record<InvoiceStatus, string> = {
|
||||
[InvoiceStatus.Paid]: 'Paid',
|
||||
[InvoiceStatus.Pending]: 'Pending',
|
||||
[InvoiceStatus.Failed]: 'Failed',
|
||||
[InvoiceStatus.Refunded]: 'Refunded',
|
||||
[InvoiceStatus.Disputed]: 'Disputed',
|
||||
};
|
||||
|
||||
const EVENT_TYPE_LABEL: Record<SubscriptionEventType, string> = {
|
||||
[SubscriptionEventType.Created]: 'Subscription created',
|
||||
[SubscriptionEventType.Upgraded]: 'Storage upgraded',
|
||||
[SubscriptionEventType.Downgraded]: 'Storage downgraded',
|
||||
[SubscriptionEventType.NewBillingCycleStarted]: 'New billing cycle started',
|
||||
[SubscriptionEventType.Canceled]: 'Canceled',
|
||||
[SubscriptionEventType.Reactivated]: 'Reactivated',
|
||||
[SubscriptionEventType.Expired]: 'Expired',
|
||||
[SubscriptionEventType.PastDue]: 'Past Due',
|
||||
[SubscriptionEventType.RecoveredFromPastDue]: 'Recovered',
|
||||
[SubscriptionEventType.Refund]: 'Refund',
|
||||
[SubscriptionEventType.Dispute]: 'Dispute',
|
||||
};
|
||||
|
||||
interface Props {
|
||||
database: Database;
|
||||
isCanManageDBs: boolean;
|
||||
}
|
||||
|
||||
export const BillingComponent = ({ database, isCanManageDBs }: Props) => {
|
||||
const { message } = App.useApp();
|
||||
|
||||
const [subscription, setSubscription] = useState<Subscription | null>(null);
|
||||
const [isLoadingSubscription, setIsLoadingSubscription] = useState(true);
|
||||
const [invoices, setInvoices] = useState<Invoice[]>([]);
|
||||
const [isLoadingInvoices, setIsLoadingInvoices] = useState(false);
|
||||
const [totalInvoices, setTotalInvoices] = useState(0);
|
||||
const [events, setEvents] = useState<SubscriptionEvent[]>([]);
|
||||
const [isLoadingEvents, setIsLoadingEvents] = useState(false);
|
||||
const [totalEvents, setTotalEvents] = useState(0);
|
||||
const [isPurchaseModalOpen, setIsPurchaseModalOpen] = useState(false);
|
||||
const [isPortalLoading, setIsPortalLoading] = useState(false);
|
||||
|
||||
const loadSubscription = async (): Promise<Subscription | null> => {
|
||||
setIsLoadingSubscription(true);
|
||||
|
||||
try {
|
||||
const sub = await billingApi.getSubscription(database.id);
|
||||
setSubscription(sub);
|
||||
|
||||
return sub;
|
||||
} catch {
|
||||
setSubscription(null);
|
||||
|
||||
return null;
|
||||
} finally {
|
||||
setIsLoadingSubscription(false);
|
||||
}
|
||||
};
|
||||
|
||||
const loadInvoices = async (subscriptionId: string) => {
|
||||
setIsLoadingInvoices(true);
|
||||
|
||||
try {
|
||||
const response = await billingApi.getInvoices(subscriptionId, MAX_ROWS, 0);
|
||||
setInvoices(response.invoices);
|
||||
setTotalInvoices(response.total);
|
||||
} catch {
|
||||
setInvoices([]);
|
||||
} finally {
|
||||
setIsLoadingInvoices(false);
|
||||
}
|
||||
};
|
||||
|
||||
const loadEvents = async (subscriptionId: string) => {
|
||||
setIsLoadingEvents(true);
|
||||
|
||||
try {
|
||||
const response = await billingApi.getSubscriptionEvents(subscriptionId, MAX_ROWS, 0);
|
||||
setEvents(response.events);
|
||||
setTotalEvents(response.total);
|
||||
} catch {
|
||||
setEvents([]);
|
||||
} finally {
|
||||
setIsLoadingEvents(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handlePortalClick = async () => {
|
||||
if (!subscription) return;
|
||||
|
||||
setIsPortalLoading(true);
|
||||
|
||||
try {
|
||||
const result = await billingApi.getPortalSession(subscription.id);
|
||||
window.open(result.url, '_blank');
|
||||
} catch {
|
||||
message.error('Failed to open billing portal');
|
||||
} finally {
|
||||
setIsPortalLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSubscriptionChanged = async () => {
|
||||
const sub = await loadSubscription();
|
||||
|
||||
if (sub) {
|
||||
loadInvoices(sub.id);
|
||||
loadEvents(sub.id);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
loadSubscription();
|
||||
}, [database.id]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!subscription) return;
|
||||
|
||||
loadInvoices(subscription.id);
|
||||
loadEvents(subscription.id);
|
||||
}, [subscription?.id]);
|
||||
|
||||
const timeFormat = getUserTimeFormat();
|
||||
const shortFormat = getUserShortTimeFormat();
|
||||
|
||||
const canPurchase =
|
||||
subscription &&
|
||||
(subscription.status === SubscriptionStatus.Trial ||
|
||||
subscription.status === SubscriptionStatus.Expired);
|
||||
|
||||
const canAccessPortal =
|
||||
subscription &&
|
||||
(subscription.status === SubscriptionStatus.Active ||
|
||||
subscription.status === SubscriptionStatus.PastDue ||
|
||||
subscription.status === SubscriptionStatus.Canceled);
|
||||
|
||||
const isTrial = subscription?.status === SubscriptionStatus.Trial;
|
||||
const monthlyPrice = subscription ? subscription.storageGb * CLOUD_PRICE_PER_GB : 0;
|
||||
|
||||
const invoiceColumns: ColumnsType<Invoice> = [
|
||||
{
|
||||
title: 'Period',
|
||||
dataIndex: 'periodStart',
|
||||
render: (_: unknown, record: Invoice) =>
|
||||
`${dayjs.utc(record.periodStart).local().format(shortFormat.format)} - ${dayjs.utc(record.periodEnd).local().format(shortFormat.format)}`,
|
||||
},
|
||||
{
|
||||
title: 'Amount',
|
||||
dataIndex: 'amountCents',
|
||||
render: (cents: number) => `$${(cents / 100).toFixed(2)}`,
|
||||
},
|
||||
{
|
||||
title: 'Storage',
|
||||
dataIndex: 'storageGb',
|
||||
render: (gb: number) => `${gb} GB`,
|
||||
},
|
||||
{
|
||||
title: 'Status',
|
||||
dataIndex: 'status',
|
||||
render: (status: InvoiceStatus) => (
|
||||
<Tag color={INVOICE_STATUS_COLOR[status]}>{INVOICE_STATUS_LABEL[status]}</Tag>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: 'Paid At',
|
||||
dataIndex: 'paidAt',
|
||||
render: (paidAt: string | undefined) =>
|
||||
paidAt ? dayjs.utc(paidAt).local().format(timeFormat.format) : '-',
|
||||
},
|
||||
];
|
||||
|
||||
const eventColumns: ColumnsType<SubscriptionEvent> = [
|
||||
{
|
||||
title: 'Date',
|
||||
dataIndex: 'createdAt',
|
||||
render: (createdAt: string) => (
|
||||
<div>
|
||||
<div>{dayjs.utc(createdAt).local().format(timeFormat.format)}</div>
|
||||
<div className="text-xs text-gray-500">{dayjs.utc(createdAt).local().fromNow()}</div>
|
||||
</div>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: 'Event',
|
||||
dataIndex: 'type',
|
||||
render: (type: SubscriptionEventType) => EVENT_TYPE_LABEL[type] ?? type,
|
||||
},
|
||||
{
|
||||
title: 'Details',
|
||||
render: (_: unknown, record: SubscriptionEvent) => {
|
||||
const parts: string[] = [];
|
||||
|
||||
if (record.oldStorageGb != null && record.newStorageGb != null) {
|
||||
parts.push(`${record.oldStorageGb} GB \u2192 ${record.newStorageGb} GB`);
|
||||
}
|
||||
|
||||
if (record.oldStatus != null && record.newStatus != null) {
|
||||
parts.push(
|
||||
`${STATUS_LABEL[record.oldStatus] ?? record.oldStatus} \u2192 ${STATUS_LABEL[record.newStatus] ?? record.newStatus}`,
|
||||
);
|
||||
}
|
||||
|
||||
return parts.length > 0 ? parts.join(', ') : '-';
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
if (isLoadingSubscription) {
|
||||
return (
|
||||
<div className="flex w-full justify-center rounded-tr-md rounded-br-md rounded-bl-md bg-white p-10 shadow dark:bg-gray-800">
|
||||
<Spin size="large" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="w-full rounded-tr-md rounded-br-md rounded-bl-md bg-white p-3 shadow md:p-5 dark:bg-gray-800">
|
||||
<div className="max-w-[720px]">
|
||||
<h2 className="text-lg font-bold md:text-xl dark:text-white">Billing</h2>
|
||||
|
||||
{/* Subscription Summary */}
|
||||
{!subscription && (
|
||||
<div className="mt-4">
|
||||
<p className="text-gray-500 dark:text-gray-400">
|
||||
No subscription found for this database.
|
||||
</p>
|
||||
|
||||
{isCanManageDBs && (
|
||||
<Button type="primary" className="mt-3" onClick={() => setIsPurchaseModalOpen(true)}>
|
||||
Purchase
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{subscription && (
|
||||
<>
|
||||
<div className="mt-4 rounded-lg border border-gray-200 p-4 dark:border-gray-700">
|
||||
<div className="flex items-center gap-2">
|
||||
<Tag
|
||||
color={STATUS_TAG_COLOR[subscription.status]}
|
||||
style={{ fontSize: 14, padding: '2px 12px' }}
|
||||
>
|
||||
{STATUS_LABEL[subscription.status]}
|
||||
</Tag>
|
||||
|
||||
{!isTrial && (
|
||||
<span className="text-2xl font-bold dark:text-white">
|
||||
${monthlyPrice.toFixed(2)}
|
||||
<span className="text-sm font-normal text-gray-500">/mo</span>
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="mt-4 grid grid-cols-2 gap-3">
|
||||
<div className="rounded-md bg-gray-50 px-3 py-2 dark:bg-gray-700/50">
|
||||
<div className="text-xs text-gray-500 dark:text-gray-400">Storage</div>
|
||||
<div className="font-medium dark:text-gray-200">
|
||||
{subscription.storageGb} GB
|
||||
{subscription.pendingStorageGb != null && (
|
||||
<span className="ml-1 text-xs text-yellow-500">
|
||||
({subscription.pendingStorageGb} GB pending)
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="rounded-md bg-gray-50 px-3 py-2 dark:bg-gray-700/50">
|
||||
<div className="text-xs text-gray-500 dark:text-gray-400">Current period</div>
|
||||
<div className="font-medium dark:text-gray-200">
|
||||
{dayjs.utc(subscription.currentPeriodStart).local().format(shortFormat.format)}{' '}
|
||||
- {dayjs.utc(subscription.currentPeriodEnd).local().format(shortFormat.format)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{subscription.canceledAt && (
|
||||
<div className="rounded-md bg-red-50 px-3 py-2 dark:bg-red-900/20">
|
||||
<div className="text-xs text-gray-500 dark:text-gray-400">Canceled at</div>
|
||||
<div className="font-medium text-red-500">
|
||||
{dayjs.utc(subscription.canceledAt).local().format(timeFormat.format)}
|
||||
</div>
|
||||
<div className="text-xs text-gray-500">
|
||||
{dayjs.utc(subscription.canceledAt).local().fromNow()}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{subscription.dataRetentionGracePeriodUntil && (
|
||||
<div className="rounded-md bg-yellow-50 px-3 py-2 dark:bg-yellow-900/20">
|
||||
<div className="text-xs text-gray-500 dark:text-gray-400">
|
||||
Data retained until
|
||||
</div>
|
||||
<div className="font-medium text-yellow-500">
|
||||
{dayjs
|
||||
.utc(subscription.dataRetentionGracePeriodUntil)
|
||||
.local()
|
||||
.format(timeFormat.format)}
|
||||
</div>
|
||||
<div className="text-xs text-gray-500">
|
||||
{dayjs.utc(subscription.dataRetentionGracePeriodUntil).local().fromNow()}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{isCanManageDBs && (
|
||||
<div className="mt-4 flex flex-wrap gap-2 border-t border-gray-200 pt-4 dark:border-gray-700">
|
||||
{canPurchase && (
|
||||
<Button type="primary" onClick={() => setIsPurchaseModalOpen(true)}>
|
||||
Purchase
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{canAccessPortal && (
|
||||
<>
|
||||
{subscription.status !== SubscriptionStatus.Canceled && (
|
||||
<Button onClick={() => setIsPurchaseModalOpen(true)}>Change storage</Button>
|
||||
)}
|
||||
{subscription.status === SubscriptionStatus.Canceled ? (
|
||||
<Button
|
||||
type="primary"
|
||||
loading={isPortalLoading}
|
||||
onClick={handlePortalClick}
|
||||
>
|
||||
Resume subscription
|
||||
</Button>
|
||||
) : (
|
||||
<Button loading={isPortalLoading} onClick={handlePortalClick}>
|
||||
Manage subscription
|
||||
</Button>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Invoices */}
|
||||
<h3 className="mt-6 mb-3 text-base font-bold dark:text-white">Invoices</h3>
|
||||
|
||||
<Table
|
||||
bordered
|
||||
size="small"
|
||||
columns={invoiceColumns}
|
||||
dataSource={invoices}
|
||||
rowKey="id"
|
||||
loading={isLoadingInvoices}
|
||||
pagination={false}
|
||||
/>
|
||||
|
||||
{totalInvoices > MAX_ROWS && (
|
||||
<p className="mt-1 text-xs text-gray-500">
|
||||
Showing {MAX_ROWS} of {totalInvoices} invoices
|
||||
</p>
|
||||
)}
|
||||
|
||||
{/* Activity */}
|
||||
<h3 className="mt-6 mb-3 text-base font-bold dark:text-white">Activity</h3>
|
||||
|
||||
<Table
|
||||
bordered
|
||||
size="small"
|
||||
columns={eventColumns}
|
||||
dataSource={events}
|
||||
rowKey="id"
|
||||
loading={isLoadingEvents}
|
||||
pagination={false}
|
||||
/>
|
||||
|
||||
{totalEvents > MAX_ROWS && (
|
||||
<p className="mt-1 text-xs text-gray-500">
|
||||
Showing {MAX_ROWS} of {totalEvents} events
|
||||
</p>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
|
||||
{isPurchaseModalOpen && (
|
||||
<PurchaseComponent
|
||||
databaseId={database.id}
|
||||
onSubscriptionChanged={handleSubscriptionChanged}
|
||||
onClose={() => setIsPurchaseModalOpen(false)}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
64
frontend/src/features/billing/ui/DbSizeCommands.tsx
Normal file
64
frontend/src/features/billing/ui/DbSizeCommands.tsx
Normal file
@@ -0,0 +1,64 @@
|
||||
import { useState } from 'react';
|
||||
|
||||
interface DbSizeCommand {
|
||||
label: string;
|
||||
code: string;
|
||||
}
|
||||
|
||||
interface Props {
|
||||
commands: DbSizeCommand[];
|
||||
}
|
||||
|
||||
export function DbSizeCommands({ commands }: Props) {
|
||||
const [copiedIndex, setCopiedIndex] = useState<number | null>(null);
|
||||
|
||||
return (
|
||||
<details className="group mb-2">
|
||||
<summary className="flex cursor-pointer list-none items-center gap-1.5 text-gray-500 transition-colors hover:text-gray-400">
|
||||
<svg
|
||||
width="12"
|
||||
height="12"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
className="transition-transform group-open:rotate-90"
|
||||
>
|
||||
<path d="M9 18l6-6-6-6" />
|
||||
</svg>
|
||||
How to check DB size?
|
||||
</summary>
|
||||
|
||||
<div className="mt-2 space-y-1.5">
|
||||
{commands.map((cmd, index) => (
|
||||
<div key={index}>
|
||||
<p className="mb-1 text-xs text-gray-400">{cmd.label}</p>
|
||||
<div className="relative">
|
||||
<pre className="overflow-x-auto rounded-lg border border-[#ffffff20] bg-[#1f2937] px-2.5 py-1.5 pr-16 text-xs">
|
||||
<code className="block whitespace-pre text-gray-300">{cmd.code}</code>
|
||||
</pre>
|
||||
<button
|
||||
onClick={async () => {
|
||||
try {
|
||||
await navigator.clipboard.writeText(cmd.code);
|
||||
setCopiedIndex(index);
|
||||
setTimeout(() => setCopiedIndex(null), 2000);
|
||||
} catch {
|
||||
/* ignore */
|
||||
}
|
||||
}}
|
||||
className={`absolute top-2 right-2 rounded border border-[#ffffff20] px-2 py-0.5 text-white transition-colors ${
|
||||
copiedIndex === index ? 'bg-green-500' : 'bg-blue-600 hover:bg-blue-700'
|
||||
}`}
|
||||
>
|
||||
{copiedIndex === index ? 'Copied!' : 'Copy'}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</details>
|
||||
);
|
||||
}
|
||||
74
frontend/src/features/billing/ui/PriceActionBar.tsx
Normal file
74
frontend/src/features/billing/ui/PriceActionBar.tsx
Normal file
@@ -0,0 +1,74 @@
|
||||
import { Button } from 'antd';
|
||||
|
||||
import { SubscriptionStatus } from '../../../entity/billing';
|
||||
|
||||
interface Props {
|
||||
monthlyPrice: number;
|
||||
currentPrice: number;
|
||||
isPurchaseFlow: boolean;
|
||||
isChangeFlow: boolean;
|
||||
isUpgrade: boolean;
|
||||
isDowngrade: boolean;
|
||||
isSameStorage: boolean;
|
||||
isSubmitting: boolean;
|
||||
subscriptionStatus: SubscriptionStatus;
|
||||
onPurchase: () => void;
|
||||
onChangeStorage: () => void;
|
||||
}
|
||||
|
||||
export function PriceActionBar({
|
||||
monthlyPrice,
|
||||
currentPrice,
|
||||
isPurchaseFlow,
|
||||
isChangeFlow,
|
||||
isUpgrade,
|
||||
isDowngrade,
|
||||
isSameStorage,
|
||||
isSubmitting,
|
||||
subscriptionStatus,
|
||||
onPurchase,
|
||||
onChangeStorage,
|
||||
}: Props) {
|
||||
return (
|
||||
<div className="mt-4 flex items-center gap-4 border-t border-[#ffffff20] pt-4">
|
||||
<div className="flex-1">
|
||||
<p className="text-2xl font-bold">
|
||||
${monthlyPrice.toFixed(2)}
|
||||
<span className="text-base font-medium text-gray-400">/mo</span>
|
||||
</p>
|
||||
|
||||
{isChangeFlow && !isSameStorage && (
|
||||
<p className="text-xs text-gray-400">Currently ${currentPrice.toFixed(2)}/mo</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col items-end gap-1">
|
||||
{isPurchaseFlow && (
|
||||
<Button type="primary" size="large" loading={isSubmitting} onClick={onPurchase}>
|
||||
{subscriptionStatus === SubscriptionStatus.Canceled ? 'Re-subscribe' : 'Purchase'}
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{isChangeFlow && (
|
||||
<>
|
||||
<Button
|
||||
type="primary"
|
||||
size="large"
|
||||
loading={isSubmitting}
|
||||
disabled={!!isSameStorage}
|
||||
onClick={onChangeStorage}
|
||||
>
|
||||
{isUpgrade ? 'Upgrade' : isDowngrade ? 'Downgrade' : 'Change Storage'}
|
||||
</Button>
|
||||
|
||||
{isDowngrade && (
|
||||
<p className="text-xs text-gray-500">
|
||||
Storage will be reduced from next billing cycle
|
||||
</p>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
.calcSlider {
|
||||
-webkit-appearance: none;
|
||||
appearance: none;
|
||||
width: 100%;
|
||||
height: 6px;
|
||||
border-radius: 3px;
|
||||
outline: none;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.calcSlider::-webkit-slider-thumb {
|
||||
-webkit-appearance: none;
|
||||
appearance: none;
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
border-radius: 50%;
|
||||
background: #155dfc;
|
||||
border: 2px solid #fff;
|
||||
cursor: pointer;
|
||||
transition: transform 0.15s;
|
||||
box-shadow: 0 0 8px rgba(21, 93, 252, 0.5);
|
||||
}
|
||||
|
||||
.calcSlider::-webkit-slider-thumb:hover {
|
||||
transform: scale(1.15);
|
||||
}
|
||||
|
||||
.calcSlider::-moz-range-thumb {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
border-radius: 50%;
|
||||
background: #155dfc;
|
||||
border: 2px solid #fff;
|
||||
cursor: pointer;
|
||||
box-shadow: 0 0 8px rgba(21, 93, 252, 0.5);
|
||||
}
|
||||
|
||||
.calcSlider::-moz-range-track {
|
||||
height: 6px;
|
||||
border-radius: 3px;
|
||||
background: #1f2937;
|
||||
}
|
||||
|
||||
.calcSlider::-moz-range-progress {
|
||||
background: #155dfc;
|
||||
border-radius: 3px;
|
||||
height: 6px;
|
||||
}
|
||||
|
||||
.calcSlider:focus-visible::-webkit-slider-thumb {
|
||||
box-shadow:
|
||||
0 0 0 3px rgba(21, 93, 252, 0.3),
|
||||
0 0 8px rgba(21, 93, 252, 0.5);
|
||||
}
|
||||
189
frontend/src/features/billing/ui/PurchaseComponent.tsx
Normal file
189
frontend/src/features/billing/ui/PurchaseComponent.tsx
Normal file
@@ -0,0 +1,189 @@
|
||||
import { Button, Modal, Spin } from 'antd';
|
||||
import { useEffect, useState } from 'react';
|
||||
|
||||
import { usePurchaseFlow } from '../hooks/usePurchaseFlow';
|
||||
|
||||
import { CLOUD_PRICE_PER_GB } from '../../../constants';
|
||||
import { SubscriptionStatus } from '../../../entity/billing';
|
||||
import {
|
||||
BACKUPS_COMPRESSION_RATIO,
|
||||
BACKUP_SIZE_STEPS,
|
||||
STORAGE_SIZE_STEPS,
|
||||
distributeGfs,
|
||||
formatSize,
|
||||
} from '../models/purchaseUtils';
|
||||
import { BackupRetentionSection } from './BackupRetentionSection';
|
||||
import { PriceActionBar } from './PriceActionBar';
|
||||
import { StorageSlidersSection } from './StorageSlidersSection';
|
||||
|
||||
interface Props {
|
||||
databaseId: string;
|
||||
onSubscriptionChanged: () => void;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
export function PurchaseComponent({ databaseId, onSubscriptionChanged, onClose }: Props) {
|
||||
const flow = usePurchaseFlow({ databaseId, onSubscriptionChanged, onClose });
|
||||
|
||||
const [storageSliderPos, setStorageSliderPos] = useState(0);
|
||||
const [backupSliderPos, setBackupSliderPos] = useState(0);
|
||||
|
||||
useEffect(() => {
|
||||
if (flow.initialSliderPos > 0) {
|
||||
setStorageSliderPos(flow.initialSliderPos);
|
||||
}
|
||||
}, [flow.initialSliderPos]);
|
||||
|
||||
const singleBackupSizeGb = BACKUP_SIZE_STEPS[backupSliderPos];
|
||||
const minStoragePosIndex = STORAGE_SIZE_STEPS.findIndex((s) => s >= singleBackupSizeGb);
|
||||
const minStoragePos =
|
||||
minStoragePosIndex === -1 ? STORAGE_SIZE_STEPS.length - 1 : minStoragePosIndex;
|
||||
const effectiveStoragePos = Math.max(storageSliderPos, minStoragePos);
|
||||
const newStorageGb = STORAGE_SIZE_STEPS[effectiveStoragePos];
|
||||
const approximateDbSize = singleBackupSizeGb * BACKUPS_COMPRESSION_RATIO;
|
||||
const backupsFit = Math.floor(newStorageGb / singleBackupSizeGb);
|
||||
const gfs = distributeGfs(backupsFit);
|
||||
const monthlyPrice = newStorageGb * CLOUD_PRICE_PER_GB;
|
||||
|
||||
const { subscription } = flow;
|
||||
|
||||
const isPurchaseFlow =
|
||||
subscription &&
|
||||
(subscription.status === SubscriptionStatus.Trial ||
|
||||
subscription.status === SubscriptionStatus.Canceled ||
|
||||
subscription.status === SubscriptionStatus.Expired);
|
||||
|
||||
const isChangeFlow =
|
||||
subscription &&
|
||||
(subscription.status === SubscriptionStatus.Active ||
|
||||
subscription.status === SubscriptionStatus.PastDue);
|
||||
|
||||
const isUpgrade = isChangeFlow && newStorageGb > subscription.storageGb;
|
||||
const isDowngrade = isChangeFlow && newStorageGb < subscription.storageGb;
|
||||
const isSameStorage = isChangeFlow && newStorageGb === subscription.storageGb;
|
||||
const currentPrice = subscription ? subscription.storageGb * CLOUD_PRICE_PER_GB : 0;
|
||||
|
||||
const modalTitle = isPurchaseFlow
|
||||
? subscription.status === SubscriptionStatus.Canceled
|
||||
? 'Re-subscribe'
|
||||
: 'Purchase subscription'
|
||||
: 'Change Storage';
|
||||
|
||||
const isShowingForm =
|
||||
subscription &&
|
||||
!flow.isLoadingSubscription &&
|
||||
!flow.isWaitingForUpgrade &&
|
||||
!flow.isUpgradeTimedOut &&
|
||||
!flow.isWaitingForPayment &&
|
||||
!flow.isPaymentConfirmed &&
|
||||
!flow.isPaymentTimedOut &&
|
||||
!flow.isCheckoutOpen;
|
||||
|
||||
return (
|
||||
<Modal
|
||||
title={modalTitle}
|
||||
open
|
||||
onCancel={onClose}
|
||||
footer={null}
|
||||
width={700}
|
||||
maskClosable={false}
|
||||
>
|
||||
{flow.isLoadingSubscription && (
|
||||
<div className="flex justify-center py-10">
|
||||
<Spin size="large" />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{flow.loadError && <div className="py-10 text-center text-red-500">{flow.loadError}</div>}
|
||||
|
||||
{flow.isWaitingForPayment && (
|
||||
<div className="flex flex-col items-center gap-4 py-10">
|
||||
<Spin size="large" />
|
||||
<p className="text-gray-400">Confirming your payment...</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{flow.isPaymentConfirmed && (
|
||||
<div className="py-6 text-center">
|
||||
<p className="mb-1 text-lg font-semibold text-green-600 dark:text-green-400">
|
||||
Payment successful!
|
||||
</p>
|
||||
{flow.confirmedStorageGb !== undefined && (
|
||||
<p className="mb-4 text-gray-500 dark:text-gray-400">
|
||||
Your subscription is now active with {flow.confirmedStorageGb} GB of storage.
|
||||
</p>
|
||||
)}
|
||||
<Button type="primary" onClick={onClose}>
|
||||
OK
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{flow.isPaymentTimedOut && (
|
||||
<div className="py-6 text-center">
|
||||
<p className="mb-4 text-yellow-500">
|
||||
Payment confirmation is taking longer than expected. Please reload the page to check the
|
||||
status.
|
||||
</p>
|
||||
<Button onClick={() => window.location.reload()}>Reload page</Button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{flow.isUpgradeTimedOut && (
|
||||
<div className="py-6 text-center">
|
||||
<p className="mb-2 text-yellow-500">
|
||||
Upgrade is taking longer than expected, it will be applied shortly. Please reload the
|
||||
page
|
||||
</p>
|
||||
<Button onClick={onClose}>Close</Button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{flow.isWaitingForUpgrade && !flow.isUpgradeTimedOut && (
|
||||
<div className="flex flex-col items-center gap-4 py-10">
|
||||
<Spin size="large" />
|
||||
<p className="text-gray-400">Waiting for storage upgrade confirmation...</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{isShowingForm && (
|
||||
<div>
|
||||
{isChangeFlow && subscription.pendingStorageGb !== undefined && (
|
||||
<div className="mb-4 rounded-lg border border-yellow-600/30 bg-yellow-900/20 px-4 py-3 text-sm text-yellow-400">
|
||||
Pending storage change to {formatSize(subscription.pendingStorageGb)} from next
|
||||
billing cycle
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="md:grid md:grid-cols-2 md:gap-6">
|
||||
<StorageSlidersSection
|
||||
onStorageSliderChange={setStorageSliderPos}
|
||||
backupSliderPos={backupSliderPos}
|
||||
onBackupSliderChange={setBackupSliderPos}
|
||||
effectiveStoragePos={effectiveStoragePos}
|
||||
newStorageGb={newStorageGb}
|
||||
singleBackupSizeGb={singleBackupSizeGb}
|
||||
approximateDbSize={approximateDbSize}
|
||||
/>
|
||||
|
||||
<BackupRetentionSection backupsFit={backupsFit} gfs={gfs} />
|
||||
</div>
|
||||
|
||||
<PriceActionBar
|
||||
monthlyPrice={monthlyPrice}
|
||||
currentPrice={currentPrice}
|
||||
isPurchaseFlow={!!isPurchaseFlow}
|
||||
isChangeFlow={!!isChangeFlow}
|
||||
isUpgrade={!!isUpgrade}
|
||||
isDowngrade={!!isDowngrade}
|
||||
isSameStorage={!!isSameStorage}
|
||||
isSubmitting={flow.isSubmitting}
|
||||
subscriptionStatus={subscription.status}
|
||||
onPurchase={() => flow.handlePurchase(newStorageGb)}
|
||||
onChangeStorage={() => flow.handleStorageChange(newStorageGb)}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user