From 4bee78646afbd81101b8c5afc786b7487ec2a154 Mon Sep 17 00:00:00 2001 From: Rostislav Dugin Date: Sat, 28 Mar 2026 22:02:22 +0300 Subject: [PATCH 1/3] REFACTOR (go): Refactor go to follow modern syntax guidelines --- AGENTS.md | 220 +++++++++++++----- agent/internal/config/config.go | 3 +- .../features/full_backup/backuper_test.go | 22 +- .../features/restore/restorer_test.go | 29 ++- .../features/start/lock_watcher_test.go | 8 +- agent/internal/features/wal/streamer.go | 4 +- agent/internal/features/wal/streamer_test.go | 18 +- agent/internal/logger/logger.go | 12 +- agent/internal/logger/logger_test.go | 2 +- backend/cmd/main.go | 2 +- backend/internal/config/config.go | 9 +- .../features/audit_logs/background_service.go | 42 ++-- .../audit_logs/background_service_test.go | 2 +- backend/internal/features/audit_logs/di.go | 28 +-- .../backups/backups/backuping/backuper.go | 116 +++++---- .../backups/backups/backuping/cleaner.go | 60 +++-- .../backups/backuping/cleaner_gfs_test.go | 28 +-- .../backups/backups/backuping/cleaner_test.go | 16 +- .../features/backups/backups/backuping/di.go | 5 - .../backups/backups/backuping/registry.go | 39 ++-- .../backups/backuping/registry_test.go | 36 ++- .../backups/backups/backuping/scheduler.go | 108 ++++----- .../backups/backups/backuping/testing.go | 5 - .../backups/backups/download/background.go | 52 ++--- .../features/backups/backups/download/di.go | 5 - .../features/backups/backups/services/di.go | 30 +-- .../internal/features/backups/config/di.go | 23 +- .../internal/features/backups/config/dto.go | 2 +- .../internal/features/backups/config/model.go | 4 +- .../features/billing/controller_test.go | 2 +- backend/internal/features/billing/di.go | 20 +- .../features/billing/models/invoice.go | 2 +- .../features/billing/models/subscription.go | 8 +- .../internal/features/billing/paddle/di.go | 66 +++--- backend/internal/features/billing/service.go | 72 +++--- .../databases/databases/mariadb/model_test.go | 13 +- .../databases/databases/mongodb/model_test.go | 34 +-- .../databases/databases/mysql/model_test.go | 13 +- .../databases/postgresql/model_test.go | 25 +- backend/internal/features/databases/di.go | 24 +- backend/internal/features/databases/model.go | 10 +- .../healthcheck/attempt/background_service.go | 41 ++-- .../features/healthcheck/attempt/di.go | 5 - .../features/healthcheck/config/di.go | 26 +-- backend/internal/features/notifiers/di.go | 22 +- backend/internal/features/notifiers/model.go | 12 +- .../features/notifiers/models/teams/model.go | 2 +- .../features/restores/controller_test.go | 2 +- backend/internal/features/restores/di.go | 24 +- .../features/restores/restoring/di.go | 5 - .../features/restores/restoring/dto.go | 8 +- .../features/restores/restoring/registry.go | 39 ++-- .../restores/restoring/registry_test.go | 36 ++- .../features/restores/restoring/restorer.go | 134 +++++------ .../restores/restoring/restorer_test.go | 6 +- .../features/restores/restoring/scheduler.go | 100 ++++---- .../restores/restoring/scheduler_test.go | 2 +- .../features/restores/restoring/testing.go | 4 - backend/internal/features/storages/di.go | 23 +- .../internal/features/storages/model_test.go | 6 +- .../tasks/cancellation/cancel_manager_test.go | 18 +- .../features/tasks/cancellation/di.go | 22 +- .../tests/mongodb_backup_restore_test.go | 18 +- .../users/controllers/password_reset_test.go | 4 +- backend/internal/storage/storage.go | 9 +- backend/internal/util/cache/cache.go | 52 ++--- backend/internal/util/cache/cache_test.go | 2 +- backend/internal/util/logger/logger.go | 146 ++++++------ .../util/logger/victorialogs_writer.go | 45 ++-- 69 files changed, 917 insertions(+), 1115 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 9d068e0..42599f6 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -20,6 +20,7 @@ This is NOT a strict set of rules, but a set of recommendations to help you writ - [Time handling](#time-handling) - [Logging](#logging) - [CRUD examples](#crud-examples) + - [Modern Go](#modern-go) - [Frontend guidelines](#frontend-guidelines) - [React component structure](#react-component-structure) @@ -598,7 +599,7 @@ func GetOrderRepository() *repositories.OrderRepository { #### SetupDependencies() pattern -**All `SetupDependencies()` functions must use sync.Once to ensure idempotent execution.** +**All `SetupDependencies()` functions must use `sync.OnceFunc` to ensure idempotent execution.** This pattern allows `SetupDependencies()` to be safely called multiple times (especially in tests) while ensuring the actual setup logic executes only once. @@ -609,45 +610,28 @@ package feature import ( "sync" - "sync/atomic" - "databasus-backend/internal/util/logger" ) -var ( - setupOnce sync.Once - isSetup atomic.Bool -) - -func SetupDependencies() { - wasAlreadySetup := isSetup.Load() - - setupOnce.Do(func() { - // Initialize dependencies here - someService.SetDependency(otherService) - anotherService.AddListener(listener) - - isSetup.Store(true) - }) - - if wasAlreadySetup { - logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call") - } -} +var SetupDependencies = sync.OnceFunc(func() { + // Initialize dependencies here + someService.SetDependency(otherService) + anotherService.AddListener(listener) +}) ``` **Why this pattern:** - **Tests can call multiple times**: Test setup often calls `SetupDependencies()` multiple times without issues - **Thread-safe**: Works correctly with concurrent calls (nanoseconds or seconds apart) -- **Idempotent**: Subsequent calls are safe, only log warning +- **Idempotent**: Subsequent calls are no-ops - **No panics**: Does not break tests or production code on multiple calls +- **Concise**: `sync.OnceFunc` (Go 1.21+) replaces the manual `sync.Once` + `atomic.Bool` + `Do()` boilerplate **Key Points:** -1. Check `isSetup.Load()` **before** calling `Do()` to detect previous executions -2. Set `isSetup.Store(true)` **inside** the `Do()` closure after setup completes -3. Log warning if already setup (helps identify unnecessary duplicate calls) -4. All setup logic must be inside the `Do()` closure +1. Use `sync.OnceFunc` instead of manual `sync.Once` + `atomic.Bool` pattern +2. All setup logic must be inside the `OnceFunc` closure +3. The returned function is safe to call concurrently and multiple times --- @@ -671,33 +655,26 @@ import ( type BackgroundService struct { // ... existing fields ... - runOnce sync.Once - hasRun atomic.Bool + hasRun atomic.Bool } func (s *BackgroundService) Run(ctx context.Context) { - wasAlreadyRun := s.hasRun.Load() - - s.runOnce.Do(func() { - s.hasRun.Store(true) - - // Existing infinite loop logic - ticker := time.NewTicker(1 * time.Minute) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - s.doWork() - } - } - }) - - if wasAlreadyRun { + if s.hasRun.Swap(true) { panic(fmt.Sprintf("%T.Run() called multiple times", s)) } + + // Existing infinite loop logic + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + s.doWork() + } + } } ``` @@ -718,11 +695,9 @@ func (s *BackgroundService) Run(ctx context.Context) { **Key Points:** -1. Check `hasRun.Load()` **before** calling `Do()` to detect previous executions -2. Set `hasRun.Store(true)` **inside** the `Do()` closure before starting work -3. **Always panic** if already run (never just log warning) -4. All run logic must be inside the `Do()` closure -5. This pattern is **thread-safe** for any timing (concurrent or sequential calls) +1. Use `atomic.Bool.Swap(true)` to atomically check-and-set in one call — no need for `sync.Once` +2. **Always panic** if already run (never just log warning) +3. This pattern is **thread-safe** for any timing (concurrent or sequential calls) --- @@ -1409,6 +1384,141 @@ func extractMessages(logs []*AuditLog) []string { --- +### Modern Go + +Prefer modern Go stdlib idioms over manual equivalents. Use these patterns consistently. + +#### `slices` package — avoid manual loops + +```go +slices.Contains(items, x) // instead of manual loop +slices.Index(items, x) // returns index or -1 +slices.IndexFunc(items, func(item T) bool { return item.ID == id }) +slices.SortFunc(items, func(a, b T) int { return cmp.Compare(a.X, b.X) }) +slices.Sort(items) // for ordered types +slices.Max(items) / slices.Min(items) // instead of manual loop +slices.Reverse(items) // in-place +slices.Compact(items) // remove consecutive duplicates +slices.Clone(s) // shallow copy +slices.Clip(s) // trim unused capacity +``` + +#### `any` instead of `interface{}` + +```go +// good +func process(value any) {} + +// bad +func process(value interface{}) {} +``` + +#### `sync.OnceFunc` / `sync.OnceValue` + +```go +// instead of sync.Once + wrapper +f := sync.OnceFunc(func() { initialize() }) + +// compute-once getter +getValue := sync.OnceValue(func() int { return expensiveComputation() }) +``` + +#### `context` helpers + +```go +stop := context.AfterFunc(ctx, cleanup) // run cleanup on cancellation +ctx, cancel := context.WithTimeoutCause(parent, d, ErrTimeout) // timeout with cause +ctx, cancel := context.WithDeadlineCause(parent, deadline, ErrDeadline) // deadline with cause +``` + +#### Range over integer + +```go +// good +for i := range len(items) { ... } + +// bad +for i := 0; i < len(items); i++ { ... } +``` + +#### `t.Context()` in tests + +Always use `t.Context()` — it cancels automatically when the test ends. + +```go +// good +func TestFoo(t *testing.T) { + ctx := t.Context() + result := doSomething(ctx) +} + +// bad +func TestFoo(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + result := doSomething(ctx) +} +``` + +#### `omitzero` instead of `omitempty` + +Use `omitzero` for `time.Duration`, `time.Time`, structs, slices, and maps — `omitempty` does not work correctly for these types. + +```go +// good +type Config struct { + Timeout time.Duration `json:"timeout,omitzero"` + CreatedAt time.Time `json:"createdAt,omitzero"` +} + +// bad +type Config struct { + Timeout time.Duration `json:"timeout,omitempty"` // broken for Duration! + CreatedAt time.Time `json:"createdAt,omitempty"` +} +``` + +#### `wg.Go()` instead of `wg.Add(1)` + goroutine + +```go +// good +var wg sync.WaitGroup +for _, item := range items { + wg.Go(func() { process(item) }) +} +wg.Wait() + +// bad +var wg sync.WaitGroup +for _, item := range items { + wg.Add(1) + go func() { + defer wg.Done() + process(item) + }() +} +wg.Wait() +``` + +#### `new(val)` for pointer literals + +`new()` accepts expressions since Go 1.26 — avoids the temporary-variable pattern. + +```go +// good +cfg := Config{ + Timeout: new(30), // *int + Debug: new(true), // *bool +} + +// bad +timeout := 30 +debug := true +cfg := Config{Timeout: &timeout, Debug: &debug} +``` + +--- + ## Frontend guidelines ### React component structure diff --git a/agent/internal/config/config.go b/agent/internal/config/config.go index a70e5a2..2fbc856 100644 --- a/agent/internal/config/config.go +++ b/agent/internal/config/config.go @@ -110,8 +110,7 @@ func (c *Config) applyDefaults() { } if c.IsDeleteWalAfterUpload == nil { - v := true - c.IsDeleteWalAfterUpload = &v + c.IsDeleteWalAfterUpload = new(true) } } diff --git a/agent/internal/features/full_backup/backuper_test.go b/agent/internal/features/full_backup/backuper_test.go index 0f4a8f1..7d423a4 100644 --- a/agent/internal/features/full_backup/backuper_test.go +++ b/agent/internal/features/full_backup/backuper_test.go @@ -71,7 +71,7 @@ func Test_RunFullBackup_WhenChainBroken_BasebackupTriggered(t *testing.T) { fb := newTestFullBackuper(server.URL) fb.cmdBuilder = mockCmdBuilder(t, "test-backup-data", validStderr()) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) defer cancel() go fb.Run(ctx) @@ -124,7 +124,7 @@ func Test_RunFullBackup_WhenScheduledBackupDue_BasebackupTriggered(t *testing.T) fb := newTestFullBackuper(server.URL) fb.cmdBuilder = mockCmdBuilder(t, "scheduled-backup-data", validStderr()) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) defer cancel() go fb.Run(ctx) @@ -169,7 +169,7 @@ func Test_RunFullBackup_WhenNoFullBackupExists_ImmediateBasebackupTriggered(t *t fb := newTestFullBackuper(server.URL) fb.cmdBuilder = mockCmdBuilder(t, "first-backup-data", validStderr()) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) defer cancel() go fb.Run(ctx) @@ -233,7 +233,7 @@ func Test_RunFullBackup_WhenUploadFails_RetriesAfterDelay(t *testing.T) { setRetryDelay(100 * time.Millisecond) defer setRetryDelay(origRetryDelay) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) defer cancel() go fb.Run(ctx) @@ -282,7 +282,7 @@ func Test_RunFullBackup_WhenAlreadyRunning_SkipsExecution(t *testing.T) { fb.isRunning.Store(true) - fb.checkAndRunIfNeeded(context.Background()) + fb.checkAndRunIfNeeded(t.Context()) mu.Lock() count := uploadCount @@ -318,7 +318,7 @@ func Test_RunFullBackup_WhenContextCancelled_StopsCleanly(t *testing.T) { setRetryDelay(5 * time.Second) defer setRetryDelay(origRetryDelay) - ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + ctx, cancel := context.WithTimeout(t.Context(), 500*time.Millisecond) defer cancel() done := make(chan struct{}) @@ -360,7 +360,7 @@ func Test_RunFullBackup_WhenChainValidAndNotScheduled_NoBasebackupTriggered(t *t fb := newTestFullBackuper(server.URL) fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr()) - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second) defer cancel() go fb.Run(ctx) @@ -411,7 +411,7 @@ func Test_RunFullBackup_WhenStderrParsingFails_FinalizesWithErrorAndRetries(t *t setRetryDelay(100 * time.Millisecond) defer setRetryDelay(origRetryDelay) - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second) defer cancel() go fb.Run(ctx) @@ -458,7 +458,7 @@ func Test_RunFullBackup_WhenNextBackupTimeNull_BasebackupTriggered(t *testing.T) fb := newTestFullBackuper(server.URL) fb.cmdBuilder = mockCmdBuilder(t, "first-run-data", validStderr()) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) defer cancel() go fb.Run(ctx) @@ -498,7 +498,7 @@ func Test_RunFullBackup_WhenChainValidityReturns401_NoBasebackupTriggered(t *tes fb := newTestFullBackuper(server.URL) fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr()) - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second) defer cancel() go fb.Run(ctx) @@ -538,7 +538,7 @@ func Test_RunFullBackup_WhenUploadSucceeds_BodyIsZstdCompressed(t *testing.T) { fb := newTestFullBackuper(server.URL) fb.cmdBuilder = mockCmdBuilder(t, originalContent, validStderr()) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) defer cancel() go fb.Run(ctx) diff --git a/agent/internal/features/restore/restorer_test.go b/agent/internal/features/restore/restorer_test.go index 8883edd..385448b 100644 --- a/agent/internal/features/restore/restorer_test.go +++ b/agent/internal/features/restore/restorer_test.go @@ -3,7 +3,6 @@ package restore import ( "archive/tar" "bytes" - "context" "encoding/json" "fmt" "net/http" @@ -86,7 +85,7 @@ func Test_RunRestore_WhenBasebackupAndWalSegmentsAvailable_FilesExtractedAndReco targetDir := createTestTargetDir(t) restorer := newTestRestorer(server.URL, targetDir, "", "", "") - err := restorer.Run(context.Background()) + err := restorer.Run(t.Context()) require.NoError(t, err) pgVersionContent, err := os.ReadFile(filepath.Join(targetDir, "PG_VERSION")) @@ -152,7 +151,7 @@ func Test_RunRestore_WhenTargetTimeProvided_RecoveryTargetTimeWrittenToConfig(t targetDir := createTestTargetDir(t) restorer := newTestRestorer(server.URL, targetDir, "", "2026-02-28T14:30:00Z", "") - err := restorer.Run(context.Background()) + err := restorer.Run(t.Context()) require.NoError(t, err) autoConfContent, err := os.ReadFile(filepath.Join(targetDir, "postgresql.auto.conf")) @@ -169,7 +168,7 @@ func Test_RunRestore_WhenPgDataDirNotEmpty_ReturnsError(t *testing.T) { restorer := newTestRestorer("http://localhost:0", targetDir, "", "", "") - err = restorer.Run(context.Background()) + err = restorer.Run(t.Context()) require.Error(t, err) assert.Contains(t, err.Error(), "not empty") } @@ -179,7 +178,7 @@ func Test_RunRestore_WhenPgDataDirDoesNotExist_ReturnsError(t *testing.T) { restorer := newTestRestorer("http://localhost:0", nonExistentDir, "", "", "") - err := restorer.Run(context.Background()) + err := restorer.Run(t.Context()) require.Error(t, err) assert.Contains(t, err.Error(), "does not exist") } @@ -197,7 +196,7 @@ func Test_RunRestore_WhenNoBackupsAvailable_ReturnsError(t *testing.T) { targetDir := createTestTargetDir(t) restorer := newTestRestorer(server.URL, targetDir, "", "", "") - err := restorer.Run(context.Background()) + err := restorer.Run(t.Context()) require.Error(t, err) assert.Contains(t, err.Error(), "No full backups available") } @@ -216,7 +215,7 @@ func Test_RunRestore_WhenWalChainBroken_ReturnsError(t *testing.T) { targetDir := createTestTargetDir(t) restorer := newTestRestorer(server.URL, targetDir, "", "", "") - err := restorer.Run(context.Background()) + err := restorer.Run(t.Context()) require.Error(t, err) assert.Contains(t, err.Error(), "WAL chain broken") assert.Contains(t, err.Error(), testWalSegment1) @@ -282,7 +281,7 @@ func Test_DownloadWalSegment_WhenFirstAttemptFails_RetriesAndSucceeds(t *testing retryDelayOverride = &testDelay defer func() { retryDelayOverride = origDelay }() - err := restorer.Run(context.Background()) + err := restorer.Run(t.Context()) require.NoError(t, err) mu.Lock() @@ -341,7 +340,7 @@ func Test_DownloadWalSegment_WhenAllAttemptsFail_ReturnsErrorWithSegmentName(t * retryDelayOverride = &testDelay defer func() { retryDelayOverride = origDelay }() - err := restorer.Run(context.Background()) + err := restorer.Run(t.Context()) require.Error(t, err) assert.Contains(t, err.Error(), testWalSegment1) assert.Contains(t, err.Error(), "3 attempts") @@ -351,7 +350,7 @@ func Test_RunRestore_WhenInvalidTargetTimeFormat_ReturnsError(t *testing.T) { targetDir := createTestTargetDir(t) restorer := newTestRestorer("http://localhost:0", targetDir, "", "not-a-valid-time", "") - err := restorer.Run(context.Background()) + err := restorer.Run(t.Context()) require.Error(t, err) assert.Contains(t, err.Error(), "invalid --target-time format") } @@ -384,7 +383,7 @@ func Test_RunRestore_WhenBasebackupDownloadFails_ReturnsError(t *testing.T) { targetDir := createTestTargetDir(t) restorer := newTestRestorer(server.URL, targetDir, "", "", "") - err := restorer.Run(context.Background()) + err := restorer.Run(t.Context()) require.Error(t, err) assert.Contains(t, err.Error(), "basebackup download failed") } @@ -423,7 +422,7 @@ func Test_RunRestore_WhenNoWalSegmentsInPlan_BasebackupRestoredSuccessfully(t *t targetDir := createTestTargetDir(t) restorer := newTestRestorer(server.URL, targetDir, "", "", "") - err := restorer.Run(context.Background()) + err := restorer.Run(t.Context()) require.NoError(t, err) pgVersionContent, err := os.ReadFile(filepath.Join(targetDir, "PG_VERSION")) @@ -486,7 +485,7 @@ func Test_RunRestore_WhenMakingApiCalls_AuthTokenIncludedInRequests(t *testing.T targetDir := createTestTargetDir(t) restorer := newTestRestorer(server.URL, targetDir, "", "", "") - err := restorer.Run(context.Background()) + err := restorer.Run(t.Context()) require.NoError(t, err) assert.GreaterOrEqual(t, int(receivedAuthHeaders.Load()), 2) @@ -530,7 +529,7 @@ func Test_ConfigurePostgresRecovery_WhenPgTypeHost_UsesHostAbsolutePath(t *testi targetDir := createTestTargetDir(t) restorer := newTestRestorer(server.URL, targetDir, "", "", "host") - err := restorer.Run(context.Background()) + err := restorer.Run(t.Context()) require.NoError(t, err) autoConfContent, err := os.ReadFile(filepath.Join(targetDir, "postgresql.auto.conf")) @@ -577,7 +576,7 @@ func Test_ConfigurePostgresRecovery_WhenPgTypeDocker_UsesContainerPath(t *testin targetDir := createTestTargetDir(t) restorer := newTestRestorer(server.URL, targetDir, "", "", "docker") - err := restorer.Run(context.Background()) + err := restorer.Run(t.Context()) require.NoError(t, err) autoConfContent, err := os.ReadFile(filepath.Join(targetDir, "postgresql.auto.conf")) diff --git a/agent/internal/features/start/lock_watcher_test.go b/agent/internal/features/start/lock_watcher_test.go index d1e3ab2..55352de 100644 --- a/agent/internal/features/start/lock_watcher_test.go +++ b/agent/internal/features/start/lock_watcher_test.go @@ -21,7 +21,7 @@ func Test_NewLockWatcher_CapturesInode(t *testing.T) { require.NoError(t, err) defer ReleaseLock(lockFile) - _, cancel := context.WithCancel(context.Background()) + _, cancel := context.WithCancel(t.Context()) defer cancel() watcher, err := NewLockWatcher(lockFile, cancel, log) @@ -37,7 +37,7 @@ func Test_LockWatcher_FileUnchanged_ContextNotCancelled(t *testing.T) { require.NoError(t, err) defer ReleaseLock(lockFile) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() watcher, err := NewLockWatcher(lockFile, cancel, log) @@ -62,7 +62,7 @@ func Test_LockWatcher_FileDeleted_CancelsContext(t *testing.T) { require.NoError(t, err) defer ReleaseLock(lockFile) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() watcher, err := NewLockWatcher(lockFile, cancel, log) @@ -88,7 +88,7 @@ func Test_LockWatcher_FileReplacedWithDifferentInode_CancelsContext(t *testing.T require.NoError(t, err) defer ReleaseLock(lockFile) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() watcher, err := NewLockWatcher(lockFile, cancel, log) diff --git a/agent/internal/features/wal/streamer.go b/agent/internal/features/wal/streamer.go index 955a027..7ae5337 100644 --- a/agent/internal/features/wal/streamer.go +++ b/agent/internal/features/wal/streamer.go @@ -8,7 +8,7 @@ import ( "os" "path/filepath" "regexp" - "sort" + "slices" "strings" "time" @@ -113,7 +113,7 @@ func (s *Streamer) listSegments() ([]string, error) { segments = append(segments, name) } - sort.Strings(segments) + slices.Sort(segments) return segments, nil } diff --git a/agent/internal/features/wal/streamer_test.go b/agent/internal/features/wal/streamer_test.go index b9deceb..fc161a8 100644 --- a/agent/internal/features/wal/streamer_test.go +++ b/agent/internal/features/wal/streamer_test.go @@ -42,7 +42,7 @@ func Test_UploadSegment_SingleSegment_ServerReceivesCorrectHeadersAndBody(t *tes streamer := newTestStreamer(walDir, server.URL) - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second) defer cancel() go streamer.Run(ctx) @@ -79,7 +79,7 @@ func Test_UploadSegments_MultipleSegmentsOutOfOrder_UploadedInAscendingOrder(t * streamer := newTestStreamer(walDir, server.URL) - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second) defer cancel() go streamer.Run(ctx) @@ -115,7 +115,7 @@ func Test_UploadSegments_DirectoryHasTmpFiles_TmpFilesIgnored(t *testing.T) { streamer := newTestStreamer(walDir, server.URL) - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second) defer cancel() go streamer.Run(ctx) @@ -146,7 +146,7 @@ func Test_UploadSegment_DeleteEnabled_FileRemovedAfterUpload(t *testing.T) { apiClient := api.NewClient(server.URL, cfg.Token, logger.GetLogger()) streamer := NewStreamer(cfg, apiClient, logger.GetLogger()) - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second) defer cancel() go streamer.Run(ctx) @@ -174,7 +174,7 @@ func Test_UploadSegment_DeleteDisabled_FileKeptAfterUpload(t *testing.T) { apiClient := api.NewClient(server.URL, cfg.Token, logger.GetLogger()) streamer := NewStreamer(cfg, apiClient, logger.GetLogger()) - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second) defer cancel() go streamer.Run(ctx) @@ -199,7 +199,7 @@ func Test_UploadSegment_ServerReturns500_FileKeptInQueue(t *testing.T) { streamer := newTestStreamer(walDir, server.URL) - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second) defer cancel() go streamer.Run(ctx) @@ -223,7 +223,7 @@ func Test_ProcessQueue_EmptyDirectory_NoUploads(t *testing.T) { streamer := newTestStreamer(walDir, server.URL) - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second) defer cancel() go streamer.Run(ctx) @@ -238,7 +238,7 @@ func Test_Run_ContextCancelled_StopsImmediately(t *testing.T) { streamer := newTestStreamer(walDir, "http://localhost:0") - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) cancel() done := make(chan struct{}) @@ -276,7 +276,7 @@ func Test_UploadSegment_ServerReturns409_FileNotDeleted(t *testing.T) { streamer := newTestStreamer(walDir, server.URL) - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second) defer cancel() go streamer.Run(ctx) diff --git a/agent/internal/logger/logger.go b/agent/internal/logger/logger.go index f57c3dc..588d484 100644 --- a/agent/internal/logger/logger.go +++ b/agent/internal/logger/logger.go @@ -64,16 +64,12 @@ func (w *rotatingWriter) rotate() error { return nil } -var ( - loggerInstance *slog.Logger - once sync.Once -) +var loggerInstance *slog.Logger + +var initLogger = sync.OnceFunc(initialize) func GetLogger() *slog.Logger { - once.Do(func() { - initialize() - }) - + initLogger() return loggerInstance } diff --git a/agent/internal/logger/logger_test.go b/agent/internal/logger/logger_test.go index 022733e..c87b9e8 100644 --- a/agent/internal/logger/logger_test.go +++ b/agent/internal/logger/logger_test.go @@ -67,7 +67,7 @@ func Test_Write_MultipleSmallWrites_CurrentSizeAccumulated(t *testing.T) { rw, _, _ := setupRotatingWriter(t, 1024) var totalWritten int64 - for i := 0; i < 10; i++ { + for range 10 { data := []byte("line\n") n, err := rw.Write(data) require.NoError(t, err) diff --git a/backend/cmd/main.go b/backend/cmd/main.go index 9932108..6dec342 100644 --- a/backend/cmd/main.go +++ b/backend/cmd/main.go @@ -193,7 +193,7 @@ func startServerWithGracefulShutdown(log *slog.Logger, app *gin.Engine) { log.Info("Shutdown signal received") // Gracefully shutdown VictoriaLogs writer - logger.ShutdownVictoriaLogs(5 * time.Second) + logger.ShutdownVictoriaLogs() // The context is used to inform the server it has 10 seconds to finish // the request it is currently handling diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 7e8e7fa..bca600e 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -142,13 +142,12 @@ type EnvVariables struct { DatabasusURL string `env:"DATABASUS_URL"` } -var ( - env EnvVariables - once sync.Once -) +var env EnvVariables + +var initEnv = sync.OnceFunc(loadEnvVariables) func GetEnv() *EnvVariables { - once.Do(loadEnvVariables) + initEnv() return &env } diff --git a/backend/internal/features/audit_logs/background_service.go b/backend/internal/features/audit_logs/background_service.go index 5fc4812..ababa02 100644 --- a/backend/internal/features/audit_logs/background_service.go +++ b/backend/internal/features/audit_logs/background_service.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "log/slog" - "sync" "sync/atomic" "time" ) @@ -13,39 +12,32 @@ type AuditLogBackgroundService struct { auditLogService *AuditLogService logger *slog.Logger - runOnce sync.Once - hasRun atomic.Bool + hasRun atomic.Bool } func (s *AuditLogBackgroundService) Run(ctx context.Context) { - wasAlreadyRun := s.hasRun.Load() + if s.hasRun.Swap(true) { + panic(fmt.Sprintf("%T.Run() called multiple times", s)) + } - s.runOnce.Do(func() { - s.hasRun.Store(true) + s.logger.Info("Starting audit log cleanup background service") - s.logger.Info("Starting audit log cleanup background service") + if ctx.Err() != nil { + return + } - if ctx.Err() != nil { + ticker := time.NewTicker(1 * time.Hour) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): return - } - - ticker := time.NewTicker(1 * time.Hour) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - if err := s.cleanOldAuditLogs(); err != nil { - s.logger.Error("Failed to clean old audit logs", "error", err) - } + case <-ticker.C: + if err := s.cleanOldAuditLogs(); err != nil { + s.logger.Error("Failed to clean old audit logs", "error", err) } } - }) - - if wasAlreadyRun { - panic(fmt.Sprintf("%T.Run() called multiple times", s)) } } diff --git a/backend/internal/features/audit_logs/background_service_test.go b/backend/internal/features/audit_logs/background_service_test.go index 9087010..9e6a046 100644 --- a/backend/internal/features/audit_logs/background_service_test.go +++ b/backend/internal/features/audit_logs/background_service_test.go @@ -102,7 +102,7 @@ func Test_CleanOldAuditLogs_DeletesMultipleOldLogs(t *testing.T) { // Create many old logs with specific UUIDs to track them testLogIDs := make([]uuid.UUID, 5) - for i := 0; i < 5; i++ { + for i := range 5 { testLogIDs[i] = uuid.New() daysAgo := 400 + (i * 10) log := &AuditLog{ diff --git a/backend/internal/features/audit_logs/di.go b/backend/internal/features/audit_logs/di.go index 9264780..a0640db 100644 --- a/backend/internal/features/audit_logs/di.go +++ b/backend/internal/features/audit_logs/di.go @@ -2,7 +2,6 @@ package audit_logs import ( "sync" - "sync/atomic" users_services "databasus-backend/internal/features/users/services" "databasus-backend/internal/util/logger" @@ -23,8 +22,6 @@ var auditLogController = &AuditLogController{ var auditLogBackgroundService = &AuditLogBackgroundService{ auditLogService: auditLogService, logger: logger.GetLogger(), - runOnce: sync.Once{}, - hasRun: atomic.Bool{}, } func GetAuditLogService() *AuditLogService { @@ -39,23 +36,8 @@ func GetAuditLogBackgroundService() *AuditLogBackgroundService { return auditLogBackgroundService } -var ( - setupOnce sync.Once - isSetup atomic.Bool -) - -func SetupDependencies() { - wasAlreadySetup := isSetup.Load() - - setupOnce.Do(func() { - users_services.GetUserService().SetAuditLogWriter(auditLogService) - users_services.GetSettingsService().SetAuditLogWriter(auditLogService) - users_services.GetManagementService().SetAuditLogWriter(auditLogService) - - isSetup.Store(true) - }) - - if wasAlreadySetup { - logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call") - } -} +var SetupDependencies = sync.OnceFunc(func() { + users_services.GetUserService().SetAuditLogWriter(auditLogService) + users_services.GetSettingsService().SetAuditLogWriter(auditLogService) + users_services.GetManagementService().SetAuditLogWriter(auditLogService) +}) diff --git a/backend/internal/features/backups/backups/backuping/backuper.go b/backend/internal/features/backups/backups/backuping/backuper.go index f9c4552..5a98ed1 100644 --- a/backend/internal/features/backups/backups/backuping/backuper.go +++ b/backend/internal/features/backups/backups/backuping/backuper.go @@ -9,7 +9,6 @@ import ( "log/slog" "slices" "strings" - "sync" "sync/atomic" "time" @@ -46,80 +45,73 @@ type BackuperNode struct { lastHeartbeat time.Time - runOnce sync.Once - hasRun atomic.Bool + hasRun atomic.Bool } func (n *BackuperNode) Run(ctx context.Context) { - wasAlreadyRun := n.hasRun.Load() + if n.hasRun.Swap(true) { + panic(fmt.Sprintf("%T.Run() called multiple times", n)) + } - n.runOnce.Do(func() { - n.hasRun.Store(true) + n.lastHeartbeat = time.Now().UTC() - n.lastHeartbeat = time.Now().UTC() + throughputMBs := config.GetEnv().NodeNetworkThroughputMBs - throughputMBs := config.GetEnv().NodeNetworkThroughputMBs + backupNode := BackupNode{ + ID: n.nodeID, + ThroughputMBs: throughputMBs, + LastHeartbeat: time.Now().UTC(), + } - backupNode := BackupNode{ - ID: n.nodeID, - ThroughputMBs: throughputMBs, - LastHeartbeat: time.Now().UTC(), - } + if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil { + n.logger.Error("Failed to register node in registry", "error", err) + panic(err) + } - if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil { - n.logger.Error("Failed to register node in registry", "error", err) - panic(err) - } - - backupHandler := func(backupID uuid.UUID, isCallNotifier bool) { - go func() { - n.MakeBackup(backupID, isCallNotifier) - if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil { - n.logger.Error( - "Failed to publish backup completion", - "error", - err, - "backupID", - backupID, - ) - } - }() - } - - err := n.backupNodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID, backupHandler) - if err != nil { - n.logger.Error("Failed to subscribe to backup assignments", "error", err) - panic(err) - } - defer func() { - if err := n.backupNodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil { - n.logger.Error("Failed to unsubscribe from backup assignments", "error", err) + backupHandler := func(backupID uuid.UUID, isCallNotifier bool) { + go func() { + n.MakeBackup(backupID, isCallNotifier) + if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil { + n.logger.Error( + "Failed to publish backup completion", + "error", + err, + "backupID", + backupID, + ) } }() + } - ticker := time.NewTicker(heartbeatTickerInterval) - defer ticker.Stop() - - n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs) - - for { - select { - case <-ctx.Done(): - n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID) - - if err := n.backupNodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil { - n.logger.Error("Failed to unregister node from registry", "error", err) - } - - return - case <-ticker.C: - n.sendHeartbeat(&backupNode) - } + err := n.backupNodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID, backupHandler) + if err != nil { + n.logger.Error("Failed to subscribe to backup assignments", "error", err) + panic(err) + } + defer func() { + if err := n.backupNodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil { + n.logger.Error("Failed to unsubscribe from backup assignments", "error", err) } - }) + }() - if wasAlreadyRun { - panic(fmt.Sprintf("%T.Run() called multiple times", n)) + ticker := time.NewTicker(heartbeatTickerInterval) + defer ticker.Stop() + + n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs) + + for { + select { + case <-ctx.Done(): + n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID) + + if err := n.backupNodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil { + n.logger.Error("Failed to unregister node from registry", "error", err) + } + + return + case <-ticker.C: + n.sendHeartbeat(&backupNode) + } } } diff --git a/backend/internal/features/backups/backups/backuping/cleaner.go b/backend/internal/features/backups/backups/backuping/cleaner.go index 0016581..97449ea 100644 --- a/backend/internal/features/backups/backups/backuping/cleaner.go +++ b/backend/internal/features/backups/backups/backuping/cleaner.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "log/slog" - "sync" "sync/atomic" "time" @@ -32,49 +31,42 @@ type BackupCleaner struct { logger *slog.Logger backupRemoveListeners []backups_core.BackupRemoveListener - runOnce sync.Once - hasRun atomic.Bool + hasRun atomic.Bool } func (c *BackupCleaner) Run(ctx context.Context) { - wasAlreadyRun := c.hasRun.Load() + if c.hasRun.Swap(true) { + panic(fmt.Sprintf("%T.Run() called multiple times", c)) + } - c.runOnce.Do(func() { - c.hasRun.Store(true) + if ctx.Err() != nil { + return + } - if ctx.Err() != nil { + retentionLog := c.logger.With("task_name", "clean_by_retention_policy") + exceededLog := c.logger.With("task_name", "clean_exceeded_storage_backups") + staleLog := c.logger.With("task_name", "clean_stale_basebackups") + + ticker := time.NewTicker(cleanerTickerInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): return - } + case <-ticker.C: + if err := c.cleanByRetentionPolicy(retentionLog); err != nil { + retentionLog.Error("failed to clean backups by retention policy", "error", err) + } - retentionLog := c.logger.With("task_name", "clean_by_retention_policy") - exceededLog := c.logger.With("task_name", "clean_exceeded_storage_backups") - staleLog := c.logger.With("task_name", "clean_stale_basebackups") + if err := c.cleanExceededStorageBackups(exceededLog); err != nil { + exceededLog.Error("failed to clean exceeded backups", "error", err) + } - ticker := time.NewTicker(cleanerTickerInterval) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - if err := c.cleanByRetentionPolicy(retentionLog); err != nil { - retentionLog.Error("failed to clean backups by retention policy", "error", err) - } - - if err := c.cleanExceededStorageBackups(exceededLog); err != nil { - exceededLog.Error("failed to clean exceeded backups", "error", err) - } - - if err := c.cleanStaleUploadedBasebackups(staleLog); err != nil { - staleLog.Error("failed to clean stale uploaded basebackups", "error", err) - } + if err := c.cleanStaleUploadedBasebackups(staleLog); err != nil { + staleLog.Error("failed to clean stale uploaded basebackups", "error", err) } } - }) - - if wasAlreadyRun { - panic(fmt.Sprintf("%T.Run() called multiple times", c)) } } diff --git a/backend/internal/features/backups/backups/backuping/cleaner_gfs_test.go b/backend/internal/features/backups/backups/backuping/cleaner_gfs_test.go index a1de206..87b2d0c 100644 --- a/backend/internal/features/backups/backups/backuping/cleaner_gfs_test.go +++ b/backend/internal/features/backups/backups/backuping/cleaner_gfs_test.go @@ -33,7 +33,7 @@ func Test_BuildGFSKeepSet(t *testing.T) { // backupsEveryDay returns n backups, newest-first, each 1 day apart. backupsEveryDay := func(n int) []*backups_core.Backup { bs := make([]*backups_core.Backup, n) - for i := 0; i < n; i++ { + for i := range n { bs[i] = newBackup(ref.Add(-time.Duration(i) * day)) } return bs @@ -42,7 +42,7 @@ func Test_BuildGFSKeepSet(t *testing.T) { // backupsEveryWeek returns n backups, newest-first, each 7 days apart. backupsEveryWeek := func(n int) []*backups_core.Backup { bs := make([]*backups_core.Backup, n) - for i := 0; i < n; i++ { + for i := range n { bs[i] = newBackup(ref.Add(-time.Duration(i) * week)) } return bs @@ -53,7 +53,7 @@ func Test_BuildGFSKeepSet(t *testing.T) { // backupsEveryHour returns n backups, newest-first, each 1 hour apart. backupsEveryHour := func(n int) []*backups_core.Backup { bs := make([]*backups_core.Backup, n) - for i := 0; i < n; i++ { + for i := range n { bs[i] = newBackup(ref.Add(-time.Duration(i) * hour)) } return bs @@ -62,7 +62,7 @@ func Test_BuildGFSKeepSet(t *testing.T) { // backupsEveryMonth returns n backups, newest-first, each ~1 month apart. backupsEveryMonth := func(n int) []*backups_core.Backup { bs := make([]*backups_core.Backup, n) - for i := 0; i < n; i++ { + for i := range n { bs[i] = newBackup(ref.AddDate(0, -i, 0)) } return bs @@ -71,7 +71,7 @@ func Test_BuildGFSKeepSet(t *testing.T) { // backupsEveryYear returns n backups, newest-first, each 1 year apart. backupsEveryYear := func(n int) []*backups_core.Backup { bs := make([]*backups_core.Backup, n) - for i := 0; i < n; i++ { + for i := range n { bs[i] = newBackup(ref.AddDate(-i, 0, 0)) } return bs @@ -410,7 +410,7 @@ func Test_CleanByGFS_KeepsCorrectBackupsPerSlot(t *testing.T) { // Create 5 backups on 5 different days; only the 3 newest days should be kept var backupIDs []uuid.UUID - for i := 0; i < 5; i++ { + for i := range 5 { backup := &backups_core.Backup{ ID: uuid.New(), DatabaseID: database.ID, @@ -486,7 +486,7 @@ func Test_CleanByGFS_WithWeeklyAndMonthlySlots_KeepsWiderSpread(t *testing.T) { // Create one backup per week for 6 weeks (each on Monday of that week) // GFS should keep: 2 daily (most recent 2 unique days) + 2 weekly + 1 monthly = up to 5 unique var createdIDs []uuid.UUID - for i := 0; i < 6; i++ { + for i := range 6 { weekOffset := time.Duration(5-i) * 7 * 24 * time.Hour backup := &backups_core.Backup{ ID: uuid.New(), @@ -561,7 +561,7 @@ func Test_CleanByGFS_WithHourlySlots_KeepsCorrectBackups(t *testing.T) { // Create 5 backups spaced 1 hour apart; only the 3 newest hours should be kept var backupIDs []uuid.UUID - for i := 0; i < 5; i++ { + for i := range 5 { backup := &backups_core.Backup{ ID: uuid.New(), DatabaseID: database.ID, @@ -824,8 +824,8 @@ func Test_CleanByGFS_WithMultipleBackupsPerDay_KeepsOnlyOnePerDailySlot(t *testi // Create 3 backups per day for 10 days = 30 total, all beyond grace period. // Each day gets backups at base+0h, base+6h, base+12h. - for day := 0; day < 10; day++ { - for sub := 0; sub < 3; sub++ { + for day := range 10 { + for sub := range 3 { backup := &backups_core.Backup{ ID: uuid.New(), DatabaseID: database.ID, @@ -915,7 +915,7 @@ func Test_CleanByGFS_With24HourlySlotsAnd23DailyBackups_DeletesExcessBackups(t * now := time.Now().UTC() - for i := 0; i < 23; i++ { + for i := range 23 { backup := &backups_core.Backup{ ID: uuid.New(), DatabaseID: database.ID, @@ -985,7 +985,7 @@ func Test_CleanByGFS_WithDisabledHourlySlotsAnd23DailyBackups_DeletesExcessBacku now := time.Now().UTC() - for i := 0; i < 23; i++ { + for i := range 23 { backup := &backups_core.Backup{ ID: uuid.New(), DatabaseID: database.ID, @@ -1055,7 +1055,7 @@ func Test_CleanByGFS_WithDailySlotsAndWeeklyBackups_DeletesExcessBackups(t *test // Create 10 weekly backups (1 per week, all >2h old past grace period). // With 7d/4w config, correct behavior: ~8 kept (4 weekly + overlap with daily for recent ones). // Daily slots should NOT absorb weekly backups that are older than 7 days. - for i := 0; i < 10; i++ { + for i := range 10 { backup := &backups_core.Backup{ ID: uuid.New(), DatabaseID: database.ID, @@ -1138,7 +1138,7 @@ func Test_CleanByGFS_WithWeeklySlotsAndMonthlyBackups_DeletesExcessBackups(t *te // With 52w/3m config, correct behavior: 3 kept (3 monthly slots; weekly should only // cover recent 52 weeks but not artificially retain old monthly backups). // Bug: all 8 kept because each monthly backup fills a unique weekly slot. - for i := 0; i < 8; i++ { + for i := range 8 { backup := &backups_core.Backup{ ID: uuid.New(), DatabaseID: database.ID, diff --git a/backend/internal/features/backups/backups/backuping/cleaner_test.go b/backend/internal/features/backups/backups/backuping/cleaner_test.go index d629483..1ed9d17 100644 --- a/backend/internal/features/backups/backups/backuping/cleaner_test.go +++ b/backend/internal/features/backups/backups/backuping/cleaner_test.go @@ -197,7 +197,7 @@ func Test_CleanExceededBackups_WhenUnderStorageLimit_NoBackupsDeleted(t *testing _, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig) assert.NoError(t, err) - for i := 0; i < 3; i++ { + for i := range 3 { backup := &backups_core.Backup{ ID: uuid.New(), DatabaseID: database.ID, @@ -263,7 +263,7 @@ func Test_CleanExceededBackups_WhenOverStorageLimit_DeletesOldestBackups(t *test // Expect 2 oldest deleted, 3 remain (900 MB < 1024 MB) now := time.Now().UTC() var backupIDs []uuid.UUID - for i := 0; i < 5; i++ { + for i := range 5 { backup := &backups_core.Backup{ ID: uuid.New(), DatabaseID: database.ID, @@ -340,7 +340,7 @@ func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) { // 3 completed at 500 MB each = 1500 MB, limit = 1 GB (1024 MB) completedBackups := make([]*backups_core.Backup, 3) - for i := 0; i < 3; i++ { + for i := range 3 { backup := &backups_core.Backup{ ID: uuid.New(), DatabaseID: database.ID, @@ -423,7 +423,7 @@ func Test_CleanExceededBackups_WithZeroStorageLimit_RemovesAllBackups(t *testing _, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig) assert.NoError(t, err) - for i := 0; i < 10; i++ { + for i := range 10 { backup := &backups_core.Backup{ ID: uuid.New(), DatabaseID: database.ID, @@ -555,7 +555,7 @@ func Test_CleanByCount_KeepsNewestNBackups_DeletesOlder(t *testing.T) { now := time.Now().UTC() var backupIDs []uuid.UUID - for i := 0; i < 5; i++ { + for i := range 5 { backup := &backups_core.Backup{ ID: uuid.New(), DatabaseID: database.ID, @@ -626,7 +626,7 @@ func Test_CleanByCount_WhenUnderLimit_NoBackupsDeleted(t *testing.T) { _, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig) assert.NoError(t, err) - for i := 0; i < 5; i++ { + for i := range 5 { backup := &backups_core.Backup{ ID: uuid.New(), DatabaseID: database.ID, @@ -686,7 +686,7 @@ func Test_CleanByCount_DoesNotDeleteInProgressBackups(t *testing.T) { now := time.Now().UTC() - for i := 0; i < 3; i++ { + for i := range 3 { backup := &backups_core.Backup{ ID: uuid.New(), DatabaseID: database.ID, @@ -1064,7 +1064,7 @@ func Test_CleanExceededStorageBackups_WhenNonCloud_SkipsCleanup(t *testing.T) { // 5 backups at 500 MB each = 2500 MB, would exceed 1 GB limit in cloud mode now := time.Now().UTC() - for i := 0; i < 5; i++ { + for i := range 5 { backup := &backups_core.Backup{ ID: uuid.New(), DatabaseID: database.ID, diff --git a/backend/internal/features/backups/backups/backuping/di.go b/backend/internal/features/backups/backups/backuping/di.go index 623449d..332b62d 100644 --- a/backend/internal/features/backups/backups/backuping/di.go +++ b/backend/internal/features/backups/backups/backuping/di.go @@ -1,7 +1,6 @@ package backuping import ( - "sync" "sync/atomic" "time" @@ -33,7 +32,6 @@ var backupCleaner = &BackupCleaner{ encryption.GetFieldEncryptor(), logger.GetLogger(), []backups_core.BackupRemoveListener{}, - sync.Once{}, atomic.Bool{}, } @@ -43,7 +41,6 @@ var backupNodesRegistry = &BackupNodesRegistry{ cache_utils.DefaultCacheTimeout, cache_utils.NewPubSubManager(), cache_utils.NewPubSubManager(), - sync.Once{}, atomic.Bool{}, } @@ -65,7 +62,6 @@ var backuperNode = &BackuperNode{ usecases.GetCreateBackupUsecase(), getNodeID(), time.Time{}, - sync.Once{}, atomic.Bool{}, } @@ -80,7 +76,6 @@ var backupsScheduler = &BackupsScheduler{ logger.GetLogger(), make(map[uuid.UUID]BackupToNodeRelation), backuperNode, - sync.Once{}, atomic.Bool{}, } diff --git a/backend/internal/features/backups/backups/backuping/registry.go b/backend/internal/features/backups/backups/backuping/registry.go index 415d41d..b5d6305 100644 --- a/backend/internal/features/backups/backups/backuping/registry.go +++ b/backend/internal/features/backups/backups/backuping/registry.go @@ -6,7 +6,6 @@ import ( "fmt" "log/slog" "strings" - "sync" "sync/atomic" "time" @@ -50,36 +49,30 @@ type BackupNodesRegistry struct { pubsubBackups *cache_utils.PubSubManager pubsubCompletions *cache_utils.PubSubManager - runOnce sync.Once - hasRun atomic.Bool + hasRun atomic.Bool } func (r *BackupNodesRegistry) Run(ctx context.Context) { - wasAlreadyRun := r.hasRun.Load() + if r.hasRun.Swap(true) { + panic(fmt.Sprintf("%T.Run() called multiple times", r)) + } - r.runOnce.Do(func() { - r.hasRun.Store(true) + if err := r.cleanupDeadNodes(); err != nil { + r.logger.Error("Failed to cleanup dead nodes on startup", "error", err) + } - if err := r.cleanupDeadNodes(); err != nil { - r.logger.Error("Failed to cleanup dead nodes on startup", "error", err) - } + ticker := time.NewTicker(cleanupTickerInterval) + defer ticker.Stop() - ticker := time.NewTicker(cleanupTickerInterval) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - if err := r.cleanupDeadNodes(); err != nil { - r.logger.Error("Failed to cleanup dead nodes", "error", err) - } + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := r.cleanupDeadNodes(); err != nil { + r.logger.Error("Failed to cleanup dead nodes", "error", err) } } - }) - - if wasAlreadyRun { - panic(fmt.Sprintf("%T.Run() called multiple times", r)) } } diff --git a/backend/internal/features/backups/backups/backuping/registry_test.go b/backend/internal/features/backups/backups/backuping/registry_test.go index 3c614fd..efc331a 100644 --- a/backend/internal/features/backups/backups/backuping/registry_test.go +++ b/backend/internal/features/backups/backups/backuping/registry_test.go @@ -4,8 +4,6 @@ import ( "context" "encoding/json" "fmt" - "sync" - "sync/atomic" "testing" "time" @@ -322,7 +320,7 @@ func Test_GetAvailableNodes_SkipsInvalidJsonData(t *testing.T) { err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node) assert.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), registry.timeout) + ctx, cancel := context.WithTimeout(t.Context(), registry.timeout) defer cancel() invalidKey := nodeInfoKeyPrefix + uuid.New().String() + nodeInfoKeySuffix @@ -331,7 +329,7 @@ func Test_GetAvailableNodes_SkipsInvalidJsonData(t *testing.T) { registry.client.B().Set().Key(invalidKey).Value("invalid json data").Build(), ) defer func() { - cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), registry.timeout) + cleanupCtx, cleanupCancel := context.WithTimeout(t.Context(), registry.timeout) defer cleanupCancel() registry.client.Do(cleanupCtx, registry.client.B().Del().Key(invalidKey).Build()) }() @@ -401,7 +399,7 @@ func Test_GetAvailableNodes_ExcludesStaleNodesFromCache(t *testing.T) { err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3) assert.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), registry.timeout) + ctx, cancel := context.WithTimeout(t.Context(), registry.timeout) defer cancel() key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix) @@ -419,7 +417,7 @@ func Test_GetAvailableNodes_ExcludesStaleNodesFromCache(t *testing.T) { modifiedData, err := json.Marshal(node) assert.NoError(t, err) - setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout) + setCtx, setCancel := context.WithTimeout(t.Context(), registry.timeout) defer setCancel() setResult := registry.client.Do( setCtx, @@ -464,7 +462,7 @@ func Test_GetBackupNodesStats_ExcludesStaleNodesFromCache(t *testing.T) { err = registry.IncrementBackupsInProgress(node3.ID) assert.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), registry.timeout) + ctx, cancel := context.WithTimeout(t.Context(), registry.timeout) defer cancel() key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix) @@ -482,7 +480,7 @@ func Test_GetBackupNodesStats_ExcludesStaleNodesFromCache(t *testing.T) { modifiedData, err := json.Marshal(node) assert.NoError(t, err) - setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout) + setCtx, setCancel := context.WithTimeout(t.Context(), registry.timeout) defer setCancel() setResult := registry.client.Do( setCtx, @@ -524,7 +522,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) { err = registry.IncrementBackupsInProgress(node2.ID) assert.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), registry.timeout) + ctx, cancel := context.WithTimeout(t.Context(), registry.timeout) defer cancel() key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix) @@ -542,7 +540,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) { modifiedData, err := json.Marshal(node) assert.NoError(t, err) - setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout) + setCtx, setCancel := context.WithTimeout(t.Context(), registry.timeout) defer setCancel() setResult := registry.client.Do( setCtx, @@ -553,7 +551,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) { err = registry.cleanupDeadNodes() assert.NoError(t, err) - checkCtx, checkCancel := context.WithTimeout(context.Background(), registry.timeout) + checkCtx, checkCancel := context.WithTimeout(t.Context(), registry.timeout) defer checkCancel() infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix) @@ -566,7 +564,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) { node2.ID.String(), nodeActiveBackupsSuffix, ) - counterCtx, counterCancel := context.WithTimeout(context.Background(), registry.timeout) + counterCtx, counterCancel := context.WithTimeout(t.Context(), registry.timeout) defer counterCancel() counterResult := registry.client.Do( counterCtx, @@ -575,7 +573,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) { assert.Error(t, counterResult.Error()) activeInfoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node1.ID.String(), nodeInfoKeySuffix) - activeCtx, activeCancel := context.WithTimeout(context.Background(), registry.timeout) + activeCtx, activeCancel := context.WithTimeout(t.Context(), registry.timeout) defer activeCancel() activeResult := registry.client.Do( activeCtx, @@ -601,8 +599,6 @@ func createTestRegistry() *BackupNodesRegistry { timeout: cache_utils.DefaultCacheTimeout, pubsubBackups: cache_utils.NewPubSubManager(), pubsubCompletions: cache_utils.NewPubSubManager(), - runOnce: sync.Once{}, - hasRun: atomic.Bool{}, } } @@ -732,7 +728,7 @@ func Test_SubscribeNodeForBackupsAssignment_HandlesInvalidJson(t *testing.T) { time.Sleep(100 * time.Millisecond) - ctx := context.Background() + ctx := t.Context() err = registry.pubsubBackups.Publish(ctx, "backup:submit", "invalid json") assert.NoError(t, err) @@ -978,7 +974,7 @@ func Test_SubscribeForBackupsCompletions_HandlesInvalidJson(t *testing.T) { time.Sleep(100 * time.Millisecond) - ctx := context.Background() + ctx := t.Context() err = registry.pubsubCompletions.Publish(ctx, "backup:completion", "invalid json") assert.NoError(t, err) @@ -1093,7 +1089,7 @@ func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) { receivedAll2 := []uuid.UUID{} receivedAll3 := []uuid.UUID{} - for i := 0; i < 3; i++ { + for range 3 { select { case received := <-receivedBackups1: receivedAll1 = append(receivedAll1, received) @@ -1102,7 +1098,7 @@ func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) { } } - for i := 0; i < 3; i++ { + for range 3 { select { case received := <-receivedBackups2: receivedAll2 = append(receivedAll2, received) @@ -1111,7 +1107,7 @@ func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) { } } - for i := 0; i < 3; i++ { + for range 3 { select { case received := <-receivedBackups3: receivedAll3 = append(receivedAll3, received) diff --git a/backend/internal/features/backups/backups/backuping/scheduler.go b/backend/internal/features/backups/backups/backuping/scheduler.go index cb8624a..732cd57 100644 --- a/backend/internal/features/backups/backups/backuping/scheduler.go +++ b/backend/internal/features/backups/backups/backuping/scheduler.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "log/slog" - "sync" "sync/atomic" "time" @@ -37,68 +36,61 @@ type BackupsScheduler struct { backupToNodeRelations map[uuid.UUID]BackupToNodeRelation backuperNode *BackuperNode - runOnce sync.Once - hasRun atomic.Bool + hasRun atomic.Bool } func (s *BackupsScheduler) Run(ctx context.Context) { - wasAlreadyRun := s.hasRun.Load() - - s.runOnce.Do(func() { - s.hasRun.Store(true) - - s.lastBackupTime = time.Now().UTC() - - if config.GetEnv().IsManyNodesMode { - // wait other nodes to start - time.Sleep(schedulerStartupDelay) - } - - if err := s.failBackupsInProgress(); err != nil { - s.logger.Error("Failed to fail backups in progress", "error", err) - panic(err) - } - - err := s.backupNodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted) - if err != nil { - s.logger.Error("Failed to subscribe to backup completions", "error", err) - panic(err) - } - - defer func() { - if err := s.backupNodesRegistry.UnsubscribeForBackupsCompletions(); err != nil { - s.logger.Error("Failed to unsubscribe from backup completions", "error", err) - } - }() - - if ctx.Err() != nil { - return - } - - ticker := time.NewTicker(schedulerTickerInterval) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - if err := s.checkDeadNodesAndFailBackups(); err != nil { - s.logger.Error("Failed to check dead nodes and fail backups", "error", err) - } - - if err := s.runPendingBackups(); err != nil { - s.logger.Error("Failed to run pending backups", "error", err) - } - - s.lastBackupTime = time.Now().UTC() - } - } - }) - - if wasAlreadyRun { + if s.hasRun.Swap(true) { panic(fmt.Sprintf("%T.Run() called multiple times", s)) } + + s.lastBackupTime = time.Now().UTC() + + if config.GetEnv().IsManyNodesMode { + // wait other nodes to start + time.Sleep(schedulerStartupDelay) + } + + if err := s.failBackupsInProgress(); err != nil { + s.logger.Error("Failed to fail backups in progress", "error", err) + panic(err) + } + + err := s.backupNodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted) + if err != nil { + s.logger.Error("Failed to subscribe to backup completions", "error", err) + panic(err) + } + + defer func() { + if err := s.backupNodesRegistry.UnsubscribeForBackupsCompletions(); err != nil { + s.logger.Error("Failed to unsubscribe from backup completions", "error", err) + } + }() + + if ctx.Err() != nil { + return + } + + ticker := time.NewTicker(schedulerTickerInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := s.checkDeadNodesAndFailBackups(); err != nil { + s.logger.Error("Failed to check dead nodes and fail backups", "error", err) + } + + if err := s.runPendingBackups(); err != nil { + s.logger.Error("Failed to run pending backups", "error", err) + } + + s.lastBackupTime = time.Now().UTC() + } + } } func (s *BackupsScheduler) IsSchedulerRunning() bool { diff --git a/backend/internal/features/backups/backups/backuping/testing.go b/backend/internal/features/backups/backups/backuping/testing.go index 0d2fec7..3ce2e2e 100644 --- a/backend/internal/features/backups/backups/backuping/testing.go +++ b/backend/internal/features/backups/backups/backuping/testing.go @@ -3,7 +3,6 @@ package backuping import ( "context" "fmt" - "sync" "sync/atomic" "testing" "time" @@ -44,7 +43,6 @@ func CreateTestBackupCleaner(billingService BillingService) *BackupCleaner { encryption.GetFieldEncryptor(), logger.GetLogger(), []backups_core.BackupRemoveListener{}, - sync.Once{}, atomic.Bool{}, } } @@ -64,7 +62,6 @@ func CreateTestBackuperNode() *BackuperNode { usecases.GetCreateBackupUsecase(), uuid.New(), time.Time{}, - sync.Once{}, atomic.Bool{}, } } @@ -84,7 +81,6 @@ func CreateTestBackuperNodeWithUseCase(useCase backups_core.CreateBackupUsecase) useCase, uuid.New(), time.Time{}, - sync.Once{}, atomic.Bool{}, } } @@ -101,7 +97,6 @@ func CreateTestScheduler(billingService BillingService) *BackupsScheduler { logger.GetLogger(), make(map[uuid.UUID]BackupToNodeRelation), CreateTestBackuperNode(), - sync.Once{}, atomic.Bool{}, } } diff --git a/backend/internal/features/backups/backups/download/background.go b/backend/internal/features/backups/backups/download/background.go index 54e664a..0f95240 100644 --- a/backend/internal/features/backups/backups/download/background.go +++ b/backend/internal/features/backups/backups/download/background.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "log/slog" - "sync" "sync/atomic" "time" ) @@ -13,38 +12,31 @@ type DownloadTokenBackgroundService struct { downloadTokenService *DownloadTokenService logger *slog.Logger - runOnce sync.Once - hasRun atomic.Bool + hasRun atomic.Bool } func (s *DownloadTokenBackgroundService) Run(ctx context.Context) { - wasAlreadyRun := s.hasRun.Load() - - s.runOnce.Do(func() { - s.hasRun.Store(true) - - s.logger.Info("Starting download token cleanup background service") - - if ctx.Err() != nil { - return - } - - ticker := time.NewTicker(1 * time.Minute) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - if err := s.downloadTokenService.CleanExpiredTokens(); err != nil { - s.logger.Error("Failed to clean expired download tokens", "error", err) - } - } - } - }) - - if wasAlreadyRun { + if s.hasRun.Swap(true) { panic(fmt.Sprintf("%T.Run() called multiple times", s)) } + + s.logger.Info("Starting download token cleanup background service") + + if ctx.Err() != nil { + return + } + + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := s.downloadTokenService.CleanExpiredTokens(); err != nil { + s.logger.Error("Failed to clean expired download tokens", "error", err) + } + } + } } diff --git a/backend/internal/features/backups/backups/download/di.go b/backend/internal/features/backups/backups/download/di.go index 80c72b6..4161df9 100644 --- a/backend/internal/features/backups/backups/download/di.go +++ b/backend/internal/features/backups/backups/download/di.go @@ -1,9 +1,6 @@ package backups_download import ( - "sync" - "sync/atomic" - "databasus-backend/internal/config" cache_utils "databasus-backend/internal/util/cache" "databasus-backend/internal/util/logger" @@ -37,8 +34,6 @@ func init() { downloadTokenBackgroundService = &DownloadTokenBackgroundService{ downloadTokenService: downloadTokenService, logger: logger.GetLogger(), - runOnce: sync.Once{}, - hasRun: atomic.Bool{}, } } diff --git a/backend/internal/features/backups/backups/services/di.go b/backend/internal/features/backups/backups/services/di.go index a0b5cfc..3094e52 100644 --- a/backend/internal/features/backups/backups/services/di.go +++ b/backend/internal/features/backups/backups/services/di.go @@ -2,7 +2,6 @@ package backups_services import ( "sync" - "sync/atomic" audit_logs "databasus-backend/internal/features/audit_logs" "databasus-backend/internal/features/backups/backups/backuping" @@ -59,26 +58,11 @@ func GetWalService() *PostgreWalBackupService { return walService } -var ( - setupOnce sync.Once - isSetup atomic.Bool -) +var SetupDependencies = sync.OnceFunc(func() { + backups_config. + GetBackupConfigService(). + SetDatabaseStorageChangeListener(backupService) -func SetupDependencies() { - wasAlreadySetup := isSetup.Load() - - setupOnce.Do(func() { - backups_config. - GetBackupConfigService(). - SetDatabaseStorageChangeListener(backupService) - - databases.GetDatabaseService().AddDbRemoveListener(backupService) - databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService()) - - isSetup.Store(true) - }) - - if wasAlreadySetup { - logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call") - } -} + databases.GetDatabaseService().AddDbRemoveListener(backupService) + databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService()) +}) diff --git a/backend/internal/features/backups/config/di.go b/backend/internal/features/backups/config/di.go index 0626619..bbf1a12 100644 --- a/backend/internal/features/backups/config/di.go +++ b/backend/internal/features/backups/config/di.go @@ -2,13 +2,11 @@ package backups_config import ( "sync" - "sync/atomic" "databasus-backend/internal/features/databases" "databasus-backend/internal/features/notifiers" "databasus-backend/internal/features/storages" workspaces_services "databasus-backend/internal/features/workspaces/services" - "databasus-backend/internal/util/logger" ) var ( @@ -35,21 +33,6 @@ func GetBackupConfigService() *BackupConfigService { return backupConfigService } -var ( - setupOnce sync.Once - isSetup atomic.Bool -) - -func SetupDependencies() { - wasAlreadySetup := isSetup.Load() - - setupOnce.Do(func() { - storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService) - - isSetup.Store(true) - }) - - if wasAlreadySetup { - logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call") - } -} +var SetupDependencies = sync.OnceFunc(func() { + storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService) +}) diff --git a/backend/internal/features/backups/config/dto.go b/backend/internal/features/backups/config/dto.go index 46559cf..4a9d213 100644 --- a/backend/internal/features/backups/config/dto.go +++ b/backend/internal/features/backups/config/dto.go @@ -7,5 +7,5 @@ type TransferDatabaseRequest struct { TargetStorageID *uuid.UUID `json:"targetStorageId,omitempty"` IsTransferWithStorage bool `json:"isTransferWithStorage,omitempty"` IsTransferWithNotifiers bool `json:"isTransferWithNotifiers,omitempty"` - TargetNotifierIDs []uuid.UUID `json:"targetNotifierIds,omitempty"` + TargetNotifierIDs []uuid.UUID `json:"targetNotifierIds,omitzero"` } diff --git a/backend/internal/features/backups/config/model.go b/backend/internal/features/backups/config/model.go index bd2bc60..51f483b 100644 --- a/backend/internal/features/backups/config/model.go +++ b/backend/internal/features/backups/config/model.go @@ -28,8 +28,8 @@ type BackupConfig struct { RetentionGfsMonths int `json:"retentionGfsMonths" gorm:"column:retention_gfs_months;type:int;not null;default:0"` RetentionGfsYears int `json:"retentionGfsYears" gorm:"column:retention_gfs_years;type:int;not null;default:0"` - BackupIntervalID uuid.UUID `json:"backupIntervalId" gorm:"column:backup_interval_id;type:uuid;not null"` - BackupInterval *intervals.Interval `json:"backupInterval,omitempty" gorm:"foreignKey:BackupIntervalID"` + BackupIntervalID uuid.UUID `json:"backupIntervalId" gorm:"column:backup_interval_id;type:uuid;not null"` + BackupInterval *intervals.Interval `json:"backupInterval,omitzero" gorm:"foreignKey:BackupIntervalID"` Storage *storages.Storage `json:"storage" gorm:"foreignKey:StorageID"` StorageID *uuid.UUID `json:"storageId" gorm:"column:storage_id;type:uuid;"` diff --git a/backend/internal/features/billing/controller_test.go b/backend/internal/features/billing/controller_test.go index 2257a1f..4571c45 100644 --- a/backend/internal/features/billing/controller_test.go +++ b/backend/internal/features/billing/controller_test.go @@ -584,7 +584,7 @@ func Test_GetInvoices_WithPagination_ReturnsCorrectPage(t *testing.T) { sub := activateSubscriptionViaWebhook(t, router, owner.Token, db.ID, 50) - for i := 0; i < 3; i++ { + for i := range 3 { invoiceID := fmt.Sprintf("inv-pag-%d-%s", i, uuid.New().String()[:8]) evt := makePaymentWebhookEvent(invoiceID, 50, int64(500+i*100)) err := billingService.RecordPaymentSuccess(log, sub, evt) diff --git a/backend/internal/features/billing/di.go b/backend/internal/features/billing/di.go index 6418291..7d40126 100644 --- a/backend/internal/features/billing/di.go +++ b/backend/internal/features/billing/di.go @@ -7,7 +7,6 @@ import ( billing_repositories "databasus-backend/internal/features/billing/repositories" "databasus-backend/internal/features/databases" workspaces_services "databasus-backend/internal/features/workspaces/services" - "databasus-backend/internal/util/logger" ) var ( @@ -18,13 +17,9 @@ var ( nil, // billing provider will be set later to avoid circular dependency workspaces_services.GetWorkspaceService(), *databases.GetDatabaseService(), - sync.Once{}, atomic.Bool{}, } billingController = &BillingController{billingService} - - setupOnce sync.Once - isSetup atomic.Bool ) func GetBillingService() *BillingService { @@ -35,15 +30,6 @@ func GetBillingController() *BillingController { return billingController } -func SetupDependencies() { - wasAlreadySetup := isSetup.Load() - - setupOnce.Do(func() { - databases.GetDatabaseService().AddDbCreationListener(billingService) - isSetup.Store(true) - }) - - if wasAlreadySetup { - logger.GetLogger().Warn("billing.SetupDependencies called multiple times, ignoring subsequent call") - } -} +var SetupDependencies = sync.OnceFunc(func() { + databases.GetDatabaseService().AddDbCreationListener(billingService) +}) diff --git a/backend/internal/features/billing/models/invoice.go b/backend/internal/features/billing/models/invoice.go index dc03d0d..03667dc 100644 --- a/backend/internal/features/billing/models/invoice.go +++ b/backend/internal/features/billing/models/invoice.go @@ -15,7 +15,7 @@ type Invoice struct { PeriodStart time.Time `json:"periodStart" gorm:"column:period_start;type:timestamptz;not null"` PeriodEnd time.Time `json:"periodEnd" gorm:"column:period_end;type:timestamptz;not null"` Status InvoiceStatus `json:"status" gorm:"column:status;type:text;not null"` - PaidAt *time.Time `json:"paidAt,omitempty" gorm:"column:paid_at;type:timestamptz"` + PaidAt *time.Time `json:"paidAt,omitzero" gorm:"column:paid_at;type:timestamptz"` CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;type:timestamptz;not null"` } diff --git a/backend/internal/features/billing/models/subscription.go b/backend/internal/features/billing/models/subscription.go index 707fc36..ace1cab 100644 --- a/backend/internal/features/billing/models/subscription.go +++ b/backend/internal/features/billing/models/subscription.go @@ -16,11 +16,11 @@ type Subscription struct { StorageGB int `json:"storageGb" gorm:"column:storage_gb;type:int;not null"` PendingStorageGB *int `json:"pendingStorageGb,omitempty" gorm:"column:pending_storage_gb;type:int"` - CurrentPeriodStart time.Time `json:"currentPeriodStart" gorm:"column:current_period_start;type:timestamptz;not null"` - CurrentPeriodEnd time.Time `json:"currentPeriodEnd" gorm:"column:current_period_end;type:timestamptz;not null"` - CanceledAt *time.Time `json:"canceledAt,omitempty" gorm:"column:canceled_at;type:timestamptz"` + CurrentPeriodStart time.Time `json:"currentPeriodStart" gorm:"column:current_period_start;type:timestamptz;not null"` + CurrentPeriodEnd time.Time `json:"currentPeriodEnd" gorm:"column:current_period_end;type:timestamptz;not null"` + CanceledAt *time.Time `json:"canceledAt,omitzero" gorm:"column:canceled_at;type:timestamptz"` - DataRetentionGracePeriodUntil *time.Time `json:"dataRetentionGracePeriodUntil,omitempty" gorm:"column:data_retention_grace_period_until;type:timestamptz"` + DataRetentionGracePeriodUntil *time.Time `json:"dataRetentionGracePeriodUntil,omitzero" gorm:"column:data_retention_grace_period_until;type:timestamptz"` ProviderName *string `json:"providerName,omitempty" gorm:"column:provider_name;type:text"` ProviderSubID *string `json:"providerSubId,omitempty" gorm:"column:provider_sub_id;type:text"` diff --git a/backend/internal/features/billing/paddle/di.go b/backend/internal/features/billing/paddle/di.go index e7793fe..a53b40c 100644 --- a/backend/internal/features/billing/paddle/di.go +++ b/backend/internal/features/billing/paddle/di.go @@ -13,46 +13,46 @@ import ( var ( paddleBillingService *PaddleBillingService paddleBillingController *PaddleBillingController - initOnce sync.Once ) +var initPaddle = sync.OnceFunc(func() { + if config.GetEnv().IsPaddleSandbox { + paddleClient, err := paddle.NewSandbox(config.GetEnv().PaddleApiKey) + if err != nil { + return + } + + paddleBillingService = &PaddleBillingService{ + paddleClient, + paddle.NewWebhookVerifier(config.GetEnv().PaddleWebhookSecret), + config.GetEnv().PaddlePriceID, + billing_webhooks.WebhookRepository{}, + billing.GetBillingService(), + } + } else { + paddleClient, err := paddle.New(config.GetEnv().PaddleApiKey) + if err != nil { + return + } + + paddleBillingService = &PaddleBillingService{ + paddleClient, + paddle.NewWebhookVerifier(config.GetEnv().PaddleWebhookSecret), + config.GetEnv().PaddlePriceID, + billing_webhooks.WebhookRepository{}, + billing.GetBillingService(), + } + } + + paddleBillingController = &PaddleBillingController{paddleBillingService} +}) + func GetPaddleBillingService() *PaddleBillingService { if !config.GetEnv().IsCloud { return nil } - initOnce.Do(func() { - if config.GetEnv().IsPaddleSandbox { - paddleClient, err := paddle.NewSandbox(config.GetEnv().PaddleApiKey) - if err != nil { - return - } - - paddleBillingService = &PaddleBillingService{ - paddleClient, - paddle.NewWebhookVerifier(config.GetEnv().PaddleWebhookSecret), - config.GetEnv().PaddlePriceID, - billing_webhooks.WebhookRepository{}, - billing.GetBillingService(), - } - } else { - paddleClient, err := paddle.New(config.GetEnv().PaddleApiKey) - if err != nil { - return - } - - paddleBillingService = &PaddleBillingService{ - paddleClient, - paddle.NewWebhookVerifier(config.GetEnv().PaddleWebhookSecret), - config.GetEnv().PaddlePriceID, - billing_webhooks.WebhookRepository{}, - billing.GetBillingService(), - } - } - - paddleBillingController = &PaddleBillingController{paddleBillingService} - }) - + initPaddle() return paddleBillingService } diff --git a/backend/internal/features/billing/service.go b/backend/internal/features/billing/service.go index 6b94b37..de4016f 100644 --- a/backend/internal/features/billing/service.go +++ b/backend/internal/features/billing/service.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "log/slog" - "sync" "sync/atomic" "time" @@ -35,57 +34,50 @@ type BillingService struct { workspaceService *workspaces_services.WorkspaceService databaseService databases.DatabaseService - runOnce sync.Once - hasRun atomic.Bool + hasRun atomic.Bool } func (s *BillingService) Run(ctx context.Context, logger slog.Logger) { - wasAlreadyRun := s.hasRun.Load() + if s.hasRun.Swap(true) { + panic(fmt.Sprintf("%T.Run() called multiple times", s)) + } - s.runOnce.Do(func() { - s.hasRun.Store(true) + ticker := time.NewTicker(billingTickerInterval) + defer ticker.Stop() - ticker := time.NewTicker(billingTickerInterval) - defer ticker.Stop() + // Run immediately on start + expiredSubsLog := logger.With("task_name", "process_expired_subscriptions") + if err := s.processExpiredSubscriptions(expiredSubsLog); err != nil { + expiredSubsLog.Error("failed to process expired subscriptions", "error", err) + } - // Run immediately on start - expiredSubsLog := logger.With("task_name", "process_expired_subscriptions") - if err := s.processExpiredSubscriptions(expiredSubsLog); err != nil { - expiredSubsLog.Error("failed to process expired subscriptions", "error", err) - } + expiredTrialsLog := logger.With("task_name", "process_expired_trials") + if err := s.processExpiredTrials(expiredTrialsLog); err != nil { + expiredTrialsLog.Error("failed to process expired trials", "error", err) + } - expiredTrialsLog := logger.With("task_name", "process_expired_trials") - if err := s.processExpiredTrials(expiredTrialsLog); err != nil { - expiredTrialsLog.Error("failed to process expired trials", "error", err) - } + reconcileSubsLog := logger.With("task_name", "reconcile_subscriptions") + if err := s.reconcileSubscriptions(reconcileSubsLog); err != nil { + reconcileSubsLog.Error("failed to reconcile subscriptions", "error", err) + } - reconcileSubsLog := logger.With("task_name", "reconcile_subscriptions") - if err := s.reconcileSubscriptions(reconcileSubsLog); err != nil { - reconcileSubsLog.Error("failed to reconcile subscriptions", "error", err) - } + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := s.processExpiredSubscriptions(expiredSubsLog); err != nil { + expiredSubsLog.Error("failed to process expired subscriptions", "error", err) + } - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - if err := s.processExpiredSubscriptions(expiredSubsLog); err != nil { - expiredSubsLog.Error("failed to process expired subscriptions", "error", err) - } + if err := s.processExpiredTrials(expiredTrialsLog); err != nil { + expiredTrialsLog.Error("failed to process expired trials", "error", err) + } - if err := s.processExpiredTrials(expiredTrialsLog); err != nil { - expiredTrialsLog.Error("failed to process expired trials", "error", err) - } - - if err := s.reconcileSubscriptions(reconcileSubsLog); err != nil { - reconcileSubsLog.Error("failed to reconcile subscriptions", "error", err) - } + if err := s.reconcileSubscriptions(reconcileSubsLog); err != nil { + reconcileSubsLog.Error("failed to reconcile subscriptions", "error", err) } } - }) - - if wasAlreadyRun { - panic(fmt.Sprintf("%T.Run() called multiple times", s)) } } diff --git a/backend/internal/features/databases/databases/mariadb/model_test.go b/backend/internal/features/databases/databases/mariadb/model_test.go index 96f6bc6..e1d699d 100644 --- a/backend/internal/features/databases/databases/mariadb/model_test.go +++ b/backend/internal/features/databases/databases/mariadb/model_test.go @@ -1,7 +1,6 @@ package mariadb import ( - "context" "fmt" "log/slog" "os" @@ -212,7 +211,7 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) { mariadbModel := createMariadbModel(container) logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - ctx := context.Background() + ctx := t.Context() isReadOnly, privileges, err := mariadbModel.IsUserReadOnly(ctx, logger, nil, uuid.New()) assert.NoError(t, err) @@ -241,7 +240,7 @@ func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) { mariadbModel := createMariadbModel(container) logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - ctx := context.Background() + ctx := t.Context() username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New()) assert.NoError(t, err) @@ -313,7 +312,7 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) { mariadbModel := createMariadbModel(container) logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - ctx := context.Background() + ctx := t.Context() username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New()) assert.NoError(t, err) @@ -390,7 +389,7 @@ func Test_ReadOnlyUser_FutureTables_NoSelectPermission(t *testing.T) { mariadbModel := createMariadbModel(container) logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - ctx := context.Background() + ctx := t.Context() username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New()) assert.NoError(t, err) @@ -466,7 +465,7 @@ func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) { } logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - ctx := context.Background() + ctx := t.Context() username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New()) assert.NoError(t, err) @@ -511,7 +510,7 @@ func Test_ReadOnlyUser_CannotDropOrAlterTables(t *testing.T) { mariadbModel := createMariadbModel(container) logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - ctx := context.Background() + ctx := t.Context() username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New()) assert.NoError(t, err) diff --git a/backend/internal/features/databases/databases/mongodb/model_test.go b/backend/internal/features/databases/databases/mongodb/model_test.go index 53c5734..64bb809 100644 --- a/backend/internal/features/databases/databases/mongodb/model_test.go +++ b/backend/internal/features/databases/databases/mongodb/model_test.go @@ -42,9 +42,9 @@ func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) { t.Parallel() container := connectToMongodbContainer(t, tc.port, tc.version) - defer container.Client.Disconnect(context.Background()) + defer container.Client.Disconnect(t.Context()) - ctx := context.Background() + ctx := t.Context() db := container.Client.Database(container.Database) _ = db.Collection("permission_test").Drop(ctx) @@ -108,9 +108,9 @@ func Test_TestConnection_SufficientPermissions_Success(t *testing.T) { t.Parallel() container := connectToMongodbContainer(t, tc.port, tc.version) - defer container.Client.Disconnect(context.Background()) + defer container.Client.Disconnect(t.Context()) - ctx := context.Background() + ctx := t.Context() db := container.Client.Database(container.Database) _ = db.Collection("backup_test").Drop(ctx) @@ -178,11 +178,11 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) { t.Parallel() container := connectToMongodbContainer(t, tc.port, tc.version) - defer container.Client.Disconnect(context.Background()) + defer container.Client.Disconnect(t.Context()) mongodbModel := createMongodbModel(container) logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - ctx := context.Background() + ctx := t.Context() isReadOnly, roles, err := mongodbModel.IsUserReadOnly(ctx, logger, nil, uuid.New()) assert.NoError(t, err) @@ -195,9 +195,9 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) { func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) { env := config.GetEnv() container := connectToMongodbContainer(t, env.TestMongodb70Port, tools.MongodbVersion7) - defer container.Client.Disconnect(context.Background()) + defer container.Client.Disconnect(t.Context()) - ctx := context.Background() + ctx := t.Context() db := container.Client.Database(container.Database) _ = db.Collection("readonly_check_test").Drop(ctx) @@ -251,15 +251,15 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) { t.Parallel() container := connectToMongodbContainer(t, tc.port, tc.version) - defer container.Client.Disconnect(context.Background()) + defer container.Client.Disconnect(t.Context()) - ctx := context.Background() + ctx := t.Context() db := container.Client.Database(container.Database) _ = db.Collection("readonly_test").Drop(ctx) _ = db.Collection("hack_collection").Drop(ctx) - _, err := db.Collection("readonly_test").InsertMany(ctx, []interface{}{ + _, err := db.Collection("readonly_test").InsertMany(ctx, []any{ bson.M{"data": "test1"}, bson.M{"data": "test2"}, }) @@ -317,9 +317,9 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) { func Test_ReadOnlyUser_FutureCollections_CanSelect(t *testing.T) { env := config.GetEnv() container := connectToMongodbContainer(t, env.TestMongodb70Port, tools.MongodbVersion7) - defer container.Client.Disconnect(context.Background()) + defer container.Client.Disconnect(t.Context()) - ctx := context.Background() + ctx := t.Context() db := container.Client.Database(container.Database) mongodbModel := createMongodbModel(container) @@ -348,9 +348,9 @@ func Test_ReadOnlyUser_FutureCollections_CanSelect(t *testing.T) { func Test_ReadOnlyUser_CannotDropOrModifyCollections(t *testing.T) { env := config.GetEnv() container := connectToMongodbContainer(t, env.TestMongodb70Port, tools.MongodbVersion7) - defer container.Client.Disconnect(context.Background()) + defer container.Client.Disconnect(t.Context()) - ctx := context.Background() + ctx := t.Context() db := container.Client.Database(container.Database) _ = db.Collection("drop_test").Drop(ctx) @@ -420,7 +420,7 @@ func connectToMongodbContainer( authDatabase, ) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) defer cancel() clientOptions := options.Client().ApplyURI(uri) @@ -473,7 +473,7 @@ func connectWithCredentials( container.Database, container.AuthDatabase, ) - ctx := context.Background() + ctx := t.Context() clientOptions := options.Client().ApplyURI(uri) client, err := mongo.Connect(ctx, clientOptions) assert.NoError(t, err) diff --git a/backend/internal/features/databases/databases/mysql/model_test.go b/backend/internal/features/databases/databases/mysql/model_test.go index 68127a6..0d0ee9e 100644 --- a/backend/internal/features/databases/databases/mysql/model_test.go +++ b/backend/internal/features/databases/databases/mysql/model_test.go @@ -1,7 +1,6 @@ package mysql import ( - "context" "fmt" "log/slog" "os" @@ -231,7 +230,7 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) { mysqlModel := createMysqlModel(container) logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - ctx := context.Background() + ctx := t.Context() isReadOnly, privileges, err := mysqlModel.IsUserReadOnly(ctx, logger, nil, uuid.New()) assert.NoError(t, err) @@ -260,7 +259,7 @@ func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) { mysqlModel := createMysqlModel(container) logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - ctx := context.Background() + ctx := t.Context() username, password, err := mysqlModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New()) assert.NoError(t, err) @@ -326,7 +325,7 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) { mysqlModel := createMysqlModel(container) logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - ctx := context.Background() + ctx := t.Context() username, password, err := mysqlModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New()) assert.NoError(t, err) @@ -400,7 +399,7 @@ func Test_ReadOnlyUser_FutureTables_NoSelectPermission(t *testing.T) { mysqlModel := createMysqlModel(container) logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - ctx := context.Background() + ctx := t.Context() username, password, err := mysqlModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New()) assert.NoError(t, err) @@ -477,7 +476,7 @@ func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) { } logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - ctx := context.Background() + ctx := t.Context() username, password, err := mysqlModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New()) assert.NoError(t, err) @@ -523,7 +522,7 @@ func Test_ReadOnlyUser_CannotDropOrAlterTables(t *testing.T) { mysqlModel := createMysqlModel(container) logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - ctx := context.Background() + ctx := t.Context() username, password, err := mysqlModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New()) assert.NoError(t, err) diff --git a/backend/internal/features/databases/databases/postgresql/model_test.go b/backend/internal/features/databases/databases/postgresql/model_test.go index 0eb4c40..2476cf7 100644 --- a/backend/internal/features/databases/databases/postgresql/model_test.go +++ b/backend/internal/features/databases/databases/postgresql/model_test.go @@ -1,7 +1,6 @@ package postgresql import ( - "context" "fmt" "log/slog" "os" @@ -267,7 +266,7 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) { pgModel := createPostgresModel(container) logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - ctx := context.Background() + ctx := t.Context() isReadOnly, privileges, err := pgModel.IsUserReadOnly(ctx, logger, nil, uuid.New()) assert.NoError(t, err) @@ -294,7 +293,7 @@ func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) { pgModel := createPostgresModel(container) logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - ctx := context.Background() + ctx := t.Context() username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New()) assert.NoError(t, err) @@ -359,7 +358,7 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) { pgModel := createPostgresModel(container) logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - ctx := context.Background() + ctx := t.Context() username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New()) assert.NoError(t, err) @@ -438,7 +437,7 @@ func Test_ReadOnlyUser_FutureTables_HaveSelectPermission(t *testing.T) { pgModel := createPostgresModel(container) logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - ctx := context.Background() + ctx := t.Context() username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New()) assert.NoError(t, err) @@ -491,7 +490,7 @@ func Test_ReadOnlyUser_MultipleSchemas_AllAccessible(t *testing.T) { pgModel := createPostgresModel(container) logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - ctx := context.Background() + ctx := t.Context() username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New()) assert.NoError(t, err) @@ -566,7 +565,7 @@ func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) { } logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - ctx := context.Background() + ctx := t.Context() username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New()) assert.NoError(t, err) @@ -653,7 +652,7 @@ func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) { } logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - ctx := context.Background() + ctx := t.Context() connectionUsername, newPassword, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New()) assert.NoError(t, err) @@ -743,7 +742,7 @@ func Test_CreateReadOnlyUser_WithPublicSchema_Success(t *testing.T) { pgModel := createPostgresModel(container) logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - ctx := context.Background() + ctx := t.Context() username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New()) assert.NoError(t, err) @@ -851,7 +850,7 @@ func Test_CreateReadOnlyUser_WithoutPublicSchema_Success(t *testing.T) { pgModel := createPostgresModel(container) logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - ctx := context.Background() + ctx := t.Context() username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New()) assert.NoError(t, err, "CreateReadOnlyUser should succeed without public schema") @@ -1018,7 +1017,7 @@ func Test_CreateReadOnlyUser_PublicSchemaExistsButNoPermissions_ReturnsError(t * } logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - ctx := context.Background() + ctx := t.Context() username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New()) assert.Error( @@ -1435,7 +1434,7 @@ func Test_CreateReadOnlyUser_TablesCreatedByDifferentUser_ReadOnlyUserCanRead(t // At this point, user_creator already owns objects, so ALTER DEFAULT PRIVILEGES FOR ROLE should apply pgModel := createPostgresModel(container) logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - ctx := context.Background() + ctx := t.Context() readonlyUsername, readonlyPassword, err := pgModel.CreateReadOnlyUser( ctx, @@ -1602,7 +1601,7 @@ func Test_CreateReadOnlyUser_WithIncludeSchemas_OnlyGrantsAccessToSpecifiedSchem pgModel.IncludeSchemas = []string{"public", "included_schema"} logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - ctx := context.Background() + ctx := t.Context() readonlyUsername, readonlyPassword, err := pgModel.CreateReadOnlyUser( ctx, diff --git a/backend/internal/features/databases/di.go b/backend/internal/features/databases/di.go index 870c3a6..4dec3b4 100644 --- a/backend/internal/features/databases/di.go +++ b/backend/internal/features/databases/di.go @@ -2,7 +2,6 @@ package databases import ( "sync" - "sync/atomic" audit_logs "databasus-backend/internal/features/audit_logs" "databasus-backend/internal/features/notifiers" @@ -40,22 +39,7 @@ func GetDatabaseController() *DatabaseController { return databaseController } -var ( - setupOnce sync.Once - isSetup atomic.Bool -) - -func SetupDependencies() { - wasAlreadySetup := isSetup.Load() - - setupOnce.Do(func() { - workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService) - notifiers.GetNotifierService().SetNotifierDatabaseCounter(databaseService) - - isSetup.Store(true) - }) - - if wasAlreadySetup { - logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call") - } -} +var SetupDependencies = sync.OnceFunc(func() { + workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService) + notifiers.GetNotifierService().SetNotifierDatabaseCounter(databaseService) +}) diff --git a/backend/internal/features/databases/model.go b/backend/internal/features/databases/model.go index 0e8f012..f13e7a7 100644 --- a/backend/internal/features/databases/model.go +++ b/backend/internal/features/databases/model.go @@ -25,16 +25,16 @@ type Database struct { Name string `json:"name" gorm:"column:name;type:text;not null"` Type DatabaseType `json:"type" gorm:"column:type;type:text;not null"` - Postgresql *postgresql.PostgresqlDatabase `json:"postgresql,omitempty" gorm:"foreignKey:DatabaseID"` - Mysql *mysql.MysqlDatabase `json:"mysql,omitempty" gorm:"foreignKey:DatabaseID"` - Mariadb *mariadb.MariadbDatabase `json:"mariadb,omitempty" gorm:"foreignKey:DatabaseID"` - Mongodb *mongodb.MongodbDatabase `json:"mongodb,omitempty" gorm:"foreignKey:DatabaseID"` + Postgresql *postgresql.PostgresqlDatabase `json:"postgresql,omitzero" gorm:"foreignKey:DatabaseID"` + Mysql *mysql.MysqlDatabase `json:"mysql,omitzero" gorm:"foreignKey:DatabaseID"` + Mariadb *mariadb.MariadbDatabase `json:"mariadb,omitzero" gorm:"foreignKey:DatabaseID"` + Mongodb *mongodb.MongodbDatabase `json:"mongodb,omitzero" gorm:"foreignKey:DatabaseID"` Notifiers []notifiers.Notifier `json:"notifiers" gorm:"many2many:database_notifiers;"` // these fields are not reliable, but // they are used for pretty UI - LastBackupTime *time.Time `json:"lastBackupTime,omitempty" gorm:"column:last_backup_time;type:timestamp with time zone"` + LastBackupTime *time.Time `json:"lastBackupTime,omitzero" gorm:"column:last_backup_time;type:timestamp with time zone"` LastBackupErrorMessage *string `json:"lastBackupErrorMessage,omitempty" gorm:"column:last_backup_error_message;type:text"` HealthStatus *HealthStatus `json:"healthStatus" gorm:"column:health_status;type:text;not null"` diff --git a/backend/internal/features/healthcheck/attempt/background_service.go b/backend/internal/features/healthcheck/attempt/background_service.go index dda0d10..c809ba5 100644 --- a/backend/internal/features/healthcheck/attempt/background_service.go +++ b/backend/internal/features/healthcheck/attempt/background_service.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "log/slog" - "sync" "sync/atomic" "time" @@ -16,34 +15,28 @@ type HealthcheckAttemptBackgroundService struct { checkDatabaseHealthUseCase *CheckDatabaseHealthUseCase logger *slog.Logger - runOnce sync.Once - hasRun atomic.Bool + hasRun atomic.Bool } func (s *HealthcheckAttemptBackgroundService) Run(ctx context.Context) { - wasAlreadyRun := s.hasRun.Load() - - s.runOnce.Do(func() { - s.hasRun.Store(true) - - // first healthcheck immediately - s.checkDatabases() - - ticker := time.NewTicker(time.Minute) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - s.checkDatabases() - } - } - }) - - if wasAlreadyRun { + if s.hasRun.Swap(true) { panic(fmt.Sprintf("%T.Run() called multiple times", s)) } + + // first healthcheck immediately + s.checkDatabases() + + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + s.checkDatabases() + } + } } func (s *HealthcheckAttemptBackgroundService) checkDatabases() { diff --git a/backend/internal/features/healthcheck/attempt/di.go b/backend/internal/features/healthcheck/attempt/di.go index f1d8d15..381874c 100644 --- a/backend/internal/features/healthcheck/attempt/di.go +++ b/backend/internal/features/healthcheck/attempt/di.go @@ -1,9 +1,6 @@ package healthcheck_attempt import ( - "sync" - "sync/atomic" - "databasus-backend/internal/features/databases" healthcheck_config "databasus-backend/internal/features/healthcheck/config" "databasus-backend/internal/features/notifiers" @@ -30,8 +27,6 @@ var healthcheckAttemptBackgroundService = &HealthcheckAttemptBackgroundService{ healthcheckConfigService: healthcheck_config.GetHealthcheckConfigService(), checkDatabaseHealthUseCase: checkDatabaseHealthUseCase, logger: logger.GetLogger(), - runOnce: sync.Once{}, - hasRun: atomic.Bool{}, } var healthcheckAttemptController = &HealthcheckAttemptController{ diff --git a/backend/internal/features/healthcheck/config/di.go b/backend/internal/features/healthcheck/config/di.go index bb2a4a6..b205027 100644 --- a/backend/internal/features/healthcheck/config/di.go +++ b/backend/internal/features/healthcheck/config/di.go @@ -2,7 +2,6 @@ package healthcheck_config import ( "sync" - "sync/atomic" "databasus-backend/internal/features/audit_logs" "databasus-backend/internal/features/databases" @@ -33,23 +32,8 @@ func GetHealthcheckConfigController() *HealthcheckConfigController { return healthcheckConfigController } -var ( - setupOnce sync.Once - isSetup atomic.Bool -) - -func SetupDependencies() { - wasAlreadySetup := isSetup.Load() - - setupOnce.Do(func() { - databases. - GetDatabaseService(). - AddDbCreationListener(healthcheckConfigService) - - isSetup.Store(true) - }) - - if wasAlreadySetup { - logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call") - } -} +var SetupDependencies = sync.OnceFunc(func() { + databases. + GetDatabaseService(). + AddDbCreationListener(healthcheckConfigService) +}) diff --git a/backend/internal/features/notifiers/di.go b/backend/internal/features/notifiers/di.go index c86f02a..eefd888 100644 --- a/backend/internal/features/notifiers/di.go +++ b/backend/internal/features/notifiers/di.go @@ -2,7 +2,6 @@ package notifiers import ( "sync" - "sync/atomic" audit_logs "databasus-backend/internal/features/audit_logs" workspaces_services "databasus-backend/internal/features/workspaces/services" @@ -39,21 +38,6 @@ func GetNotifierRepository() *NotifierRepository { return notifierRepository } -var ( - setupOnce sync.Once - isSetup atomic.Bool -) - -func SetupDependencies() { - wasAlreadySetup := isSetup.Load() - - setupOnce.Do(func() { - workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService) - - isSetup.Store(true) - }) - - if wasAlreadySetup { - logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call") - } -} +var SetupDependencies = sync.OnceFunc(func() { + workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService) +}) diff --git a/backend/internal/features/notifiers/model.go b/backend/internal/features/notifiers/model.go index 1aa9d3d..bd636d9 100644 --- a/backend/internal/features/notifiers/model.go +++ b/backend/internal/features/notifiers/model.go @@ -23,12 +23,12 @@ type Notifier struct { LastSendError *string `json:"lastSendError" gorm:"column:last_send_error;type:text"` // specific notifier - TelegramNotifier *telegram_notifier.TelegramNotifier `json:"telegramNotifier" gorm:"foreignKey:NotifierID"` - EmailNotifier *email_notifier.EmailNotifier `json:"emailNotifier" gorm:"foreignKey:NotifierID"` - WebhookNotifier *webhook_notifier.WebhookNotifier `json:"webhookNotifier" gorm:"foreignKey:NotifierID"` - SlackNotifier *slack_notifier.SlackNotifier `json:"slackNotifier" gorm:"foreignKey:NotifierID"` - DiscordNotifier *discord_notifier.DiscordNotifier `json:"discordNotifier" gorm:"foreignKey:NotifierID"` - TeamsNotifier *teams_notifier.TeamsNotifier `json:"teamsNotifier,omitempty" gorm:"foreignKey:NotifierID;constraint:OnDelete:CASCADE"` + TelegramNotifier *telegram_notifier.TelegramNotifier `json:"telegramNotifier" gorm:"foreignKey:NotifierID"` + EmailNotifier *email_notifier.EmailNotifier `json:"emailNotifier" gorm:"foreignKey:NotifierID"` + WebhookNotifier *webhook_notifier.WebhookNotifier `json:"webhookNotifier" gorm:"foreignKey:NotifierID"` + SlackNotifier *slack_notifier.SlackNotifier `json:"slackNotifier" gorm:"foreignKey:NotifierID"` + DiscordNotifier *discord_notifier.DiscordNotifier `json:"discordNotifier" gorm:"foreignKey:NotifierID"` + TeamsNotifier *teams_notifier.TeamsNotifier `json:"teamsNotifier,omitzero" gorm:"foreignKey:NotifierID;constraint:OnDelete:CASCADE"` } func (n *Notifier) TableName() string { diff --git a/backend/internal/features/notifiers/models/teams/model.go b/backend/internal/features/notifiers/models/teams/model.go index c75933e..9f925a7 100644 --- a/backend/internal/features/notifiers/models/teams/model.go +++ b/backend/internal/features/notifiers/models/teams/model.go @@ -49,7 +49,7 @@ type cardAttachment struct { type payload struct { Title string `json:"title"` Text string `json:"text"` - Attachments []cardAttachment `json:"attachments,omitempty"` + Attachments []cardAttachment `json:"attachments,omitzero"` } func (n *TeamsNotifier) Send( diff --git a/backend/internal/features/restores/controller_test.go b/backend/internal/features/restores/controller_test.go index 9663dbe..8a90b16 100644 --- a/backend/internal/features/restores/controller_test.go +++ b/backend/internal/features/restores/controller_test.go @@ -462,7 +462,7 @@ func Test_CancelRestore_InProgressRestore_SuccessfullyCancelled(t *testing.T) { }, } - var restoreResponse map[string]interface{} + var restoreResponse map[string]any test_utils.MakePostRequestAndUnmarshal( t, router, diff --git a/backend/internal/features/restores/di.go b/backend/internal/features/restores/di.go index d9e05d1..f2ea1a4 100644 --- a/backend/internal/features/restores/di.go +++ b/backend/internal/features/restores/di.go @@ -2,7 +2,6 @@ package restores import ( "sync" - "sync/atomic" audit_logs "databasus-backend/internal/features/audit_logs" "databasus-backend/internal/features/backups/backups/backuping" @@ -45,22 +44,7 @@ func GetRestoreController() *RestoreController { return restoreController } -var ( - setupOnce sync.Once - isSetup atomic.Bool -) - -func SetupDependencies() { - wasAlreadySetup := isSetup.Load() - - setupOnce.Do(func() { - backups_services.GetBackupService().AddBackupRemoveListener(restoreService) - backuping.GetBackupCleaner().AddBackupRemoveListener(restoreService) - - isSetup.Store(true) - }) - - if wasAlreadySetup { - logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call") - } -} +var SetupDependencies = sync.OnceFunc(func() { + backups_services.GetBackupService().AddBackupRemoveListener(restoreService) + backuping.GetBackupCleaner().AddBackupRemoveListener(restoreService) +}) diff --git a/backend/internal/features/restores/restoring/di.go b/backend/internal/features/restores/restoring/di.go index c2000ac..facafb5 100644 --- a/backend/internal/features/restores/restoring/di.go +++ b/backend/internal/features/restores/restoring/di.go @@ -1,7 +1,6 @@ package restoring import ( - "sync" "sync/atomic" "time" @@ -27,8 +26,6 @@ var restoreNodesRegistry = &RestoreNodesRegistry{ timeout: cache_utils.DefaultCacheTimeout, pubsubRestores: cache_utils.NewPubSubManager(), pubsubCompletions: cache_utils.NewPubSubManager(), - runOnce: sync.Once{}, - hasRun: atomic.Bool{}, } var restoreDatabaseCache = cache_utils.NewCacheUtil[RestoreDatabaseCache]( @@ -52,7 +49,6 @@ var restorerNode = &RestorerNode{ restoreDatabaseCache, restoreCancelManager, time.Time{}, - sync.Once{}, atomic.Bool{}, } @@ -68,7 +64,6 @@ var restoresScheduler = &RestoresScheduler{ restorerNode, restoreDatabaseCache, uuid.Nil, - sync.Once{}, atomic.Bool{}, } diff --git a/backend/internal/features/restores/restoring/dto.go b/backend/internal/features/restores/restoring/dto.go index fb17e70..e1a6d55 100644 --- a/backend/internal/features/restores/restoring/dto.go +++ b/backend/internal/features/restores/restoring/dto.go @@ -12,10 +12,10 @@ import ( ) type RestoreDatabaseCache struct { - PostgresqlDatabase *postgresql.PostgresqlDatabase `json:"postgresqlDatabase,omitempty"` - MysqlDatabase *mysql.MysqlDatabase `json:"mysqlDatabase,omitempty"` - MariadbDatabase *mariadb.MariadbDatabase `json:"mariadbDatabase,omitempty"` - MongodbDatabase *mongodb.MongodbDatabase `json:"mongodbDatabase,omitempty"` + PostgresqlDatabase *postgresql.PostgresqlDatabase `json:"postgresqlDatabase,omitzero"` + MysqlDatabase *mysql.MysqlDatabase `json:"mysqlDatabase,omitzero"` + MariadbDatabase *mariadb.MariadbDatabase `json:"mariadbDatabase,omitzero"` + MongodbDatabase *mongodb.MongodbDatabase `json:"mongodbDatabase,omitzero"` } type RestoreToNodeRelation struct { diff --git a/backend/internal/features/restores/restoring/registry.go b/backend/internal/features/restores/restoring/registry.go index 950dfc0..ffc095d 100644 --- a/backend/internal/features/restores/restoring/registry.go +++ b/backend/internal/features/restores/restoring/registry.go @@ -6,7 +6,6 @@ import ( "fmt" "log/slog" "strings" - "sync" "sync/atomic" "time" @@ -50,36 +49,30 @@ type RestoreNodesRegistry struct { pubsubRestores *cache_utils.PubSubManager pubsubCompletions *cache_utils.PubSubManager - runOnce sync.Once - hasRun atomic.Bool + hasRun atomic.Bool } func (r *RestoreNodesRegistry) Run(ctx context.Context) { - wasAlreadyRun := r.hasRun.Load() + if r.hasRun.Swap(true) { + panic(fmt.Sprintf("%T.Run() called multiple times", r)) + } - r.runOnce.Do(func() { - r.hasRun.Store(true) + if err := r.cleanupDeadNodes(); err != nil { + r.logger.Error("Failed to cleanup dead nodes on startup", "error", err) + } - if err := r.cleanupDeadNodes(); err != nil { - r.logger.Error("Failed to cleanup dead nodes on startup", "error", err) - } + ticker := time.NewTicker(cleanupTickerInterval) + defer ticker.Stop() - ticker := time.NewTicker(cleanupTickerInterval) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - if err := r.cleanupDeadNodes(); err != nil { - r.logger.Error("Failed to cleanup dead nodes", "error", err) - } + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := r.cleanupDeadNodes(); err != nil { + r.logger.Error("Failed to cleanup dead nodes", "error", err) } } - }) - - if wasAlreadyRun { - panic(fmt.Sprintf("%T.Run() called multiple times", r)) } } diff --git a/backend/internal/features/restores/restoring/registry_test.go b/backend/internal/features/restores/restoring/registry_test.go index 653b491..0c2fd9d 100644 --- a/backend/internal/features/restores/restoring/registry_test.go +++ b/backend/internal/features/restores/restoring/registry_test.go @@ -4,8 +4,6 @@ import ( "context" "encoding/json" "fmt" - "sync" - "sync/atomic" "testing" "time" @@ -322,7 +320,7 @@ func Test_GetAvailableNodes_SkipsInvalidJsonData(t *testing.T) { err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node) assert.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), registry.timeout) + ctx, cancel := context.WithTimeout(t.Context(), registry.timeout) defer cancel() invalidKey := nodeInfoKeyPrefix + uuid.New().String() + nodeInfoKeySuffix @@ -331,7 +329,7 @@ func Test_GetAvailableNodes_SkipsInvalidJsonData(t *testing.T) { registry.client.B().Set().Key(invalidKey).Value("invalid json data").Build(), ) defer func() { - cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), registry.timeout) + cleanupCtx, cleanupCancel := context.WithTimeout(t.Context(), registry.timeout) defer cleanupCancel() registry.client.Do(cleanupCtx, registry.client.B().Del().Key(invalidKey).Build()) }() @@ -401,7 +399,7 @@ func Test_GetAvailableNodes_ExcludesStaleNodesFromCache(t *testing.T) { err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3) assert.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), registry.timeout) + ctx, cancel := context.WithTimeout(t.Context(), registry.timeout) defer cancel() key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix) @@ -419,7 +417,7 @@ func Test_GetAvailableNodes_ExcludesStaleNodesFromCache(t *testing.T) { modifiedData, err := json.Marshal(node) assert.NoError(t, err) - setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout) + setCtx, setCancel := context.WithTimeout(t.Context(), registry.timeout) defer setCancel() setResult := registry.client.Do( setCtx, @@ -464,7 +462,7 @@ func Test_GetRestoreNodesStats_ExcludesStaleNodesFromCache(t *testing.T) { err = registry.IncrementRestoresInProgress(node3.ID) assert.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), registry.timeout) + ctx, cancel := context.WithTimeout(t.Context(), registry.timeout) defer cancel() key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix) @@ -482,7 +480,7 @@ func Test_GetRestoreNodesStats_ExcludesStaleNodesFromCache(t *testing.T) { modifiedData, err := json.Marshal(node) assert.NoError(t, err) - setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout) + setCtx, setCancel := context.WithTimeout(t.Context(), registry.timeout) defer setCancel() setResult := registry.client.Do( setCtx, @@ -524,7 +522,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) { err = registry.IncrementRestoresInProgress(node2.ID) assert.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), registry.timeout) + ctx, cancel := context.WithTimeout(t.Context(), registry.timeout) defer cancel() key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix) @@ -542,7 +540,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) { modifiedData, err := json.Marshal(node) assert.NoError(t, err) - setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout) + setCtx, setCancel := context.WithTimeout(t.Context(), registry.timeout) defer setCancel() setResult := registry.client.Do( setCtx, @@ -553,7 +551,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) { err = registry.cleanupDeadNodes() assert.NoError(t, err) - checkCtx, checkCancel := context.WithTimeout(context.Background(), registry.timeout) + checkCtx, checkCancel := context.WithTimeout(t.Context(), registry.timeout) defer checkCancel() infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix) @@ -566,7 +564,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) { node2.ID.String(), nodeActiveRestoresSuffix, ) - counterCtx, counterCancel := context.WithTimeout(context.Background(), registry.timeout) + counterCtx, counterCancel := context.WithTimeout(t.Context(), registry.timeout) defer counterCancel() counterResult := registry.client.Do( counterCtx, @@ -575,7 +573,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) { assert.Error(t, counterResult.Error()) activeInfoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node1.ID.String(), nodeInfoKeySuffix) - activeCtx, activeCancel := context.WithTimeout(context.Background(), registry.timeout) + activeCtx, activeCancel := context.WithTimeout(t.Context(), registry.timeout) defer activeCancel() activeResult := registry.client.Do( activeCtx, @@ -601,8 +599,6 @@ func createTestRegistry() *RestoreNodesRegistry { timeout: cache_utils.DefaultCacheTimeout, pubsubRestores: cache_utils.NewPubSubManager(), pubsubCompletions: cache_utils.NewPubSubManager(), - runOnce: sync.Once{}, - hasRun: atomic.Bool{}, } } @@ -734,7 +730,7 @@ func Test_SubscribeNodeForRestoresAssignment_HandlesInvalidJson(t *testing.T) { time.Sleep(100 * time.Millisecond) - ctx := context.Background() + ctx := t.Context() err = registry.pubsubRestores.Publish(ctx, "restore:submit", "invalid json") assert.NoError(t, err) @@ -980,7 +976,7 @@ func Test_SubscribeForRestoresCompletions_HandlesInvalidJson(t *testing.T) { time.Sleep(100 * time.Millisecond) - ctx := context.Background() + ctx := t.Context() err = registry.pubsubCompletions.Publish(ctx, "restore:completion", "invalid json") assert.NoError(t, err) @@ -1095,7 +1091,7 @@ func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) { receivedAll2 := []uuid.UUID{} receivedAll3 := []uuid.UUID{} - for i := 0; i < 3; i++ { + for range 3 { select { case received := <-receivedRestores1: receivedAll1 = append(receivedAll1, received) @@ -1104,7 +1100,7 @@ func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) { } } - for i := 0; i < 3; i++ { + for range 3 { select { case received := <-receivedRestores2: receivedAll2 = append(receivedAll2, received) @@ -1113,7 +1109,7 @@ func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) { } } - for i := 0; i < 3; i++ { + for range 3 { select { case received := <-receivedRestores3: receivedAll3 = append(receivedAll3, received) diff --git a/backend/internal/features/restores/restoring/restorer.go b/backend/internal/features/restores/restoring/restorer.go index 1ad89cc..8859bed 100644 --- a/backend/internal/features/restores/restoring/restorer.go +++ b/backend/internal/features/restores/restoring/restorer.go @@ -6,7 +6,6 @@ import ( "fmt" "log/slog" "strings" - "sync" "sync/atomic" "time" @@ -45,81 +44,74 @@ type RestorerNode struct { lastHeartbeat time.Time - runOnce sync.Once - hasRun atomic.Bool + hasRun atomic.Bool } func (n *RestorerNode) Run(ctx context.Context) { - wasAlreadyRun := n.hasRun.Load() - - n.runOnce.Do(func() { - n.hasRun.Store(true) - - n.lastHeartbeat = time.Now().UTC() - - throughputMBs := config.GetEnv().NodeNetworkThroughputMBs - - restoreNode := RestoreNode{ - ID: n.nodeID, - ThroughputMBs: throughputMBs, - } - - if err := n.restoreNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), restoreNode); err != nil { - n.logger.Error("Failed to register node in registry", "error", err) - panic(err) - } - - restoreHandler := func(restoreID uuid.UUID, isCallNotifier bool) { - n.MakeRestore(restoreID) - if err := n.restoreNodesRegistry.PublishRestoreCompletion(n.nodeID, restoreID); err != nil { - n.logger.Error( - "Failed to publish restore completion", - "error", - err, - "restoreID", - restoreID, - ) - } - } - - err := n.restoreNodesRegistry.SubscribeNodeForRestoresAssignment( - n.nodeID, - restoreHandler, - ) - if err != nil { - n.logger.Error("Failed to subscribe to restore assignments", "error", err) - panic(err) - } - defer func() { - if err := n.restoreNodesRegistry.UnsubscribeNodeForRestoresAssignments(); err != nil { - n.logger.Error("Failed to unsubscribe from restore assignments", "error", err) - } - }() - - ticker := time.NewTicker(heartbeatTickerInterval) - defer ticker.Stop() - - n.logger.Info("Restore node started", "nodeID", n.nodeID, "throughput", throughputMBs) - - for { - select { - case <-ctx.Done(): - n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID) - - if err := n.restoreNodesRegistry.UnregisterNodeFromRegistry(restoreNode); err != nil { - n.logger.Error("Failed to unregister node from registry", "error", err) - } - - return - case <-ticker.C: - n.sendHeartbeat(&restoreNode) - } - } - }) - - if wasAlreadyRun { + if n.hasRun.Swap(true) { panic(fmt.Sprintf("%T.Run() called multiple times", n)) } + + n.lastHeartbeat = time.Now().UTC() + + throughputMBs := config.GetEnv().NodeNetworkThroughputMBs + + restoreNode := RestoreNode{ + ID: n.nodeID, + ThroughputMBs: throughputMBs, + } + + if err := n.restoreNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), restoreNode); err != nil { + n.logger.Error("Failed to register node in registry", "error", err) + panic(err) + } + + restoreHandler := func(restoreID uuid.UUID, isCallNotifier bool) { + n.MakeRestore(restoreID) + if err := n.restoreNodesRegistry.PublishRestoreCompletion(n.nodeID, restoreID); err != nil { + n.logger.Error( + "Failed to publish restore completion", + "error", + err, + "restoreID", + restoreID, + ) + } + } + + err := n.restoreNodesRegistry.SubscribeNodeForRestoresAssignment( + n.nodeID, + restoreHandler, + ) + if err != nil { + n.logger.Error("Failed to subscribe to restore assignments", "error", err) + panic(err) + } + defer func() { + if err := n.restoreNodesRegistry.UnsubscribeNodeForRestoresAssignments(); err != nil { + n.logger.Error("Failed to unsubscribe from restore assignments", "error", err) + } + }() + + ticker := time.NewTicker(heartbeatTickerInterval) + defer ticker.Stop() + + n.logger.Info("Restore node started", "nodeID", n.nodeID, "throughput", throughputMBs) + + for { + select { + case <-ctx.Done(): + n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID) + + if err := n.restoreNodesRegistry.UnregisterNodeFromRegistry(restoreNode); err != nil { + n.logger.Error("Failed to unregister node from registry", "error", err) + } + + return + case <-ticker.C: + n.sendHeartbeat(&restoreNode) + } + } } func (n *RestorerNode) IsRestorerRunning() bool { diff --git a/backend/internal/features/restores/restoring/restorer_test.go b/backend/internal/features/restores/restoring/restorer_test.go index f7881fb..d7ff6ef 100644 --- a/backend/internal/features/restores/restoring/restorer_test.go +++ b/backend/internal/features/restores/restoring/restorer_test.go @@ -144,7 +144,7 @@ func Test_MakeRestore_WhenTaskStarts_CacheDeletedImmediately(t *testing.T) { Port: 5432, Username: "test", Password: "test", - Database: stringPtr("testdb"), + Database: new("testdb"), Version: "16", }, } @@ -162,7 +162,3 @@ func Test_MakeRestore_WhenTaskStarts_CacheDeletedImmediately(t *testing.T) { cachedDBAfter := restoreDatabaseCache.Get(restore.ID.String()) assert.Nil(t, cachedDBAfter, "Cache should be deleted immediately when task starts") } - -func stringPtr(s string) *string { - return &s -} diff --git a/backend/internal/features/restores/restoring/scheduler.go b/backend/internal/features/restores/restoring/scheduler.go index 80c55a8..faeab2d 100644 --- a/backend/internal/features/restores/restoring/scheduler.go +++ b/backend/internal/features/restores/restoring/scheduler.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "log/slog" - "sync" "sync/atomic" "time" @@ -37,64 +36,57 @@ type RestoresScheduler struct { cacheUtil *cache_utils.CacheUtil[RestoreDatabaseCache] completionSubscriptionID uuid.UUID - runOnce sync.Once - hasRun atomic.Bool + hasRun atomic.Bool } func (s *RestoresScheduler) Run(ctx context.Context) { - wasAlreadyRun := s.hasRun.Load() - - s.runOnce.Do(func() { - s.hasRun.Store(true) - - s.lastCheckTime = time.Now().UTC() - - if config.GetEnv().IsManyNodesMode { - // wait other nodes to start - time.Sleep(schedulerStartupDelay) - } - - if err := s.failRestoresInProgress(); err != nil { - s.logger.Error("Failed to fail restores in progress", "error", err) - panic(err) - } - - err := s.restoreNodesRegistry.SubscribeForRestoresCompletions(s.onRestoreCompleted) - if err != nil { - s.logger.Error("Failed to subscribe to restore completions", "error", err) - panic(err) - } - - defer func() { - if err := s.restoreNodesRegistry.UnsubscribeForRestoresCompletions(); err != nil { - s.logger.Error("Failed to unsubscribe from restore completions", "error", err) - } - }() - - if ctx.Err() != nil { - return - } - - ticker := time.NewTicker(schedulerTickerInterval) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - if err := s.checkDeadNodesAndFailRestores(); err != nil { - s.logger.Error("Failed to check dead nodes and fail restores", "error", err) - } - - s.lastCheckTime = time.Now().UTC() - } - } - }) - - if wasAlreadyRun { + if s.hasRun.Swap(true) { panic(fmt.Sprintf("%T.Run() called multiple times", s)) } + + s.lastCheckTime = time.Now().UTC() + + if config.GetEnv().IsManyNodesMode { + // wait other nodes to start + time.Sleep(schedulerStartupDelay) + } + + if err := s.failRestoresInProgress(); err != nil { + s.logger.Error("Failed to fail restores in progress", "error", err) + panic(err) + } + + err := s.restoreNodesRegistry.SubscribeForRestoresCompletions(s.onRestoreCompleted) + if err != nil { + s.logger.Error("Failed to subscribe to restore completions", "error", err) + panic(err) + } + + defer func() { + if err := s.restoreNodesRegistry.UnsubscribeForRestoresCompletions(); err != nil { + s.logger.Error("Failed to unsubscribe from restore completions", "error", err) + } + }() + + if ctx.Err() != nil { + return + } + + ticker := time.NewTicker(schedulerTickerInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := s.checkDeadNodesAndFailRestores(); err != nil { + s.logger.Error("Failed to check dead nodes and fail restores", "error", err) + } + + s.lastCheckTime = time.Now().UTC() + } + } } func (s *RestoresScheduler) IsSchedulerRunning() bool { diff --git a/backend/internal/features/restores/restoring/scheduler_test.go b/backend/internal/features/restores/restoring/scheduler_test.go index 67f3f11..f34ef0b 100644 --- a/backend/internal/features/restores/restoring/scheduler_test.go +++ b/backend/internal/features/restores/restoring/scheduler_test.go @@ -686,7 +686,7 @@ func Test_StartRestore_CredentialsStoredEncryptedInCache(t *testing.T) { Port: 5432, Username: "testuser", Password: plaintextPassword, - Database: stringPtr("testdb"), + Database: new("testdb"), Version: "16", } diff --git a/backend/internal/features/restores/restoring/testing.go b/backend/internal/features/restores/restoring/testing.go index 30002d7..99e69db 100644 --- a/backend/internal/features/restores/restoring/testing.go +++ b/backend/internal/features/restores/restoring/testing.go @@ -3,7 +3,6 @@ package restoring import ( "context" "fmt" - "sync" "sync/atomic" "testing" "time" @@ -53,7 +52,6 @@ func CreateTestRestorerNode() *RestorerNode { restoreDatabaseCache, tasks_cancellation.GetTaskCancelManager(), time.Time{}, - sync.Once{}, atomic.Bool{}, } } @@ -73,7 +71,6 @@ func CreateTestRestorerNodeWithUsecase(usecase restores_core.RestoreBackupUsecas restoreDatabaseCache, tasks_cancellation.GetTaskCancelManager(), time.Time{}, - sync.Once{}, atomic.Bool{}, } } @@ -91,7 +88,6 @@ func CreateTestRestoresScheduler() *RestoresScheduler { restorerNode, restoreDatabaseCache, uuid.Nil, - sync.Once{}, atomic.Bool{}, } } diff --git a/backend/internal/features/storages/di.go b/backend/internal/features/storages/di.go index e3b6ad9..5259cbd 100644 --- a/backend/internal/features/storages/di.go +++ b/backend/internal/features/storages/di.go @@ -2,12 +2,10 @@ package storages import ( "sync" - "sync/atomic" audit_logs "databasus-backend/internal/features/audit_logs" workspaces_services "databasus-backend/internal/features/workspaces/services" "databasus-backend/internal/util/encryption" - "databasus-backend/internal/util/logger" ) var ( @@ -34,21 +32,6 @@ func GetStorageController() *StorageController { return storageController } -var ( - setupOnce sync.Once - isSetup atomic.Bool -) - -func SetupDependencies() { - wasAlreadySetup := isSetup.Load() - - setupOnce.Do(func() { - workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(storageService) - - isSetup.Store(true) - }) - - if wasAlreadySetup { - logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call") - } -} +var SetupDependencies = sync.OnceFunc(func() { + workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(storageService) +}) diff --git a/backend/internal/features/storages/model_test.go b/backend/internal/features/storages/model_test.go index 3958921..e25cc79 100644 --- a/backend/internal/features/storages/model_test.go +++ b/backend/internal/features/storages/model_test.go @@ -49,7 +49,7 @@ type AzuriteContainer struct { } func Test_Storage_BasicOperations(t *testing.T) { - ctx := context.Background() + ctx := t.Context() validateEnvVariables(t) @@ -227,7 +227,7 @@ acl = private`, s3Container.accessKey, s3Container.secretKey, s3Container.endpoi fileID := uuid.New() err = tc.storage.SaveFile( - context.Background(), + t.Context(), encryptor, logger.GetLogger(), fileID.String(), @@ -250,7 +250,7 @@ acl = private`, s3Container.accessKey, s3Container.secretKey, s3Container.endpoi fileID := uuid.New() err = tc.storage.SaveFile( - context.Background(), + t.Context(), encryptor, logger.GetLogger(), fileID.String(), diff --git a/backend/internal/features/tasks/cancellation/cancel_manager_test.go b/backend/internal/features/tasks/cancellation/cancel_manager_test.go index 50bdb48..e696ac0 100644 --- a/backend/internal/features/tasks/cancellation/cancel_manager_test.go +++ b/backend/internal/features/tasks/cancellation/cancel_manager_test.go @@ -14,7 +14,7 @@ func Test_RegisterTask_TaskRegisteredSuccessfully(t *testing.T) { manager := taskCancelManager taskID := uuid.New() - _, cancel := context.WithCancel(context.Background()) + _, cancel := context.WithCancel(t.Context()) defer cancel() manager.RegisterTask(taskID, cancel) @@ -29,7 +29,7 @@ func Test_UnregisterTask_TaskUnregisteredSuccessfully(t *testing.T) { manager := taskCancelManager taskID := uuid.New() - _, cancel := context.WithCancel(context.Background()) + _, cancel := context.WithCancel(t.Context()) defer cancel() manager.RegisterTask(taskID, cancel) @@ -45,7 +45,7 @@ func Test_CancelTask_OnSameInstance_TaskCancelledViaPubSub(t *testing.T) { manager := taskCancelManager taskID := uuid.New() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) cancelled := false var mu sync.Mutex @@ -79,7 +79,7 @@ func Test_CancelTask_FromDifferentInstance_TaskCancelledOnRunningInstance(t *tes manager2 := taskCancelManager taskID := uuid.New() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) cancelled := false var mu sync.Mutex @@ -131,9 +131,9 @@ func Test_CancelTask_WithMultipleTasks_AllTasksCancelled(t *testing.T) { cancelledFlags := make([]bool, numTasks) var mu sync.Mutex - for i := 0; i < numTasks; i++ { + for i := range numTasks { taskIDs[i] = uuid.New() - contexts[i], cancels[i] = context.WithCancel(context.Background()) + contexts[i], cancels[i] = context.WithCancel(t.Context()) idx := i wrappedCancel := func() { @@ -149,7 +149,7 @@ func Test_CancelTask_WithMultipleTasks_AllTasksCancelled(t *testing.T) { manager.StartSubscription() time.Sleep(100 * time.Millisecond) - for i := 0; i < numTasks; i++ { + for i := range numTasks { err := manager.CancelTask(taskIDs[i]) assert.NoError(t, err, "Cancel should not return error") } @@ -157,7 +157,7 @@ func Test_CancelTask_WithMultipleTasks_AllTasksCancelled(t *testing.T) { time.Sleep(1 * time.Second) mu.Lock() - for i := 0; i < numTasks; i++ { + for i := range numTasks { assert.True(t, cancelledFlags[i], "Task %d should be cancelled", i) assert.Error(t, contexts[i].Err(), "Context %d should be cancelled", i) } @@ -168,7 +168,7 @@ func Test_CancelTask_AfterUnregister_TaskNotCancelled(t *testing.T) { manager := taskCancelManager taskID := uuid.New() - _, cancel := context.WithCancel(context.Background()) + _, cancel := context.WithCancel(t.Context()) defer cancel() cancelled := false diff --git a/backend/internal/features/tasks/cancellation/di.go b/backend/internal/features/tasks/cancellation/di.go index e302653..740b9d6 100644 --- a/backend/internal/features/tasks/cancellation/di.go +++ b/backend/internal/features/tasks/cancellation/di.go @@ -3,7 +3,6 @@ package task_cancellation import ( "context" "sync" - "sync/atomic" "github.com/google/uuid" @@ -22,21 +21,6 @@ func GetTaskCancelManager() *TaskCancelManager { return taskCancelManager } -var ( - setupOnce sync.Once - isSetup atomic.Bool -) - -func SetupDependencies() { - wasAlreadySetup := isSetup.Load() - - setupOnce.Do(func() { - taskCancelManager.StartSubscription() - - isSetup.Store(true) - }) - - if wasAlreadySetup { - logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call") - } -} +var SetupDependencies = sync.OnceFunc(func() { + taskCancelManager.StartSubscription() +}) diff --git a/backend/internal/features/tests/mongodb_backup_restore_test.go b/backend/internal/features/tests/mongodb_backup_restore_test.go index ef21647..7074c85 100644 --- a/backend/internal/features/tests/mongodb_backup_restore_test.go +++ b/backend/internal/features/tests/mongodb_backup_restore_test.go @@ -132,7 +132,7 @@ func testMongodbBackupRestoreForVersion( t.Skipf("Skipping MongoDB %s test: %v", mongodbVersion, err) return } - defer container.Client.Disconnect(context.Background()) + defer container.Client.Disconnect(t.Context()) setupMongodbTestData(t, container) @@ -177,7 +177,7 @@ func testMongodbBackupRestoreForVersion( verifyMongodbDataIntegrity(t, container, newDBName) - ctx := context.Background() + ctx := t.Context() _ = container.Client.Database(newDBName).Drop(ctx) err = os.Remove(filepath.Join(config.GetEnv().DataFolder, backup.ID.String())) @@ -206,7 +206,7 @@ func testMongodbBackupRestoreWithEncryptionForVersion( t.Skipf("Skipping MongoDB %s test: %v", mongodbVersion, err) return } - defer container.Client.Disconnect(context.Background()) + defer container.Client.Disconnect(t.Context()) setupMongodbTestData(t, container) @@ -256,7 +256,7 @@ func testMongodbBackupRestoreWithEncryptionForVersion( verifyMongodbDataIntegrity(t, container, newDBName) - ctx := context.Background() + ctx := t.Context() _ = container.Client.Database(newDBName).Drop(ctx) err = os.Remove(filepath.Join(config.GetEnv().DataFolder, backup.ID.String())) @@ -285,7 +285,7 @@ func testMongodbBackupRestoreWithReadOnlyUserForVersion( t.Skipf("Skipping MongoDB %s test: %v", mongodbVersion, err) return } - defer container.Client.Disconnect(context.Background()) + defer container.Client.Disconnect(t.Context()) setupMongodbTestData(t, container) @@ -344,7 +344,7 @@ func testMongodbBackupRestoreWithReadOnlyUserForVersion( verifyMongodbDataIntegrity(t, container, newDBName) - ctx := context.Background() + ctx := t.Context() _ = container.Client.Database(newDBName).Drop(ctx) dropMongodbUserSafe(container.Client, readOnlyUser.Username, container.AuthDatabase) @@ -498,7 +498,7 @@ func waitForMongodbRestoreCompletion( } func verifyMongodbDataIntegrity(t *testing.T, container *MongodbContainer, restoredDBName string) { - ctx := context.Background() + ctx := t.Context() originalCollection := container.Client.Database(container.Database).Collection("test_data") restoredCollection := container.Client.Database(restoredDBName).Collection("test_data") @@ -595,12 +595,12 @@ func connectToMongodbContainer( } func setupMongodbTestData(t *testing.T, container *MongodbContainer) { - ctx := context.Background() + ctx := t.Context() collection := container.Client.Database(container.Database).Collection("test_data") _ = collection.Drop(ctx) - testDocs := []interface{}{ + testDocs := []any{ MongodbTestDataItem{ ID: "1", Name: "test1", diff --git a/backend/internal/features/users/controllers/password_reset_test.go b/backend/internal/features/users/controllers/password_reset_test.go index c56cd2d..b2508cf 100644 --- a/backend/internal/features/users/controllers/password_reset_test.go +++ b/backend/internal/features/users/controllers/password_reset_test.go @@ -539,7 +539,7 @@ func extractCodeFromEmail(emailBody string) string { // Look for pattern:

CODE

// First find

Date: Sat, 28 Mar 2026 22:07:45 +0300 Subject: [PATCH 2/3] FIX (agent): Do not show cancel button for agent backups --- .../features/backups/ui/BackupsComponent.tsx | 38 ++++++++++--------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/frontend/src/features/backups/ui/BackupsComponent.tsx b/frontend/src/features/backups/ui/BackupsComponent.tsx index a41cc1b..6ccf2fb 100644 --- a/frontend/src/features/backups/ui/BackupsComponent.tsx +++ b/frontend/src/features/backups/ui/BackupsComponent.tsx @@ -302,24 +302,26 @@ export const BackupsComponent = ({ const renderActions = (record: Backup) => { return (
- {record.status === BackupStatus.IN_PROGRESS && isCanManageDBs && ( -
- {cancellingBackupId === record.id ? ( - - ) : ( - - { - if (cancellingBackupId) return; - cancelBackup(record.id); - }} - style={{ color: '#ff0000', opacity: cancellingBackupId ? 0.2 : 1 }} - /> - - )} -
- )} + {record.status === BackupStatus.IN_PROGRESS && + isCanManageDBs && + database.postgresql?.backupType !== PostgresBackupType.WAL_V1 && ( +
+ {cancellingBackupId === record.id ? ( + + ) : ( + + { + if (cancellingBackupId) return; + cancelBackup(record.id); + }} + style={{ color: '#ff0000', opacity: cancellingBackupId ? 0.2 : 1 }} + /> + + )} +
+ )} {record.status === BackupStatus.COMPLETED && (
From c7d091fe512efd918bb58fa9886e9c442fe4fcd3 Mon Sep 17 00:00:00 2001 From: Rostislav Dugin Date: Sat, 28 Mar 2026 22:52:46 +0300 Subject: [PATCH 3/3] FEATURE (agent): Stop WAL and FULL backups on staling within 5 mins --- .../features/api/idle_timeout_reader.go | 60 ++++++++++ .../features/api/idle_timeout_reader_test.go | 112 ++++++++++++++++++ .../internal/features/full_backup/backuper.go | 26 +++- .../features/full_backup/backuper_test.go | 62 ++++++++++ agent/internal/features/wal/streamer.go | 19 ++- agent/internal/features/wal/streamer_test.go | 45 +++++++ 6 files changed, 317 insertions(+), 7 deletions(-) create mode 100644 agent/internal/features/api/idle_timeout_reader.go create mode 100644 agent/internal/features/api/idle_timeout_reader_test.go diff --git a/agent/internal/features/api/idle_timeout_reader.go b/agent/internal/features/api/idle_timeout_reader.go new file mode 100644 index 0000000..6a59ffe --- /dev/null +++ b/agent/internal/features/api/idle_timeout_reader.go @@ -0,0 +1,60 @@ +package api + +import ( + "context" + "fmt" + "io" + "time" +) + +// IdleTimeoutReader wraps an io.Reader and cancels the associated context +// if no bytes are successfully read within the specified timeout duration. +// This detects stalled uploads where the network or source stops transmitting data. +// +// When the idle timeout fires, the reader is also closed (if it implements io.Closer) +// to unblock any goroutine blocked on the underlying Read. +type IdleTimeoutReader struct { + reader io.Reader + timeout time.Duration + cancel context.CancelCauseFunc + timer *time.Timer +} + +// NewIdleTimeoutReader creates a reader that cancels the context via cancel +// if Read does not return any bytes for the given timeout duration. +func NewIdleTimeoutReader(reader io.Reader, timeout time.Duration, cancel context.CancelCauseFunc) *IdleTimeoutReader { + r := &IdleTimeoutReader{ + reader: reader, + timeout: timeout, + cancel: cancel, + } + + r.timer = time.AfterFunc(timeout, func() { + cancel(fmt.Errorf("upload idle timeout: no bytes transmitted for %v", timeout)) + + if closer, ok := reader.(io.Closer); ok { + _ = closer.Close() + } + }) + + return r +} + +func (r *IdleTimeoutReader) Read(p []byte) (int, error) { + n, err := r.reader.Read(p) + + if n > 0 { + r.timer.Reset(r.timeout) + } + + if err != nil { + r.Stop() + } + + return n, err +} + +// Stop cancels the idle timer. Must be called when the reader is no longer needed. +func (r *IdleTimeoutReader) Stop() { + r.timer.Stop() +} diff --git a/agent/internal/features/api/idle_timeout_reader_test.go b/agent/internal/features/api/idle_timeout_reader_test.go new file mode 100644 index 0000000..17d52ff --- /dev/null +++ b/agent/internal/features/api/idle_timeout_reader_test.go @@ -0,0 +1,112 @@ +package api + +import ( + "context" + "fmt" + "io" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ReadThroughIdleTimeoutReader_WhenBytesFlowContinuously_DoesNotCancelContext(t *testing.T) { + ctx, cancel := context.WithCancelCause(t.Context()) + defer cancel(nil) + + pr, pw := io.Pipe() + + idleReader := NewIdleTimeoutReader(pr, 200*time.Millisecond, cancel) + defer idleReader.Stop() + + go func() { + for range 5 { + _, _ = pw.Write([]byte("data")) + time.Sleep(50 * time.Millisecond) + } + + _ = pw.Close() + }() + + data, err := io.ReadAll(idleReader) + + require.NoError(t, err) + assert.Equal(t, "datadatadatadatadata", string(data)) + assert.NoError(t, ctx.Err(), "context should not be cancelled when bytes flow continuously") +} + +func Test_ReadThroughIdleTimeoutReader_WhenNoBytesTransmitted_CancelsContext(t *testing.T) { + ctx, cancel := context.WithCancelCause(t.Context()) + defer cancel(nil) + + pr, _ := io.Pipe() + + idleReader := NewIdleTimeoutReader(pr, 100*time.Millisecond, cancel) + defer idleReader.Stop() + + time.Sleep(200 * time.Millisecond) + + assert.Error(t, ctx.Err(), "context should be cancelled when no bytes are transmitted") + assert.Contains(t, context.Cause(ctx).Error(), "upload idle timeout") +} + +func Test_ReadThroughIdleTimeoutReader_WhenBytesStopMidStream_CancelsContext(t *testing.T) { + ctx, cancel := context.WithCancelCause(t.Context()) + defer cancel(nil) + + pr, pw := io.Pipe() + + idleReader := NewIdleTimeoutReader(pr, 100*time.Millisecond, cancel) + defer idleReader.Stop() + + go func() { + _, _ = pw.Write([]byte("initial")) + // Stop writing — simulate stalled source + }() + + buf := make([]byte, 1024) + n, _ := idleReader.Read(buf) + assert.Equal(t, "initial", string(buf[:n])) + + time.Sleep(200 * time.Millisecond) + + assert.Error(t, ctx.Err(), "context should be cancelled when bytes stop mid-stream") + assert.Contains(t, context.Cause(ctx).Error(), "upload idle timeout") +} + +func Test_StopIdleTimeoutReader_WhenCalledBeforeTimeout_DoesNotCancelContext(t *testing.T) { + ctx, cancel := context.WithCancelCause(t.Context()) + defer cancel(nil) + + pr, _ := io.Pipe() + + idleReader := NewIdleTimeoutReader(pr, 100*time.Millisecond, cancel) + idleReader.Stop() + + time.Sleep(200 * time.Millisecond) + + assert.NoError(t, ctx.Err(), "context should not be cancelled when reader is stopped before timeout") +} + +func Test_ReadThroughIdleTimeoutReader_WhenReaderReturnsError_PropagatesError(t *testing.T) { + ctx, cancel := context.WithCancelCause(t.Context()) + defer cancel(nil) + + pr, pw := io.Pipe() + + idleReader := NewIdleTimeoutReader(pr, 5*time.Second, cancel) + defer idleReader.Stop() + + expectedErr := fmt.Errorf("test read error") + _ = pw.CloseWithError(expectedErr) + + buf := make([]byte, 1024) + _, err := idleReader.Read(buf) + + assert.ErrorIs(t, err, expectedErr) + + // Timer should be stopped after error — context should not be cancelled + time.Sleep(100 * time.Millisecond) + assert.NoError(t, ctx.Err(), "context should not be cancelled after reader error stops the timer") +} diff --git a/agent/internal/features/full_backup/backuper.go b/agent/internal/features/full_backup/backuper.go index e336311..1e12221 100644 --- a/agent/internal/features/full_backup/backuper.go +++ b/agent/internal/features/full_backup/backuper.go @@ -21,9 +21,11 @@ import ( const ( checkInterval = 30 * time.Second retryDelay = 1 * time.Minute - uploadTimeout = 30 * time.Minute + uploadTimeout = 23 * time.Hour ) +var uploadIdleTimeout = 5 * time.Minute + var retryDelayOverride *time.Duration type CmdBuilder func(ctx context.Context) *exec.Cmd @@ -176,16 +178,32 @@ func (backuper *FullBackuper) executeAndUploadBasebackup(ctx context.Context) er // Phase 1: Stream compressed data via io.Pipe directly to the API. pipeReader, pipeWriter := io.Pipe() + defer func() { _ = pipeReader.Close() }() + go backuper.compressAndStream(pipeWriter, stdoutPipe) - uploadCtx, cancel := context.WithTimeout(ctx, uploadTimeout) - defer cancel() + uploadCtx, timeoutCancel := context.WithTimeout(ctx, uploadTimeout) + defer timeoutCancel() - uploadResp, uploadErr := backuper.apiClient.UploadBasebackup(uploadCtx, pipeReader) + idleCtx, idleCancel := context.WithCancelCause(uploadCtx) + defer idleCancel(nil) + + idleReader := api.NewIdleTimeoutReader(pipeReader, uploadIdleTimeout, idleCancel) + defer idleReader.Stop() + + uploadResp, uploadErr := backuper.apiClient.UploadBasebackup(idleCtx, idleReader) + + if uploadErr != nil && cmd.Process != nil { + _ = cmd.Process.Kill() + } cmdErr := cmd.Wait() if uploadErr != nil { + if cause := context.Cause(idleCtx); cause != nil { + uploadErr = cause + } + stderrStr := stderrBuf.String() if stderrStr != "" { return fmt.Errorf("upload basebackup: %w (pg_basebackup stderr: %s)", uploadErr, stderrStr) diff --git a/agent/internal/features/full_backup/backuper_test.go b/agent/internal/features/full_backup/backuper_test.go index 7d423a4..610b3fb 100644 --- a/agent/internal/features/full_backup/backuper_test.go +++ b/agent/internal/features/full_backup/backuper_test.go @@ -562,6 +562,68 @@ func Test_RunFullBackup_WhenUploadSucceeds_BodyIsZstdCompressed(t *testing.T) { assert.Equal(t, originalContent, string(decompressed)) } +func Test_RunFullBackup_WhenUploadStalls_FailsWithIdleTimeout(t *testing.T) { + server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case testFullStartPath: + // Server reads body normally — it will block until connection is closed + _, _ = io.ReadAll(r.Body) + writeJSON(w, map[string]string{"backupId": testBackupID}) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + fb := newTestFullBackuper(server.URL) + fb.cmdBuilder = stallingCmdBuilder(t) + + origIdleTimeout := uploadIdleTimeout + uploadIdleTimeout = 200 * time.Millisecond + defer func() { uploadIdleTimeout = origIdleTimeout }() + + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) + defer cancel() + + err := fb.executeAndUploadBasebackup(ctx) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "idle timeout", "error should mention idle timeout") +} + +func stallingCmdBuilder(t *testing.T) CmdBuilder { + t.Helper() + + return func(ctx context.Context) *exec.Cmd { + cmd := exec.CommandContext(ctx, os.Args[0], + "-test.run=TestHelperProcessStalling", + "--", + ) + + cmd.Env = append(os.Environ(), "GO_TEST_HELPER_PROCESS_STALLING=1") + + return cmd + } +} + +func TestHelperProcessStalling(t *testing.T) { + if os.Getenv("GO_TEST_HELPER_PROCESS_STALLING") != "1" { + return + } + + // Write enough data to flush through the zstd encoder's internal buffer (~128KB blocks). + // Without enough data, zstd buffers everything and the pipe never receives bytes. + data := make([]byte, 256*1024) + for i := range data { + data[i] = byte(i) + } + _, _ = os.Stdout.Write(data) + + // Stall with stdout open — the compress goroutine blocks on its next read. + // The parent process will kill us when the context is cancelled. + time.Sleep(time.Hour) + os.Exit(0) +} + func newTestServer(t *testing.T, handler http.HandlerFunc) *httptest.Server { t.Helper() diff --git a/agent/internal/features/wal/streamer.go b/agent/internal/features/wal/streamer.go index 7ae5337..d5b5702 100644 --- a/agent/internal/features/wal/streamer.go +++ b/agent/internal/features/wal/streamer.go @@ -18,6 +18,8 @@ import ( "databasus-agent/internal/features/api" ) +var uploadIdleTimeout = 5 * time.Minute + const ( pollInterval = 10 * time.Second uploadTimeout = 5 * time.Minute @@ -122,16 +124,27 @@ func (s *Streamer) uploadSegment(ctx context.Context, segmentName string) error filePath := filepath.Join(s.cfg.PgWalDir, segmentName) pr, pw := io.Pipe() + defer func() { _ = pr.Close() }() go s.compressAndStream(pw, filePath) - uploadCtx, cancel := context.WithTimeout(ctx, uploadTimeout) - defer cancel() + uploadCtx, timeoutCancel := context.WithTimeout(ctx, uploadTimeout) + defer timeoutCancel() + + idleCtx, idleCancel := context.WithCancelCause(uploadCtx) + defer idleCancel(nil) + + idleReader := api.NewIdleTimeoutReader(pr, uploadIdleTimeout, idleCancel) + defer idleReader.Stop() s.log.Info("Uploading WAL segment", "segment", segmentName) - result, err := s.apiClient.UploadWalSegment(uploadCtx, segmentName, pr) + result, err := s.apiClient.UploadWalSegment(idleCtx, segmentName, idleReader) if err != nil { + if cause := context.Cause(idleCtx); cause != nil { + return fmt.Errorf("upload WAL segment: %w", cause) + } + return err } diff --git a/agent/internal/features/wal/streamer_test.go b/agent/internal/features/wal/streamer_test.go index fc161a8..5647cbf 100644 --- a/agent/internal/features/wal/streamer_test.go +++ b/agent/internal/features/wal/streamer_test.go @@ -2,6 +2,7 @@ package wal import ( "context" + "crypto/rand" "encoding/json" "io" "net/http" @@ -9,6 +10,7 @@ import ( "os" "path/filepath" "sync" + "sync/atomic" "testing" "time" @@ -287,6 +289,49 @@ func Test_UploadSegment_ServerReturns409_FileNotDeleted(t *testing.T) { assert.NoError(t, err, "segment file should not be deleted on gap detection") } +func Test_UploadSegment_WhenUploadStalls_FailsWithIdleTimeout(t *testing.T) { + walDir := createTestWalDir(t) + + // Use incompressible random data to ensure TCP buffers fill up + segmentContent := make([]byte, 1024*1024) + _, err := rand.Read(segmentContent) + require.NoError(t, err) + + writeTestSegment(t, walDir, "000000010000000100000001", segmentContent) + + var requestReceived atomic.Bool + handlerDone := make(chan struct{}) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestReceived.Store(true) + + // Read one byte then stall — simulates a network stall + buf := make([]byte, 1) + _, _ = r.Body.Read(buf) + <-handlerDone + })) + defer server.Close() + defer close(handlerDone) + + origIdleTimeout := uploadIdleTimeout + uploadIdleTimeout = 200 * time.Millisecond + defer func() { uploadIdleTimeout = origIdleTimeout }() + + streamer := newTestStreamer(walDir, server.URL) + + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) + defer cancel() + + uploadErr := streamer.uploadSegment(ctx, "000000010000000100000001") + + assert.Error(t, uploadErr, "upload should fail when stalled") + assert.True(t, requestReceived.Load(), "server should have received the request") + assert.Contains(t, uploadErr.Error(), "idle timeout", "error should mention idle timeout") + + _, statErr := os.Stat(filepath.Join(walDir, "000000010000000100000001")) + assert.NoError(t, statErr, "segment file should remain in queue after idle timeout") +} + func newTestStreamer(walDir, serverURL string) *Streamer { cfg := createTestConfig(walDir, serverURL) apiClient := api.NewClient(serverURL, cfg.Token, logger.GetLogger())