Compare commits

...

46 Commits

Author SHA1 Message Date
Rostislav Dugin
7123de9fa3 Merge pull request #486 from databasus/develop
FIX (docker): Chown /var/run/postgresql after UID/GID adjustment to f…
2026-03-31 14:25:15 +03:00
Rostislav Dugin
d1c41ed53a FIX (docker): Chown /var/run/postgresql after UID/GID adjustment to fix PostgreSQL lock file permission denied on startup 2026-03-31 14:24:43 +03:00
Rostislav Dugin
c2ddbfc86f Merge pull request #484 from databasus/develop
FEATURE (docker): Add PUID/PGID environment variables to control post…
2026-03-31 11:52:23 +03:00
Rostislav Dugin
f287967b5d FEATURE (docker): Add PUID/PGID environment variables to control postgres user UID/GID for host-level backup compatibility 2026-03-31 11:51:57 +03:00
Rostislav Dugin
ef879df08f Merge pull request #483 from databasus/develop
FIX (backups): Use system's temp directory instead of mounter directo…
2026-03-31 11:41:26 +03:00
Rostislav Dugin
44ddcb836e FIX (backups): Use system's temp directory instead of mounter directory to fix access permissions on TrueNAS 2026-03-31 11:40:11 +03:00
Rostislav Dugin
b5178f5752 Merge pull request #482 from databasus/develop
FEATURE (clipboard): Add parsing from clipboard via dialog in HTTP\no…
2026-03-31 11:21:25 +03:00
Rostislav Dugin
7913c1b474 FEATURE (clipboard): Add parsing from clipboard via dialog in HTTP\no navigator mode 2026-03-31 11:20:13 +03:00
Rostislav Dugin
2815cc3752 Merge pull request #481 from databasus/develop
FIX (storages): Validat only single rclone storage is passed
2026-03-31 10:37:54 +03:00
Rostislav Dugin
189573fa1b FIX (storages): Validat only single rclone storage is passed 2026-03-31 10:37:13 +03:00
Rostislav Dugin
81f77760c9 Merge pull request #479 from databasus/develop
FEATURE (navbar): Update navbar link color
2026-03-30 13:16:07 +03:00
Rostislav Dugin
63e23b2489 FEATURE (navbar): Update navbar link color 2026-03-30 13:15:03 +03:00
Rostislav Dugin
8c1b8ac00f Merge pull request #477 from databasus/develop
Develop
2026-03-29 15:45:15 +03:00
Rostislav Dugin
1926096377 FEATURE (backups): Add filters to backups panel 2026-03-29 15:33:01 +03:00
Rostislav Dugin
0a131511a8 FIX (agent): Fix uploading WAL to storages 2026-03-29 14:35:46 +03:00
Rostislav Dugin
aa01ce0b76 FEATURE (agent): Make installation guide more structured 2026-03-29 14:34:33 +03:00
Rostislav Dugin
496fc05993 Merge pull request #474 from databasus/develop
FIX (wal): Fix timeout upload test
2026-03-29 11:30:01 +03:00
Rostislav Dugin
1ac0eb4d5b FIX (wal): Fix timeout upload test 2026-03-29 11:29:44 +03:00
Rostislav Dugin
6b052902f7 Merge pull request #473 from databasus/develop
Develop
2026-03-28 22:53:13 +03:00
Rostislav Dugin
c7d091fe51 FEATURE (agent): Stop WAL and FULL backups on staling within 5 mins 2026-03-28 22:52:46 +03:00
Rostislav Dugin
b1dfd1c425 FIX (agent): Do not show cancel button for agent backups 2026-03-28 22:07:45 +03:00
Rostislav Dugin
4bee78646a REFACTOR (go): Refactor go to follow modern syntax guidelines 2026-03-28 22:02:22 +03:00
Rostislav Dugin
927eeabc0f Merge pull request #470 from databasus/develop
Develop
2026-03-27 23:58:39 +03:00
Rostislav Dugin
3a5a53c92d FIX (backups): Fix light theme for banner 2026-03-27 23:44:00 +03:00
Rostislav Dugin
f0ab470a84 FEATURE (readme): Update readme 2026-03-27 23:00:25 +03:00
Rostislav Dugin
f728fda759 FIX (backups): Hide hourly GFS when daily and more period selected 2026-03-27 22:51:40 +03:00
Rostislav Dugin
80b5df6283 FIX (billing): Fix UI units 2026-03-27 22:50:08 +03:00
Rostislav Dugin
67556a0db1 FIX (dockerfile): Fix index.html 2026-03-27 22:33:20 +03:00
Rostislav Dugin
c4cf7f8446 Merge branch 'develop' of https://github.com/databasus/databasus into develop 2026-03-27 22:02:40 +03:00
Rostislav Dugin
61a0bcabb1 FEATURE (cloud): Add cloud 2026-03-27 22:02:25 +03:00
Rostislav Dugin
f7f70a13eb Merge pull request #465 from databasus/develop
Develop
2026-03-24 18:01:35 +03:00
Rostislav Dugin
f1e289c421 Merge pull request #464 from kvendingoldo/main
feat: add support of imagePullSecrets
2026-03-24 18:00:15 +03:00
Alexander Sharov
c0952e057f feat: add support of imagePullSecrets
Signed-off-by: Alexander Sharov <kvendingoldo@yandex.ru>
2026-03-24 18:53:10 +04:00
Rostislav Dugin
b4d4e0a1d7 Merge pull request #459 from databasus/develop
FIX (readme): Fix typo in readme
2026-03-22 13:42:39 +03:00
Rostislav Dugin
c648e9c29f FIX (readme): Fix typo in readme 2026-03-22 13:41:56 +03:00
Rostislav Dugin
3fce6d2a99 Merge pull request #458 from databasus/develop
FIX (playground): Move index.html for playground into <noscript>
2026-03-22 09:58:48 +03:00
Rostislav Dugin
198b94ba9d FIX (playground): Move index.html for playground into <noscript> 2026-03-22 09:57:24 +03:00
Rostislav Dugin
80cd0bf5d3 Merge pull request #457 from databasus/develop
FIX (playground): Make turnstile mandatory in sign in and sign up
2026-03-21 22:32:26 +03:00
Rostislav Dugin
231e3cc709 FIX (playground): Make turnstile mandatory in sign in and sign up 2026-03-21 22:30:16 +03:00
Rostislav Dugin
8cf0fdacb1 Merge pull request #456 from databasus/develop
Develop
2026-03-21 14:15:21 +03:00
Rostislav Dugin
2d28af19dc FEATURE (playground): Remove playground warning 2026-03-21 14:14:22 +03:00
Rostislav Dugin
67dc257fda FIX (mariadb\mysql): Skip SSL if https mode is set to false 2026-03-21 14:12:53 +03:00
Rostislav Dugin
881167f812 FEATURE (index.html): Adjust policices for playgronund index 2026-03-21 14:08:09 +03:00
Rostislav Dugin
cf807cfc54 FIX (mariadb\mysql): Skip tables locking over restores 2026-03-21 12:55:34 +03:00
Rostislav Dugin
df91651709 Merge pull request #455 from databasus/develop
FIX (readme): Fix FAQ link
2026-03-21 12:44:22 +03:00
Rostislav Dugin
b0592dae9e FIX (readme): Fix FAQ link 2026-03-21 12:43:55 +03:00
197 changed files with 11630 additions and 3546 deletions

832
AGENTS.md

File diff suppressed because it is too large Load Diff

View File

@@ -239,7 +239,8 @@ RUN apt-get update && \
fi
# Create postgres user and set up directories
RUN useradd -m -s /bin/bash postgres || true && \
RUN groupadd -g 999 postgres || true && \
useradd -m -s /bin/bash -u 999 -g 999 postgres || true && \
mkdir -p /databasus-data/pgdata && \
chown -R postgres:postgres /databasus-data/pgdata
@@ -257,6 +258,9 @@ COPY backend/migrations ./migrations
# Copy UI files
COPY --from=backend-build /app/ui/build ./ui/build
# Copy cloud static HTML template (injected into index.html at startup when IS_CLOUD=true)
COPY frontend/cloud-root-content.html /app/cloud-root-content.html
# Copy agent binaries (both architectures) — served by the backend
# at GET /api/v1/system/agent?arch=amd64|arm64
COPY --from=agent-build /agent-binaries ./agent-binaries
@@ -291,6 +295,25 @@ if [ -d "/postgresus-data" ] && [ "\$(ls -A /postgresus-data 2>/dev/null)" ]; th
exit 1
fi
# ========= Adjust postgres user UID/GID =========
PUID=\${PUID:-999}
PGID=\${PGID:-999}
CURRENT_UID=\$(id -u postgres)
CURRENT_GID=\$(id -g postgres)
if [ "\$CURRENT_GID" != "\$PGID" ]; then
echo "Adjusting postgres group GID from \$CURRENT_GID to \$PGID..."
groupmod -o -g "\$PGID" postgres
fi
if [ "\$CURRENT_UID" != "\$PUID" ]; then
echo "Adjusting postgres user UID from \$CURRENT_UID to \$PUID..."
usermod -o -u "\$PUID" postgres
fi
chown -R postgres:postgres /var/run/postgresql
# PostgreSQL 17 binary paths
PG_BIN="/usr/lib/postgresql/17/bin"
@@ -313,7 +336,9 @@ window.__RUNTIME_CONFIG__ = {
GOOGLE_CLIENT_ID: '\${GOOGLE_CLIENT_ID:-}',
IS_EMAIL_CONFIGURED: '\$IS_EMAIL_CONFIGURED',
CLOUDFLARE_TURNSTILE_SITE_KEY: '\${CLOUDFLARE_TURNSTILE_SITE_KEY:-}',
CONTAINER_ARCH: '\${CONTAINER_ARCH:-unknown}'
CONTAINER_ARCH: '\${CONTAINER_ARCH:-unknown}',
CLOUD_PRICE_PER_GB: '\${CLOUD_PRICE_PER_GB:-}',
CLOUD_PADDLE_CLIENT_TOKEN: '\${CLOUD_PADDLE_CLIENT_TOKEN:-}'
};
JSEOF
@@ -326,6 +351,32 @@ if [ -n "\${ANALYTICS_SCRIPT:-}" ]; then
fi
fi
# Inject Paddle script if client token is provided (only if not already injected)
if [ -n "\${CLOUD_PADDLE_CLIENT_TOKEN:-}" ]; then
if ! grep -q "cdn.paddle.com" /app/ui/build/index.html 2>/dev/null; then
echo "Injecting Paddle script..."
sed -i "s#</head># <script src=\"https://cdn.paddle.com/paddle/v2/paddle.js\"></script>\\
</head>#" /app/ui/build/index.html
fi
fi
# Inject static HTML into root div for cloud mode (payment system requires visible legal links)
if [ "\${IS_CLOUD:-false}" = "true" ]; then
if ! grep -q "cloud-static-content" /app/ui/build/index.html 2>/dev/null; then
echo "Injecting cloud static HTML content..."
perl -i -pe '
BEGIN {
open my \$fh, "<", "/app/cloud-root-content.html" or die;
local \$/;
\$c = <\$fh>;
close \$fh;
\$c =~ s/\\n/ /g;
}
s/<div id="root"><\\/div>/<div id="root"><!-- cloud-static-content --><noscript>\$c<\\/noscript><\\/div>/
' /app/ui/build/index.html
fi
fi
# Ensure proper ownership of data directory
echo "Setting up data directory permissions..."
mkdir -p /databasus-data/pgdata
@@ -439,7 +490,7 @@ fi
echo "Setting up database and user..."
gosu postgres \$PG_BIN/psql -p 5437 -h localhost -d postgres << 'SQL'
# We use stub password, because internal DB is not exposed outside container
-- We use stub password, because internal DB is not exposed outside container
ALTER USER postgres WITH PASSWORD 'Q1234567';
SELECT 'CREATE DATABASE databasus OWNER postgres'
WHERE NOT EXISTS (SELECT FROM pg_database WHERE datname = 'databasus')

22
NOTICE.md Normal file
View File

@@ -0,0 +1,22 @@
Copyright © 20252026 Rostislav Dugin and contributors.
“Databasus” is a trademark of Rostislav Dugin.
The source code in this repository is licensed under the Apache License, Version 2.0.
That license applies to the code only and does not grant any right to use the
Databasus name, logo, or branding, except for reasonable and customary referential
use in describing the origin of the software and reproducing the content of this NOTICE.
Permitted referential use includes truthful use of the name “Databasus” to identify
the original Databasus project in software catalogs, deployment templates, hosting
panels, package indexes, compatibility pages, integrations, tutorials, reviews, and
similar informational materials, including phrases such as “Databasus”,
“Deploy Databasus”, “Databasus on Coolify”, and “Compatible with Databasus”.
You may not use “Databasus” as the name or primary branding of a competing product,
service, fork, distribution, or hosted offering, or in any manner likely to cause
confusion as to source, affiliation, sponsorship, or endorsement.
Nothing in this repository transfers, waives, limits, or estops any rights in the
Databasus mark. All trademark rights are reserved except for the limited referential
use stated above.

View File

@@ -1,8 +1,8 @@
<div align="center">
<img src="assets/logo.svg" alt="Databasus Logo" width="250"/>
<h3>Backup tool for PostgreSQL, MySQL and MongoDB</h3>
<p>Databasus is a free, open source and self-hosted tool to backup databases (with focus on PostgreSQL). Make backups with different storages (S3, Google Drive, FTP, etc.) and notifications about progress (Slack, Discord, Telegram, etc.)</p>
<h3>PostgreSQL backup tool (with MySQL\MariaDB and MongoDB support)</h3>
<p>Databasus is a free, open source and self-hosted tool to backup databases (with primary focus on PostgreSQL). Make backups with different storages (S3, Google Drive, FTP, etc.) and notifications about progress (Slack, Discord, Telegram, etc.)</p>
<!-- Badges -->
[![PostgreSQL](https://img.shields.io/badge/PostgreSQL-336791?logo=postgresql&logoColor=white)](https://www.postgresql.org/)
@@ -99,8 +99,8 @@ It is also important for Databasus that you are able to decrypt and restore back
### 📦 **Backup types**
- **Logical** — Native dump of the database in its engine-specific binary format. Compressed and streamed directly to storage with no intermediate files
- **Physical** — File-level copy of the entire database cluster. Faster backup and restore for large datasets compared to logical dumps (requires agent)
- **Incremental** — Physical base backup combined with continuous WAL segment archiving. Enables Point-in-Time Recovery (PITR) — restore to any second between backups. Designed for disaster recovery and near-zero data loss requirements (requires agent)
- **Physical** — File-level copy of the entire database cluster. Faster backup and restore for large datasets compared to logical dumps
- **Incremental** — Physical base backup combined with continuous WAL segment archiving. Enables Point-in-time recovery (PITR) — restore to any second between backups. Designed for disaster recovery and near-zero data loss requirements
### 🐳 **Self-hosted & secure**
@@ -245,7 +245,7 @@ Replace `admin` with the actual email address of the user whose password you wan
### 💾 Backuping Databasus itself
After installation, it is also recommended to <a href="https://databasus.com/faq/#backup-databasus">backup your Databasus itself</a> or, at least, to copy secret key used for encryption (30 seconds is needed). So you are able to restore from your encrypted backups if you lose access to the server with Databasus or it is corrupted.
After installation, it is also recommended to <a href="https://databasus.com/faq#backup-databasus">backup your Databasus itself</a> or, at least, to copy secret key used for encryption (30 seconds is needed). So you are able to restore from your encrypted backups if you lose access to the server with Databasus or it is corrupted.
---
@@ -259,7 +259,9 @@ Contributions are welcome! Read the <a href="https://databasus.com/contribute">c
Also you can join our large community of developers, DBAs and DevOps engineers on Telegram [@databasus_community](https://t.me/databasus_community).
## AI disclaimer
## FAQ
### AI disclaimer
There have been questions about AI usage in project development in issues and discussions. As the project focuses on security, reliability and production usage, it's important to explain how AI is used in the development process.
@@ -295,3 +297,13 @@ Moreover, it's important to note that we do not differentiate between bad human
Even if code is written manually by a human, it's not guaranteed to be merged. Vibe code is not allowed at all and all such PRs are rejected by default (see [contributing guide](https://databasus.com/contribute)).
We also draw attention to fast issue resolution and security [vulnerability reporting](https://github.com/databasus/databasus?tab=security-ov-file#readme).
### You have a cloud version — are you truly open source?
Yes. Every feature available in Databasus Cloud is equally available in the self-hosted version with no restrictions, no feature gates and no usage limits. The entire codebase is Apache 2.0 licensed and always will be.
Databasus is not "open core." We do not withhold features behind a paid tier and then call the limited remainder "open source," as projects like GitLab or Sentry do. We believe open source means the complete product is open, not just a marketing label on a stripped-down edition.
Databasus Cloud runs the exact same code as the self-hosted version. The only difference is that we take care of infrastructure, availability, backups, reservations, monitoring and updates for you — so you don't have to. If you are using cloud, you can always move your databases from cloud to self-hosted if you wish.
Revenue from Cloud funds full-time development of the project. Most large open-source projects rely on corporate backing or sponsorship to survive. We chose a different path: Databasus sustains itself so it can grow and improve independently, without being tied to any enterprise or sponsor.

View File

@@ -110,8 +110,7 @@ func (c *Config) applyDefaults() {
}
if c.IsDeleteWalAfterUpload == nil {
v := true
c.IsDeleteWalAfterUpload = &v
c.IsDeleteWalAfterUpload = new(true)
}
}

View File

@@ -0,0 +1,60 @@
package api
import (
"context"
"fmt"
"io"
"time"
)
// IdleTimeoutReader wraps an io.Reader and cancels the associated context
// if no bytes are successfully read within the specified timeout duration.
// This detects stalled uploads where the network or source stops transmitting data.
//
// When the idle timeout fires, the reader is also closed (if it implements io.Closer)
// to unblock any goroutine blocked on the underlying Read.
type IdleTimeoutReader struct {
reader io.Reader
timeout time.Duration
cancel context.CancelCauseFunc
timer *time.Timer
}
// NewIdleTimeoutReader creates a reader that cancels the context via cancel
// if Read does not return any bytes for the given timeout duration.
func NewIdleTimeoutReader(reader io.Reader, timeout time.Duration, cancel context.CancelCauseFunc) *IdleTimeoutReader {
r := &IdleTimeoutReader{
reader: reader,
timeout: timeout,
cancel: cancel,
}
r.timer = time.AfterFunc(timeout, func() {
cancel(fmt.Errorf("upload idle timeout: no bytes transmitted for %v", timeout))
if closer, ok := reader.(io.Closer); ok {
_ = closer.Close()
}
})
return r
}
func (r *IdleTimeoutReader) Read(p []byte) (int, error) {
n, err := r.reader.Read(p)
if n > 0 {
r.timer.Reset(r.timeout)
}
if err != nil && err != io.EOF {
r.Stop()
}
return n, err
}
// Stop cancels the idle timer. Must be called when the reader is no longer needed.
func (r *IdleTimeoutReader) Stop() {
r.timer.Stop()
}

View File

@@ -0,0 +1,112 @@
package api
import (
"context"
"fmt"
"io"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_ReadThroughIdleTimeoutReader_WhenBytesFlowContinuously_DoesNotCancelContext(t *testing.T) {
ctx, cancel := context.WithCancelCause(t.Context())
defer cancel(nil)
pr, pw := io.Pipe()
idleReader := NewIdleTimeoutReader(pr, 200*time.Millisecond, cancel)
defer idleReader.Stop()
go func() {
for range 5 {
_, _ = pw.Write([]byte("data"))
time.Sleep(50 * time.Millisecond)
}
_ = pw.Close()
}()
data, err := io.ReadAll(idleReader)
require.NoError(t, err)
assert.Equal(t, "datadatadatadatadata", string(data))
assert.NoError(t, ctx.Err(), "context should not be cancelled when bytes flow continuously")
}
func Test_ReadThroughIdleTimeoutReader_WhenNoBytesTransmitted_CancelsContext(t *testing.T) {
ctx, cancel := context.WithCancelCause(t.Context())
defer cancel(nil)
pr, _ := io.Pipe()
idleReader := NewIdleTimeoutReader(pr, 100*time.Millisecond, cancel)
defer idleReader.Stop()
time.Sleep(200 * time.Millisecond)
assert.Error(t, ctx.Err(), "context should be cancelled when no bytes are transmitted")
assert.Contains(t, context.Cause(ctx).Error(), "upload idle timeout")
}
func Test_ReadThroughIdleTimeoutReader_WhenBytesStopMidStream_CancelsContext(t *testing.T) {
ctx, cancel := context.WithCancelCause(t.Context())
defer cancel(nil)
pr, pw := io.Pipe()
idleReader := NewIdleTimeoutReader(pr, 100*time.Millisecond, cancel)
defer idleReader.Stop()
go func() {
_, _ = pw.Write([]byte("initial"))
// Stop writing — simulate stalled source
}()
buf := make([]byte, 1024)
n, _ := idleReader.Read(buf)
assert.Equal(t, "initial", string(buf[:n]))
time.Sleep(200 * time.Millisecond)
assert.Error(t, ctx.Err(), "context should be cancelled when bytes stop mid-stream")
assert.Contains(t, context.Cause(ctx).Error(), "upload idle timeout")
}
func Test_StopIdleTimeoutReader_WhenCalledBeforeTimeout_DoesNotCancelContext(t *testing.T) {
ctx, cancel := context.WithCancelCause(t.Context())
defer cancel(nil)
pr, _ := io.Pipe()
idleReader := NewIdleTimeoutReader(pr, 100*time.Millisecond, cancel)
idleReader.Stop()
time.Sleep(200 * time.Millisecond)
assert.NoError(t, ctx.Err(), "context should not be cancelled when reader is stopped before timeout")
}
func Test_ReadThroughIdleTimeoutReader_WhenReaderReturnsError_PropagatesError(t *testing.T) {
ctx, cancel := context.WithCancelCause(t.Context())
defer cancel(nil)
pr, pw := io.Pipe()
idleReader := NewIdleTimeoutReader(pr, 5*time.Second, cancel)
defer idleReader.Stop()
expectedErr := fmt.Errorf("test read error")
_ = pw.CloseWithError(expectedErr)
buf := make([]byte, 1024)
_, err := idleReader.Read(buf)
assert.ErrorIs(t, err, expectedErr)
// Timer should be stopped after error — context should not be cancelled
time.Sleep(100 * time.Millisecond)
assert.NoError(t, ctx.Err(), "context should not be cancelled after reader error stops the timer")
}

View File

@@ -21,9 +21,11 @@ import (
const (
checkInterval = 30 * time.Second
retryDelay = 1 * time.Minute
uploadTimeout = 30 * time.Minute
uploadTimeout = 23 * time.Hour
)
var uploadIdleTimeout = 5 * time.Minute
var retryDelayOverride *time.Duration
type CmdBuilder func(ctx context.Context) *exec.Cmd
@@ -176,16 +178,32 @@ func (backuper *FullBackuper) executeAndUploadBasebackup(ctx context.Context) er
// Phase 1: Stream compressed data via io.Pipe directly to the API.
pipeReader, pipeWriter := io.Pipe()
defer func() { _ = pipeReader.Close() }()
go backuper.compressAndStream(pipeWriter, stdoutPipe)
uploadCtx, cancel := context.WithTimeout(ctx, uploadTimeout)
defer cancel()
uploadCtx, timeoutCancel := context.WithTimeout(ctx, uploadTimeout)
defer timeoutCancel()
uploadResp, uploadErr := backuper.apiClient.UploadBasebackup(uploadCtx, pipeReader)
idleCtx, idleCancel := context.WithCancelCause(uploadCtx)
defer idleCancel(nil)
idleReader := api.NewIdleTimeoutReader(pipeReader, uploadIdleTimeout, idleCancel)
defer idleReader.Stop()
uploadResp, uploadErr := backuper.apiClient.UploadBasebackup(idleCtx, idleReader)
if uploadErr != nil && cmd.Process != nil {
_ = cmd.Process.Kill()
}
cmdErr := cmd.Wait()
if uploadErr != nil {
if cause := context.Cause(idleCtx); cause != nil {
uploadErr = cause
}
stderrStr := stderrBuf.String()
if stderrStr != "" {
return fmt.Errorf("upload basebackup: %w (pg_basebackup stderr: %s)", uploadErr, stderrStr)

View File

@@ -71,7 +71,7 @@ func Test_RunFullBackup_WhenChainBroken_BasebackupTriggered(t *testing.T) {
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, "test-backup-data", validStderr())
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
defer cancel()
go fb.Run(ctx)
@@ -124,7 +124,7 @@ func Test_RunFullBackup_WhenScheduledBackupDue_BasebackupTriggered(t *testing.T)
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, "scheduled-backup-data", validStderr())
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
defer cancel()
go fb.Run(ctx)
@@ -169,7 +169,7 @@ func Test_RunFullBackup_WhenNoFullBackupExists_ImmediateBasebackupTriggered(t *t
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, "first-backup-data", validStderr())
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
defer cancel()
go fb.Run(ctx)
@@ -233,7 +233,7 @@ func Test_RunFullBackup_WhenUploadFails_RetriesAfterDelay(t *testing.T) {
setRetryDelay(100 * time.Millisecond)
defer setRetryDelay(origRetryDelay)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
defer cancel()
go fb.Run(ctx)
@@ -282,7 +282,7 @@ func Test_RunFullBackup_WhenAlreadyRunning_SkipsExecution(t *testing.T) {
fb.isRunning.Store(true)
fb.checkAndRunIfNeeded(context.Background())
fb.checkAndRunIfNeeded(t.Context())
mu.Lock()
count := uploadCount
@@ -318,7 +318,7 @@ func Test_RunFullBackup_WhenContextCancelled_StopsCleanly(t *testing.T) {
setRetryDelay(5 * time.Second)
defer setRetryDelay(origRetryDelay)
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
ctx, cancel := context.WithTimeout(t.Context(), 500*time.Millisecond)
defer cancel()
done := make(chan struct{})
@@ -360,7 +360,7 @@ func Test_RunFullBackup_WhenChainValidAndNotScheduled_NoBasebackupTriggered(t *t
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr())
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
defer cancel()
go fb.Run(ctx)
@@ -411,7 +411,7 @@ func Test_RunFullBackup_WhenStderrParsingFails_FinalizesWithErrorAndRetries(t *t
setRetryDelay(100 * time.Millisecond)
defer setRetryDelay(origRetryDelay)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
defer cancel()
go fb.Run(ctx)
@@ -458,7 +458,7 @@ func Test_RunFullBackup_WhenNextBackupTimeNull_BasebackupTriggered(t *testing.T)
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, "first-run-data", validStderr())
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
defer cancel()
go fb.Run(ctx)
@@ -498,7 +498,7 @@ func Test_RunFullBackup_WhenChainValidityReturns401_NoBasebackupTriggered(t *tes
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr())
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
defer cancel()
go fb.Run(ctx)
@@ -538,7 +538,7 @@ func Test_RunFullBackup_WhenUploadSucceeds_BodyIsZstdCompressed(t *testing.T) {
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, originalContent, validStderr())
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
defer cancel()
go fb.Run(ctx)
@@ -562,6 +562,68 @@ func Test_RunFullBackup_WhenUploadSucceeds_BodyIsZstdCompressed(t *testing.T) {
assert.Equal(t, originalContent, string(decompressed))
}
func Test_RunFullBackup_WhenUploadStalls_FailsWithIdleTimeout(t *testing.T) {
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testFullStartPath:
// Server reads body normally — it will block until connection is closed
_, _ = io.ReadAll(r.Body)
writeJSON(w, map[string]string{"backupId": testBackupID})
default:
w.WriteHeader(http.StatusNotFound)
}
})
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = stallingCmdBuilder(t)
origIdleTimeout := uploadIdleTimeout
uploadIdleTimeout = 200 * time.Millisecond
defer func() { uploadIdleTimeout = origIdleTimeout }()
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
defer cancel()
err := fb.executeAndUploadBasebackup(ctx)
assert.Error(t, err)
assert.Contains(t, err.Error(), "idle timeout", "error should mention idle timeout")
}
func stallingCmdBuilder(t *testing.T) CmdBuilder {
t.Helper()
return func(ctx context.Context) *exec.Cmd {
cmd := exec.CommandContext(ctx, os.Args[0],
"-test.run=TestHelperProcessStalling",
"--",
)
cmd.Env = append(os.Environ(), "GO_TEST_HELPER_PROCESS_STALLING=1")
return cmd
}
}
func TestHelperProcessStalling(t *testing.T) {
if os.Getenv("GO_TEST_HELPER_PROCESS_STALLING") != "1" {
return
}
// Write enough data to flush through the zstd encoder's internal buffer (~128KB blocks).
// Without enough data, zstd buffers everything and the pipe never receives bytes.
data := make([]byte, 256*1024)
for i := range data {
data[i] = byte(i)
}
_, _ = os.Stdout.Write(data)
// Stall with stdout open — the compress goroutine blocks on its next read.
// The parent process will kill us when the context is cancelled.
time.Sleep(time.Hour)
os.Exit(0)
}
func newTestServer(t *testing.T, handler http.HandlerFunc) *httptest.Server {
t.Helper()

View File

@@ -3,7 +3,6 @@ package restore
import (
"archive/tar"
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
@@ -86,7 +85,7 @@ func Test_RunRestore_WhenBasebackupAndWalSegmentsAvailable_FilesExtractedAndReco
targetDir := createTestTargetDir(t)
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
err := restorer.Run(context.Background())
err := restorer.Run(t.Context())
require.NoError(t, err)
pgVersionContent, err := os.ReadFile(filepath.Join(targetDir, "PG_VERSION"))
@@ -152,7 +151,7 @@ func Test_RunRestore_WhenTargetTimeProvided_RecoveryTargetTimeWrittenToConfig(t
targetDir := createTestTargetDir(t)
restorer := newTestRestorer(server.URL, targetDir, "", "2026-02-28T14:30:00Z", "")
err := restorer.Run(context.Background())
err := restorer.Run(t.Context())
require.NoError(t, err)
autoConfContent, err := os.ReadFile(filepath.Join(targetDir, "postgresql.auto.conf"))
@@ -169,7 +168,7 @@ func Test_RunRestore_WhenPgDataDirNotEmpty_ReturnsError(t *testing.T) {
restorer := newTestRestorer("http://localhost:0", targetDir, "", "", "")
err = restorer.Run(context.Background())
err = restorer.Run(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), "not empty")
}
@@ -179,7 +178,7 @@ func Test_RunRestore_WhenPgDataDirDoesNotExist_ReturnsError(t *testing.T) {
restorer := newTestRestorer("http://localhost:0", nonExistentDir, "", "", "")
err := restorer.Run(context.Background())
err := restorer.Run(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), "does not exist")
}
@@ -197,7 +196,7 @@ func Test_RunRestore_WhenNoBackupsAvailable_ReturnsError(t *testing.T) {
targetDir := createTestTargetDir(t)
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
err := restorer.Run(context.Background())
err := restorer.Run(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), "No full backups available")
}
@@ -216,7 +215,7 @@ func Test_RunRestore_WhenWalChainBroken_ReturnsError(t *testing.T) {
targetDir := createTestTargetDir(t)
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
err := restorer.Run(context.Background())
err := restorer.Run(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), "WAL chain broken")
assert.Contains(t, err.Error(), testWalSegment1)
@@ -282,7 +281,7 @@ func Test_DownloadWalSegment_WhenFirstAttemptFails_RetriesAndSucceeds(t *testing
retryDelayOverride = &testDelay
defer func() { retryDelayOverride = origDelay }()
err := restorer.Run(context.Background())
err := restorer.Run(t.Context())
require.NoError(t, err)
mu.Lock()
@@ -341,7 +340,7 @@ func Test_DownloadWalSegment_WhenAllAttemptsFail_ReturnsErrorWithSegmentName(t *
retryDelayOverride = &testDelay
defer func() { retryDelayOverride = origDelay }()
err := restorer.Run(context.Background())
err := restorer.Run(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), testWalSegment1)
assert.Contains(t, err.Error(), "3 attempts")
@@ -351,7 +350,7 @@ func Test_RunRestore_WhenInvalidTargetTimeFormat_ReturnsError(t *testing.T) {
targetDir := createTestTargetDir(t)
restorer := newTestRestorer("http://localhost:0", targetDir, "", "not-a-valid-time", "")
err := restorer.Run(context.Background())
err := restorer.Run(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid --target-time format")
}
@@ -384,7 +383,7 @@ func Test_RunRestore_WhenBasebackupDownloadFails_ReturnsError(t *testing.T) {
targetDir := createTestTargetDir(t)
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
err := restorer.Run(context.Background())
err := restorer.Run(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), "basebackup download failed")
}
@@ -423,7 +422,7 @@ func Test_RunRestore_WhenNoWalSegmentsInPlan_BasebackupRestoredSuccessfully(t *t
targetDir := createTestTargetDir(t)
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
err := restorer.Run(context.Background())
err := restorer.Run(t.Context())
require.NoError(t, err)
pgVersionContent, err := os.ReadFile(filepath.Join(targetDir, "PG_VERSION"))
@@ -486,7 +485,7 @@ func Test_RunRestore_WhenMakingApiCalls_AuthTokenIncludedInRequests(t *testing.T
targetDir := createTestTargetDir(t)
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
err := restorer.Run(context.Background())
err := restorer.Run(t.Context())
require.NoError(t, err)
assert.GreaterOrEqual(t, int(receivedAuthHeaders.Load()), 2)
@@ -530,7 +529,7 @@ func Test_ConfigurePostgresRecovery_WhenPgTypeHost_UsesHostAbsolutePath(t *testi
targetDir := createTestTargetDir(t)
restorer := newTestRestorer(server.URL, targetDir, "", "", "host")
err := restorer.Run(context.Background())
err := restorer.Run(t.Context())
require.NoError(t, err)
autoConfContent, err := os.ReadFile(filepath.Join(targetDir, "postgresql.auto.conf"))
@@ -577,7 +576,7 @@ func Test_ConfigurePostgresRecovery_WhenPgTypeDocker_UsesContainerPath(t *testin
targetDir := createTestTargetDir(t)
restorer := newTestRestorer(server.URL, targetDir, "", "", "docker")
err := restorer.Run(context.Background())
err := restorer.Run(t.Context())
require.NoError(t, err)
autoConfContent, err := os.ReadFile(filepath.Join(targetDir, "postgresql.auto.conf"))

View File

@@ -21,7 +21,7 @@ func Test_NewLockWatcher_CapturesInode(t *testing.T) {
require.NoError(t, err)
defer ReleaseLock(lockFile)
_, cancel := context.WithCancel(context.Background())
_, cancel := context.WithCancel(t.Context())
defer cancel()
watcher, err := NewLockWatcher(lockFile, cancel, log)
@@ -37,7 +37,7 @@ func Test_LockWatcher_FileUnchanged_ContextNotCancelled(t *testing.T) {
require.NoError(t, err)
defer ReleaseLock(lockFile)
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
watcher, err := NewLockWatcher(lockFile, cancel, log)
@@ -62,7 +62,7 @@ func Test_LockWatcher_FileDeleted_CancelsContext(t *testing.T) {
require.NoError(t, err)
defer ReleaseLock(lockFile)
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
watcher, err := NewLockWatcher(lockFile, cancel, log)
@@ -88,7 +88,7 @@ func Test_LockWatcher_FileReplacedWithDifferentInode_CancelsContext(t *testing.T
require.NoError(t, err)
defer ReleaseLock(lockFile)
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
watcher, err := NewLockWatcher(lockFile, cancel, log)

View File

@@ -8,7 +8,7 @@ import (
"os"
"path/filepath"
"regexp"
"sort"
"slices"
"strings"
"time"
@@ -18,6 +18,8 @@ import (
"databasus-agent/internal/features/api"
)
var uploadIdleTimeout = 5 * time.Minute
const (
pollInterval = 10 * time.Second
uploadTimeout = 5 * time.Minute
@@ -113,7 +115,7 @@ func (s *Streamer) listSegments() ([]string, error) {
segments = append(segments, name)
}
sort.Strings(segments)
slices.Sort(segments)
return segments, nil
}
@@ -122,16 +124,27 @@ func (s *Streamer) uploadSegment(ctx context.Context, segmentName string) error
filePath := filepath.Join(s.cfg.PgWalDir, segmentName)
pr, pw := io.Pipe()
defer func() { _ = pr.Close() }()
go s.compressAndStream(pw, filePath)
uploadCtx, cancel := context.WithTimeout(ctx, uploadTimeout)
defer cancel()
uploadCtx, timeoutCancel := context.WithTimeout(ctx, uploadTimeout)
defer timeoutCancel()
idleCtx, idleCancel := context.WithCancelCause(uploadCtx)
defer idleCancel(nil)
idleReader := api.NewIdleTimeoutReader(pr, uploadIdleTimeout, idleCancel)
defer idleReader.Stop()
s.log.Info("Uploading WAL segment", "segment", segmentName)
result, err := s.apiClient.UploadWalSegment(uploadCtx, segmentName, pr)
result, err := s.apiClient.UploadWalSegment(idleCtx, segmentName, idleReader)
if err != nil {
if cause := context.Cause(idleCtx); cause != nil {
return fmt.Errorf("upload WAL segment: %w", cause)
}
return err
}

View File

@@ -2,6 +2,7 @@ package wal
import (
"context"
"crypto/rand"
"encoding/json"
"io"
"net/http"
@@ -9,6 +10,7 @@ import (
"os"
"path/filepath"
"sync"
"sync/atomic"
"testing"
"time"
@@ -42,7 +44,7 @@ func Test_UploadSegment_SingleSegment_ServerReceivesCorrectHeadersAndBody(t *tes
streamer := newTestStreamer(walDir, server.URL)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
defer cancel()
go streamer.Run(ctx)
@@ -79,7 +81,7 @@ func Test_UploadSegments_MultipleSegmentsOutOfOrder_UploadedInAscendingOrder(t *
streamer := newTestStreamer(walDir, server.URL)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
defer cancel()
go streamer.Run(ctx)
@@ -115,7 +117,7 @@ func Test_UploadSegments_DirectoryHasTmpFiles_TmpFilesIgnored(t *testing.T) {
streamer := newTestStreamer(walDir, server.URL)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
defer cancel()
go streamer.Run(ctx)
@@ -146,7 +148,7 @@ func Test_UploadSegment_DeleteEnabled_FileRemovedAfterUpload(t *testing.T) {
apiClient := api.NewClient(server.URL, cfg.Token, logger.GetLogger())
streamer := NewStreamer(cfg, apiClient, logger.GetLogger())
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
defer cancel()
go streamer.Run(ctx)
@@ -174,7 +176,7 @@ func Test_UploadSegment_DeleteDisabled_FileKeptAfterUpload(t *testing.T) {
apiClient := api.NewClient(server.URL, cfg.Token, logger.GetLogger())
streamer := NewStreamer(cfg, apiClient, logger.GetLogger())
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
defer cancel()
go streamer.Run(ctx)
@@ -199,7 +201,7 @@ func Test_UploadSegment_ServerReturns500_FileKeptInQueue(t *testing.T) {
streamer := newTestStreamer(walDir, server.URL)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
defer cancel()
go streamer.Run(ctx)
@@ -223,7 +225,7 @@ func Test_ProcessQueue_EmptyDirectory_NoUploads(t *testing.T) {
streamer := newTestStreamer(walDir, server.URL)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
defer cancel()
go streamer.Run(ctx)
@@ -238,7 +240,7 @@ func Test_Run_ContextCancelled_StopsImmediately(t *testing.T) {
streamer := newTestStreamer(walDir, "http://localhost:0")
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(t.Context())
cancel()
done := make(chan struct{})
@@ -276,7 +278,7 @@ func Test_UploadSegment_ServerReturns409_FileNotDeleted(t *testing.T) {
streamer := newTestStreamer(walDir, server.URL)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
defer cancel()
go streamer.Run(ctx)
@@ -287,6 +289,49 @@ func Test_UploadSegment_ServerReturns409_FileNotDeleted(t *testing.T) {
assert.NoError(t, err, "segment file should not be deleted on gap detection")
}
func Test_UploadSegment_WhenUploadStalls_FailsWithIdleTimeout(t *testing.T) {
walDir := createTestWalDir(t)
// Use incompressible random data to ensure TCP buffers fill up
segmentContent := make([]byte, 1024*1024)
_, err := rand.Read(segmentContent)
require.NoError(t, err)
writeTestSegment(t, walDir, "000000010000000100000001", segmentContent)
var requestReceived atomic.Bool
handlerDone := make(chan struct{})
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestReceived.Store(true)
// Read one byte then stall — simulates a network stall
buf := make([]byte, 1)
_, _ = r.Body.Read(buf)
<-handlerDone
}))
defer server.Close()
defer close(handlerDone)
origIdleTimeout := uploadIdleTimeout
uploadIdleTimeout = 200 * time.Millisecond
defer func() { uploadIdleTimeout = origIdleTimeout }()
streamer := newTestStreamer(walDir, server.URL)
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
defer cancel()
uploadErr := streamer.uploadSegment(ctx, "000000010000000100000001")
assert.Error(t, uploadErr, "upload should fail when stalled")
assert.True(t, requestReceived.Load(), "server should have received the request")
assert.Contains(t, uploadErr.Error(), "idle timeout", "error should mention idle timeout")
_, statErr := os.Stat(filepath.Join(walDir, "000000010000000100000001"))
assert.NoError(t, statErr, "segment file should remain in queue after idle timeout")
}
func newTestStreamer(walDir, serverURL string) *Streamer {
cfg := createTestConfig(walDir, serverURL)
apiClient := api.NewClient(serverURL, cfg.Token, logger.GetLogger())

View File

@@ -64,16 +64,12 @@ func (w *rotatingWriter) rotate() error {
return nil
}
var (
loggerInstance *slog.Logger
once sync.Once
)
var loggerInstance *slog.Logger
var initLogger = sync.OnceFunc(initialize)
func GetLogger() *slog.Logger {
once.Do(func() {
initialize()
})
initLogger()
return loggerInstance
}

View File

@@ -67,7 +67,7 @@ func Test_Write_MultipleSmallWrites_CurrentSizeAccumulated(t *testing.T) {
rw, _, _ := setupRotatingWriter(t, 1024)
var totalWritten int64
for i := 0; i < 10; i++ {
for range 10 {
data := []byte("line\n")
n, err := rw.Write(data)
require.NoError(t, err)

View File

@@ -27,6 +27,13 @@ VALKEY_PORT=6379
VALKEY_USERNAME=
VALKEY_PASSWORD=
VALKEY_IS_SSL=false
# billing
PRICE_PER_GB_CENTS=
IS_PADDLE_SANDBOX=true
PADDLE_API_KEY=
PADDLE_WEBHOOK_SECRET=
PADDLE_PRICE_ID=
PADDLE_CLIENT_TOKEN=
# testing
# to get Google Drive env variables: add storage in UI and copy data from added storage here
TEST_GOOGLE_DRIVE_CLIENT_ID=

View File

@@ -9,6 +9,7 @@ import (
"os/exec"
"os/signal"
"path/filepath"
"runtime/debug"
"syscall"
"time"
@@ -25,6 +26,8 @@ import (
backups_download "databasus-backend/internal/features/backups/backups/download"
backups_services "databasus-backend/internal/features/backups/backups/services"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/billing"
billing_paddle "databasus-backend/internal/features/billing/paddle"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/disk"
"databasus-backend/internal/features/encryption/secrets"
@@ -105,7 +108,9 @@ func main() {
go generateSwaggerDocs(log)
gin.SetMode(gin.ReleaseMode)
ginApp := gin.Default()
ginApp := gin.New()
ginApp.Use(gin.Logger())
ginApp.Use(ginRecoveryWithLogger(log))
// Add GZIP compression middleware
ginApp.Use(gzip.Gzip(
@@ -188,7 +193,7 @@ func startServerWithGracefulShutdown(log *slog.Logger, app *gin.Engine) {
log.Info("Shutdown signal received")
// Gracefully shutdown VictoriaLogs writer
logger.ShutdownVictoriaLogs(5 * time.Second)
logger.ShutdownVictoriaLogs()
// The context is used to inform the server it has 10 seconds to finish
// the request it is currently handling
@@ -217,6 +222,10 @@ func setUpRoutes(r *gin.Engine) {
backups_controllers.GetPostgresWalBackupController().RegisterRoutes(v1)
databases.GetDatabaseController().RegisterPublicRoutes(v1)
if config.GetEnv().IsCloud {
billing_paddle.GetPaddleBillingController().RegisterPublicRoutes(v1)
}
// Setup auth middleware
userService := users_services.GetUserService()
authMiddleware := users_middleware.AuthMiddleware(userService)
@@ -240,6 +249,7 @@ func setUpRoutes(r *gin.Engine) {
audit_logs.GetAuditLogController().RegisterRoutes(protected)
users_controllers.GetManagementController().RegisterRoutes(protected)
users_controllers.GetSettingsController().RegisterRoutes(protected)
billing.GetBillingController().RegisterRoutes(protected)
}
func setUpDependencies() {
@@ -252,6 +262,11 @@ func setUpDependencies() {
storages.SetupDependencies()
backups_config.SetupDependencies()
task_cancellation.SetupDependencies()
billing.SetupDependencies()
if config.GetEnv().IsCloud {
billing_paddle.SetupDependencies()
}
}
func runBackgroundTasks(log *slog.Logger) {
@@ -308,6 +323,12 @@ func runBackgroundTasks(log *slog.Logger) {
go runWithPanicLogging(log, "restore nodes registry background service", func() {
restoring.GetRestoreNodesRegistry().Run(ctx)
})
if config.GetEnv().IsCloud {
go runWithPanicLogging(log, "billing background service", func() {
billing.GetBillingService().Run(ctx, *log)
})
}
} else {
log.Info("Skipping primary node tasks as not primary node")
}
@@ -330,7 +351,7 @@ func runBackgroundTasks(log *slog.Logger) {
func runWithPanicLogging(log *slog.Logger, serviceName string, fn func()) {
defer func() {
if r := recover(); r != nil {
log.Error("Panic in "+serviceName, "error", r)
log.Error("Panic in "+serviceName, "error", r, "stacktrace", string(debug.Stack()))
}
}()
fn()
@@ -410,6 +431,25 @@ func enableCors(ginApp *gin.Engine) {
}
}
func ginRecoveryWithLogger(log *slog.Logger) gin.HandlerFunc {
return func(ctx *gin.Context) {
defer func() {
if r := recover(); r != nil {
log.Error("Panic recovered in HTTP handler",
"error", r,
"stacktrace", string(debug.Stack()),
"method", ctx.Request.Method,
"path", ctx.Request.URL.Path,
)
ctx.AbortWithStatus(http.StatusInternalServerError)
}
}()
ctx.Next()
}
}
func mountFrontend(ginApp *gin.Engine) {
staticDir := "./ui/build"
ginApp.NoRoute(func(c *gin.Context) {

View File

@@ -5,6 +5,7 @@ go 1.26.1
require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.3
github.com/PaddleHQ/paddle-go-sdk v1.0.0
github.com/gin-contrib/cors v1.7.5
github.com/gin-contrib/gzip v1.2.3
github.com/gin-gonic/gin v1.10.0
@@ -100,6 +101,8 @@ require (
github.com/emersion/go-message v0.18.2 // indirect
github.com/emersion/go-vcard v0.0.0-20241024213814-c9703dde27ff // indirect
github.com/flynn/noise v1.1.0 // indirect
github.com/ggicci/httpin v0.19.0 // indirect
github.com/ggicci/owl v0.8.2 // indirect
github.com/go-chi/chi/v5 v5.2.3 // indirect
github.com/go-darwin/apfs v0.0.0-20211011131704-f84b94dbf348 // indirect
github.com/go-git/go-billy/v5 v5.6.2 // indirect

View File

@@ -77,6 +77,8 @@ github.com/Max-Sum/base32768 v0.0.0-20230304063302-18e6ce5945fd/go.mod h1:C8yoIf
github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/PaddleHQ/paddle-go-sdk v1.0.0 h1:+EXitsPFbRcc0CpQE/MIeudxiVOR8pFe/aOWTEUHDKU=
github.com/PaddleHQ/paddle-go-sdk v1.0.0/go.mod h1:kbBBzf0BHEj38QvhtoELqlGip3alKgA/I+vl7RQzB58=
github.com/ProtonMail/bcrypt v0.0.0-20210511135022-227b4adcab57/go.mod h1:HecWFHognK8GfRDGnFQbW/LiV7A3MX3gZVs45vk5h8I=
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf h1:yc9daCCYUefEs69zUkSzubzjBbL+cmOXgnmt9Fyd9ug=
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf/go.mod h1:o0ESU9p83twszAU8LBeJKFAAMX14tISa0yk4Oo5TOqo=
@@ -248,6 +250,10 @@ github.com/gabriel-vasile/mimetype v1.4.11/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9t
github.com/geoffgarside/ber v1.1.0/go.mod h1:jVPKeCbj6MvQZhwLYsGwaGI52oUorHoHKNecGT85ZCc=
github.com/geoffgarside/ber v1.2.0 h1:/loowoRcs/MWLYmGX9QtIAbA+V/FrnVLsMMPhwiRm64=
github.com/geoffgarside/ber v1.2.0/go.mod h1:jVPKeCbj6MvQZhwLYsGwaGI52oUorHoHKNecGT85ZCc=
github.com/ggicci/httpin v0.19.0 h1:p0B3SWLVgg770VirYiHB14M5wdRx3zR8mCTzM/TkTQ8=
github.com/ggicci/httpin v0.19.0/go.mod h1:hzsQHcbqLabmGOycf7WNw6AAzcVbsMeoOp46bWAbIWc=
github.com/ggicci/owl v0.8.2 h1:og+lhqpzSMPDdEB+NJfzoAJARP7qCG3f8uUC3xvGukA=
github.com/ggicci/owl v0.8.2/go.mod h1:PHRD57u41vFN5UtFz2SF79yTVoM3HlWpjMiE+ZU2dj4=
github.com/gin-contrib/cors v1.7.5 h1:cXC9SmofOrRg0w9PigwGlHG3ztswH6bqq4vJVXnvYMk=
github.com/gin-contrib/cors v1.7.5/go.mod h1:4q3yi7xBEDDWKapjT2o1V7mScKDDr8k+jZ0fSquGoy0=
github.com/gin-contrib/gzip v1.2.3 h1:dAhT722RuEG330ce2agAs75z7yB+NKvX/ZM1r8w0u2U=
@@ -454,6 +460,8 @@ github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1
github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk=
github.com/jtolio/noiseconn v0.0.0-20231127013910-f6d9ecbf1de7 h1:JcltaO1HXM5S2KYOYcKgAV7slU0xPy1OcvrVgn98sRQ=
github.com/jtolio/noiseconn v0.0.0-20231127013910-f6d9ecbf1de7/go.mod h1:MEkhEPFwP3yudWO0lj6vfYpLIB+3eIcuIW+e0AZzUQk=
github.com/justinas/alice v1.2.0 h1:+MHSA/vccVCF4Uq37S42jwlkvI2Xzl7zTPCN5BnZNVo=
github.com/justinas/alice v1.2.0/go.mod h1:fN5HRH/reO/zrUflLfTN43t3vXvKzvZIENsNEe7i7qA=
github.com/jzelinskie/whirlpool v0.0.0-20201016144138-0675e54bb004 h1:G+9t9cEtnC9jFiTxyptEKuNIAbiN5ZCQzX2a74lj3xg=
github.com/jzelinskie/whirlpool v0.0.0-20201016144138-0675e54bb004/go.mod h1:KmHnJWQrgEvbuy0vcvj00gtMqbvNn1L+3YUZLK/B92c=
github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU=

View File

@@ -5,6 +5,7 @@ import (
"path/filepath"
"strings"
"sync"
"time"
"github.com/ilyakaznacheev/cleanenv"
"github.com/joho/godotenv"
@@ -53,6 +54,20 @@ type EnvVariables struct {
TempFolder string
SecretKeyPath string
// Billing (always tax-exclusive)
PricePerGBCents int64 `env:"PRICE_PER_GB_CENTS"`
MinStorageGB int
MaxStorageGB int
TrialDuration time.Duration
TrialStorageGB int
GracePeriod time.Duration
// Paddle billing
IsPaddleSandbox bool `env:"IS_PADDLE_SANDBOX"`
PaddleApiKey string `env:"PADDLE_API_KEY"`
PaddleWebhookSecret string `env:"PADDLE_WEBHOOK_SECRET"`
PaddlePriceID string `env:"PADDLE_PRICE_ID"`
PaddleClientToken string `env:"PADDLE_CLIENT_TOKEN"`
TestGoogleDriveClientID string `env:"TEST_GOOGLE_DRIVE_CLIENT_ID"`
TestGoogleDriveClientSecret string `env:"TEST_GOOGLE_DRIVE_CLIENT_SECRET"`
TestGoogleDriveTokenJSON string `env:"TEST_GOOGLE_DRIVE_TOKEN_JSON"`
@@ -127,14 +142,13 @@ type EnvVariables struct {
DatabasusURL string `env:"DATABASUS_URL"`
}
var (
env EnvVariables
once sync.Once
)
var env EnvVariables
func GetEnv() EnvVariables {
once.Do(loadEnvVariables)
return env
var initEnv = sync.OnceFunc(loadEnvVariables)
func GetEnv() *EnvVariables {
initEnv()
return &env
}
func loadEnvVariables() {
@@ -363,5 +377,39 @@ func loadEnvVariables() {
}
// Billing
if env.IsCloud {
if env.PricePerGBCents == 0 {
log.Error("PRICE_PER_GB_CENTS is empty or zero")
os.Exit(1)
}
if env.PaddleApiKey == "" {
log.Error("PADDLE_API_KEY is empty")
os.Exit(1)
}
if env.PaddleWebhookSecret == "" {
log.Error("PADDLE_WEBHOOK_SECRET is empty")
os.Exit(1)
}
if env.PaddlePriceID == "" {
log.Error("PADDLE_PRICE_ID is empty")
os.Exit(1)
}
if env.PaddleClientToken == "" {
log.Error("PADDLE_CLIENT_TOKEN is empty")
os.Exit(1)
}
}
env.MinStorageGB = 20
env.MaxStorageGB = 10_000
env.TrialDuration = 24 * time.Hour
env.TrialStorageGB = 20
env.GracePeriod = 30 * 24 * time.Hour
log.Info("Environment variables loaded successfully!")
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
)
@@ -13,39 +12,32 @@ type AuditLogBackgroundService struct {
auditLogService *AuditLogService
logger *slog.Logger
runOnce sync.Once
hasRun atomic.Bool
hasRun atomic.Bool
}
func (s *AuditLogBackgroundService) Run(ctx context.Context) {
wasAlreadyRun := s.hasRun.Load()
if s.hasRun.Swap(true) {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
s.runOnce.Do(func() {
s.hasRun.Store(true)
s.logger.Info("Starting audit log cleanup background service")
s.logger.Info("Starting audit log cleanup background service")
if ctx.Err() != nil {
return
}
if ctx.Err() != nil {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
}
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.cleanOldAuditLogs(); err != nil {
s.logger.Error("Failed to clean old audit logs", "error", err)
}
case <-ticker.C:
if err := s.cleanOldAuditLogs(); err != nil {
s.logger.Error("Failed to clean old audit logs", "error", err)
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
}

View File

@@ -102,7 +102,7 @@ func Test_CleanOldAuditLogs_DeletesMultipleOldLogs(t *testing.T) {
// Create many old logs with specific UUIDs to track them
testLogIDs := make([]uuid.UUID, 5)
for i := 0; i < 5; i++ {
for i := range 5 {
testLogIDs[i] = uuid.New()
daysAgo := 400 + (i * 10)
log := &AuditLog{

View File

@@ -2,7 +2,6 @@ package audit_logs
import (
"sync"
"sync/atomic"
users_services "databasus-backend/internal/features/users/services"
"databasus-backend/internal/util/logger"
@@ -23,8 +22,6 @@ var auditLogController = &AuditLogController{
var auditLogBackgroundService = &AuditLogBackgroundService{
auditLogService: auditLogService,
logger: logger.GetLogger(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
func GetAuditLogService() *AuditLogService {
@@ -39,23 +36,8 @@ func GetAuditLogBackgroundService() *AuditLogBackgroundService {
return auditLogBackgroundService
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
users_services.GetUserService().SetAuditLogWriter(auditLogService)
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}
var SetupDependencies = sync.OnceFunc(func() {
users_services.GetUserService().SetAuditLogWriter(auditLogService)
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
})

View File

@@ -9,7 +9,6 @@ import (
"log/slog"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
@@ -46,80 +45,73 @@ type BackuperNode struct {
lastHeartbeat time.Time
runOnce sync.Once
hasRun atomic.Bool
hasRun atomic.Bool
}
func (n *BackuperNode) Run(ctx context.Context) {
wasAlreadyRun := n.hasRun.Load()
if n.hasRun.Swap(true) {
panic(fmt.Sprintf("%T.Run() called multiple times", n))
}
n.runOnce.Do(func() {
n.hasRun.Store(true)
n.lastHeartbeat = time.Now().UTC()
n.lastHeartbeat = time.Now().UTC()
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
backupNode := BackupNode{
ID: n.nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: time.Now().UTC(),
}
backupNode := BackupNode{
ID: n.nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: time.Now().UTC(),
}
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
n.logger.Error("Failed to register node in registry", "error", err)
panic(err)
}
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
n.logger.Error("Failed to register node in registry", "error", err)
panic(err)
}
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
go func() {
n.MakeBackup(backupID, isCallNotifier)
if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil {
n.logger.Error(
"Failed to publish backup completion",
"error",
err,
"backupID",
backupID,
)
}
}()
}
err := n.backupNodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID, backupHandler)
if err != nil {
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
panic(err)
}
defer func() {
if err := n.backupNodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
go func() {
n.MakeBackup(backupID, isCallNotifier)
if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil {
n.logger.Error(
"Failed to publish backup completion",
"error",
err,
"backupID",
backupID,
)
}
}()
}
ticker := time.NewTicker(heartbeatTickerInterval)
defer ticker.Stop()
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
for {
select {
case <-ctx.Done():
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
if err := n.backupNodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
n.logger.Error("Failed to unregister node from registry", "error", err)
}
return
case <-ticker.C:
n.sendHeartbeat(&backupNode)
}
err := n.backupNodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID, backupHandler)
if err != nil {
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
panic(err)
}
defer func() {
if err := n.backupNodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
}
})
}()
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", n))
ticker := time.NewTicker(heartbeatTickerInterval)
defer ticker.Stop()
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
for {
select {
case <-ctx.Done():
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
if err := n.backupNodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
n.logger.Error("Failed to unregister node from registry", "error", err)
}
return
case <-ticker.C:
n.sendHeartbeat(&backupNode)
}
}
}
@@ -171,26 +163,6 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
backup.BackupSizeMb = completedMBs
backup.BackupDurationMs = time.Since(start).Milliseconds()
// Check size limit (0 = unlimited)
if backupConfig.MaxBackupSizeMB > 0 &&
completedMBs > float64(backupConfig.MaxBackupSizeMB) {
errMsg := fmt.Sprintf(
"backup size (%.2f MB) exceeded maximum allowed size (%d MB)",
completedMBs,
backupConfig.MaxBackupSizeMB,
)
backup.Status = backups_core.BackupStatusFailed
backup.IsSkipRetry = true
backup.FailMessage = &errMsg
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to save backup with size exceeded error", "error", err)
}
cancel() // Cancel the backup context
return
}
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to update backup progress", "error", err)
}

View File

@@ -153,121 +153,3 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
assert.Equal(t, notifier.ID, capturedNotifier.ID)
})
}
func Test_BackupSizeLimits(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
t.Run("UnlimitedSize_MaxBackupSizeMBIsZero_BackupCompletes", func(t *testing.T) {
// Enable backups with unlimited size (0)
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backupConfig.MaxBackupSizeMB = 0 // unlimited
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backuperNode := CreateTestBackuperNode()
backuperNode.createBackupUseCase = &CreateLargeBackupUsecase{}
// Create a backup record
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backuperNode.MakeBackup(backup.ID, false)
// Verify backup completed successfully even with large size
updatedBackup, err := backupRepository.FindByID(backup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
assert.Equal(t, float64(10000), updatedBackup.BackupSizeMb)
assert.Nil(t, updatedBackup.FailMessage)
})
t.Run("SizeExceeded_BackupFailedWithIsSkipRetry", func(t *testing.T) {
// Enable backups with 5 MB limit
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backupConfig.MaxBackupSizeMB = 5
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backuperNode := CreateTestBackuperNode()
backuperNode.createBackupUseCase = &CreateProgressiveBackupUsecase{}
// Create a backup record
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backuperNode.MakeBackup(backup.ID, false)
// Verify backup was marked as failed with IsSkipRetry=true
updatedBackup, err := backupRepository.FindByID(backup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusFailed, updatedBackup.Status)
assert.True(t, updatedBackup.IsSkipRetry)
assert.NotNil(t, updatedBackup.FailMessage)
assert.Contains(t, *updatedBackup.FailMessage, "exceeded maximum allowed size")
assert.Contains(t, *updatedBackup.FailMessage, "10.00 MB")
assert.Contains(t, *updatedBackup.FailMessage, "5 MB")
assert.Greater(t, updatedBackup.BackupSizeMb, float64(5))
})
t.Run("SizeWithinLimit_BackupCompletes", func(t *testing.T) {
// Enable backups with 100 MB limit
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backupConfig.MaxBackupSizeMB = 100
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backuperNode := CreateTestBackuperNode()
backuperNode.createBackupUseCase = &CreateMediumBackupUsecase{}
// Create a backup record
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backuperNode.MakeBackup(backup.ID, false)
// Verify backup completed successfully
updatedBackup, err := backupRepository.FindByID(backup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
assert.Equal(t, float64(50), updatedBackup.BackupSizeMb)
assert.Nil(t, updatedBackup.FailMessage)
})
}

View File

@@ -4,12 +4,12 @@ import (
"context"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"databasus-backend/internal/config"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/storages"
@@ -26,49 +26,47 @@ type BackupCleaner struct {
backupRepository *backups_core.BackupRepository
storageService *storages.StorageService
backupConfigService *backups_config.BackupConfigService
billingService BillingService
fieldEncryptor util_encryption.FieldEncryptor
logger *slog.Logger
backupRemoveListeners []backups_core.BackupRemoveListener
runOnce sync.Once
hasRun atomic.Bool
hasRun atomic.Bool
}
func (c *BackupCleaner) Run(ctx context.Context) {
wasAlreadyRun := c.hasRun.Load()
if c.hasRun.Swap(true) {
panic(fmt.Sprintf("%T.Run() called multiple times", c))
}
c.runOnce.Do(func() {
c.hasRun.Store(true)
if ctx.Err() != nil {
return
}
if ctx.Err() != nil {
retentionLog := c.logger.With("task_name", "clean_by_retention_policy")
exceededLog := c.logger.With("task_name", "clean_exceeded_storage_backups")
staleLog := c.logger.With("task_name", "clean_stale_basebackups")
ticker := time.NewTicker(cleanerTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
}
case <-ticker.C:
if err := c.cleanByRetentionPolicy(retentionLog); err != nil {
retentionLog.Error("failed to clean backups by retention policy", "error", err)
}
ticker := time.NewTicker(cleanerTickerInterval)
defer ticker.Stop()
if err := c.cleanExceededStorageBackups(exceededLog); err != nil {
exceededLog.Error("failed to clean exceeded backups", "error", err)
}
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := c.cleanByRetentionPolicy(); err != nil {
c.logger.Error("Failed to clean backups by retention policy", "error", err)
}
if err := c.cleanExceededBackups(); err != nil {
c.logger.Error("Failed to clean exceeded backups", "error", err)
}
if err := c.cleanStaleUploadedBasebackups(); err != nil {
c.logger.Error("Failed to clean stale uploaded basebackups", "error", err)
}
if err := c.cleanStaleUploadedBasebackups(staleLog); err != nil {
staleLog.Error("failed to clean stale uploaded basebackups", "error", err)
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", c))
}
}
@@ -104,7 +102,7 @@ func (c *BackupCleaner) AddBackupRemoveListener(listener backups_core.BackupRemo
c.backupRemoveListeners = append(c.backupRemoveListeners, listener)
}
func (c *BackupCleaner) cleanStaleUploadedBasebackups() error {
func (c *BackupCleaner) cleanStaleUploadedBasebackups(logger *slog.Logger) error {
staleBackups, err := c.backupRepository.FindStaleUploadedBasebackups(
time.Now().UTC().Add(-10 * time.Minute),
)
@@ -113,31 +111,30 @@ func (c *BackupCleaner) cleanStaleUploadedBasebackups() error {
}
for _, backup := range staleBackups {
backupLog := logger.With("database_id", backup.DatabaseID, "backup_id", backup.ID)
staleStorage, storageErr := c.storageService.GetStorageByID(backup.StorageID)
if storageErr != nil {
c.logger.Error(
"Failed to get storage for stale basebackup cleanup",
"backupId", backup.ID,
"storageId", backup.StorageID,
backupLog.Error(
"failed to get storage for stale basebackup cleanup",
"storage_id", backup.StorageID,
"error", storageErr,
)
} else {
if err := staleStorage.DeleteFile(c.fieldEncryptor, backup.FileName); err != nil {
c.logger.Error(
"Failed to delete stale basebackup file",
"backupId", backup.ID,
"fileName", backup.FileName,
"error", err,
backupLog.Error(
fmt.Sprintf("failed to delete stale basebackup file: %s", backup.FileName),
"error",
err,
)
}
metadataFileName := backup.FileName + ".metadata"
if err := staleStorage.DeleteFile(c.fieldEncryptor, metadataFileName); err != nil {
c.logger.Error(
"Failed to delete stale basebackup metadata file",
"backupId", backup.ID,
"fileName", metadataFileName,
"error", err,
backupLog.Error(
fmt.Sprintf("failed to delete stale basebackup metadata file: %s", metadataFileName),
"error",
err,
)
}
}
@@ -147,77 +144,67 @@ func (c *BackupCleaner) cleanStaleUploadedBasebackups() error {
backup.FailMessage = &failMsg
if err := c.backupRepository.Save(backup); err != nil {
c.logger.Error(
"Failed to mark stale uploaded basebackup as failed",
"backupId", backup.ID,
"error", err,
)
backupLog.Error("failed to mark stale uploaded basebackup as failed", "error", err)
continue
}
c.logger.Info(
"Marked stale uploaded basebackup as failed and cleaned storage",
"backupId", backup.ID,
"databaseId", backup.DatabaseID,
)
backupLog.Info("marked stale uploaded basebackup as failed and cleaned storage")
}
return nil
}
func (c *BackupCleaner) cleanByRetentionPolicy() error {
func (c *BackupCleaner) cleanByRetentionPolicy(logger *slog.Logger) error {
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
dbLog := logger.With("database_id", backupConfig.DatabaseID, "policy", backupConfig.RetentionPolicyType)
var cleanErr error
switch backupConfig.RetentionPolicyType {
case backups_config.RetentionPolicyTypeCount:
cleanErr = c.cleanByCount(backupConfig)
cleanErr = c.cleanByCount(dbLog, backupConfig)
case backups_config.RetentionPolicyTypeGFS:
cleanErr = c.cleanByGFS(backupConfig)
cleanErr = c.cleanByGFS(dbLog, backupConfig)
default:
cleanErr = c.cleanByTimePeriod(backupConfig)
cleanErr = c.cleanByTimePeriod(dbLog, backupConfig)
}
if cleanErr != nil {
c.logger.Error(
"Failed to clean backups by retention policy",
"databaseId", backupConfig.DatabaseID,
"policy", backupConfig.RetentionPolicyType,
"error", cleanErr,
)
dbLog.Error("failed to clean backups by retention policy", "error", cleanErr)
}
}
return nil
}
func (c *BackupCleaner) cleanExceededBackups() error {
func (c *BackupCleaner) cleanExceededStorageBackups(logger *slog.Logger) error {
if !config.GetEnv().IsCloud {
return nil
}
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
if backupConfig.MaxBackupsTotalSizeMB <= 0 {
dbLog := logger.With("database_id", backupConfig.DatabaseID)
subscription, subErr := c.billingService.GetSubscription(dbLog, backupConfig.DatabaseID)
if subErr != nil {
dbLog.Error("failed to get subscription for exceeded backups check", "error", subErr)
continue
}
if err := c.cleanExceededBackupsForDatabase(
backupConfig.DatabaseID,
backupConfig.MaxBackupsTotalSizeMB,
); err != nil {
c.logger.Error(
"Failed to clean exceeded backups for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
storageLimitMB := int64(subscription.GetBackupsStorageGB()) * 1024
if err := c.cleanExceededBackupsForDatabase(dbLog, backupConfig.DatabaseID, storageLimitMB); err != nil {
dbLog.Error("failed to clean exceeded backups for database", "error", err)
continue
}
}
@@ -225,7 +212,7 @@ func (c *BackupCleaner) cleanExceededBackups() error {
return nil
}
func (c *BackupCleaner) cleanByTimePeriod(backupConfig *backups_config.BackupConfig) error {
func (c *BackupCleaner) cleanByTimePeriod(logger *slog.Logger, backupConfig *backups_config.BackupConfig) error {
if backupConfig.RetentionTimePeriod == "" {
return nil
}
@@ -255,21 +242,17 @@ func (c *BackupCleaner) cleanByTimePeriod(backupConfig *backups_config.BackupCon
}
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
logger.Error("failed to delete old backup", "backup_id", backup.ID, "error", err)
continue
}
c.logger.Info(
"Deleted old backup",
"backupId", backup.ID,
"databaseId", backupConfig.DatabaseID,
)
logger.Info("deleted old backup", "backup_id", backup.ID)
}
return nil
}
func (c *BackupCleaner) cleanByCount(backupConfig *backups_config.BackupConfig) error {
func (c *BackupCleaner) cleanByCount(logger *slog.Logger, backupConfig *backups_config.BackupConfig) error {
if backupConfig.RetentionCount <= 0 {
return nil
}
@@ -298,28 +281,20 @@ func (c *BackupCleaner) cleanByCount(backupConfig *backups_config.BackupConfig)
}
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error(
"Failed to delete backup by count policy",
"backupId",
backup.ID,
"error",
err,
)
logger.Error("failed to delete backup by count policy", "backup_id", backup.ID, "error", err)
continue
}
c.logger.Info(
"Deleted backup by count policy",
"backupId", backup.ID,
"databaseId", backupConfig.DatabaseID,
"retentionCount", backupConfig.RetentionCount,
logger.Info(
fmt.Sprintf("deleted backup by count policy: retention count is %d", backupConfig.RetentionCount),
"backup_id", backup.ID,
)
}
return nil
}
func (c *BackupCleaner) cleanByGFS(backupConfig *backups_config.BackupConfig) error {
func (c *BackupCleaner) cleanByGFS(logger *slog.Logger, backupConfig *backups_config.BackupConfig) error {
if backupConfig.RetentionGfsHours <= 0 && backupConfig.RetentionGfsDays <= 0 &&
backupConfig.RetentionGfsWeeks <= 0 && backupConfig.RetentionGfsMonths <= 0 &&
backupConfig.RetentionGfsYears <= 0 {
@@ -357,29 +332,20 @@ func (c *BackupCleaner) cleanByGFS(backupConfig *backups_config.BackupConfig) er
}
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error(
"Failed to delete backup by GFS policy",
"backupId",
backup.ID,
"error",
err,
)
logger.Error("failed to delete backup by GFS policy", "backup_id", backup.ID, "error", err)
continue
}
c.logger.Info(
"Deleted backup by GFS policy",
"backupId", backup.ID,
"databaseId", backupConfig.DatabaseID,
)
logger.Info("deleted backup by GFS policy", "backup_id", backup.ID)
}
return nil
}
func (c *BackupCleaner) cleanExceededBackupsForDatabase(
logger *slog.Logger,
databaseID uuid.UUID,
limitperDbMB int64,
limitPerDbMB int64,
) error {
for {
backupsTotalSizeMB, err := c.backupRepository.GetTotalSizeByDatabase(databaseID)
@@ -387,7 +353,7 @@ func (c *BackupCleaner) cleanExceededBackupsForDatabase(
return err
}
if backupsTotalSizeMB <= float64(limitperDbMB) {
if backupsTotalSizeMB <= float64(limitPerDbMB) {
break
}
@@ -400,59 +366,27 @@ func (c *BackupCleaner) cleanExceededBackupsForDatabase(
}
if len(oldestBackups) == 0 {
c.logger.Warn(
"No backups to delete but still over limit",
"databaseId",
databaseID,
"totalSizeMB",
backupsTotalSizeMB,
"limitMB",
limitperDbMB,
)
logger.Warn(fmt.Sprintf(
"no backups to delete but still over limit: total size is %.1f MB, limit is %d MB",
backupsTotalSizeMB, limitPerDbMB,
))
break
}
backup := oldestBackups[0]
if isRecentBackup(backup) {
c.logger.Warn(
"Oldest backup is too recent to delete, stopping size cleanup",
"databaseId",
databaseID,
"backupId",
backup.ID,
"totalSizeMB",
backupsTotalSizeMB,
"limitMB",
limitperDbMB,
)
break
}
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error(
"Failed to delete exceeded backup",
"backupId",
backup.ID,
"databaseId",
databaseID,
"error",
err,
)
logger.Error("failed to delete exceeded backup", "backup_id", backup.ID, "error", err)
return err
}
c.logger.Info(
"Deleted exceeded backup",
"backupId",
backup.ID,
"databaseId",
databaseID,
"backupSizeMB",
backup.BackupSizeMb,
"totalSizeMB",
backupsTotalSizeMB,
"limitMB",
limitperDbMB,
logger.Info(
fmt.Sprintf("deleted exceeded backup: backup size is %.1f MB, total size is %.1f MB, limit is %d MB",
backup.BackupSizeMb, backupsTotalSizeMB, limitPerDbMB),
"backup_id", backup.ID,
)
}

View File

@@ -33,7 +33,7 @@ func Test_BuildGFSKeepSet(t *testing.T) {
// backupsEveryDay returns n backups, newest-first, each 1 day apart.
backupsEveryDay := func(n int) []*backups_core.Backup {
bs := make([]*backups_core.Backup, n)
for i := 0; i < n; i++ {
for i := range n {
bs[i] = newBackup(ref.Add(-time.Duration(i) * day))
}
return bs
@@ -42,7 +42,7 @@ func Test_BuildGFSKeepSet(t *testing.T) {
// backupsEveryWeek returns n backups, newest-first, each 7 days apart.
backupsEveryWeek := func(n int) []*backups_core.Backup {
bs := make([]*backups_core.Backup, n)
for i := 0; i < n; i++ {
for i := range n {
bs[i] = newBackup(ref.Add(-time.Duration(i) * week))
}
return bs
@@ -53,7 +53,7 @@ func Test_BuildGFSKeepSet(t *testing.T) {
// backupsEveryHour returns n backups, newest-first, each 1 hour apart.
backupsEveryHour := func(n int) []*backups_core.Backup {
bs := make([]*backups_core.Backup, n)
for i := 0; i < n; i++ {
for i := range n {
bs[i] = newBackup(ref.Add(-time.Duration(i) * hour))
}
return bs
@@ -62,7 +62,7 @@ func Test_BuildGFSKeepSet(t *testing.T) {
// backupsEveryMonth returns n backups, newest-first, each ~1 month apart.
backupsEveryMonth := func(n int) []*backups_core.Backup {
bs := make([]*backups_core.Backup, n)
for i := 0; i < n; i++ {
for i := range n {
bs[i] = newBackup(ref.AddDate(0, -i, 0))
}
return bs
@@ -71,7 +71,7 @@ func Test_BuildGFSKeepSet(t *testing.T) {
// backupsEveryYear returns n backups, newest-first, each 1 year apart.
backupsEveryYear := func(n int) []*backups_core.Backup {
bs := make([]*backups_core.Backup, n)
for i := 0; i < n; i++ {
for i := range n {
bs[i] = newBackup(ref.AddDate(-i, 0, 0))
}
return bs
@@ -410,7 +410,7 @@ func Test_CleanByGFS_KeepsCorrectBackupsPerSlot(t *testing.T) {
// Create 5 backups on 5 different days; only the 3 newest days should be kept
var backupIDs []uuid.UUID
for i := 0; i < 5; i++ {
for i := range 5 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
@@ -425,7 +425,7 @@ func Test_CleanByGFS_KeepsCorrectBackupsPerSlot(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -486,7 +486,7 @@ func Test_CleanByGFS_WithWeeklyAndMonthlySlots_KeepsWiderSpread(t *testing.T) {
// Create one backup per week for 6 weeks (each on Monday of that week)
// GFS should keep: 2 daily (most recent 2 unique days) + 2 weekly + 1 monthly = up to 5 unique
var createdIDs []uuid.UUID
for i := 0; i < 6; i++ {
for i := range 6 {
weekOffset := time.Duration(5-i) * 7 * 24 * time.Hour
backup := &backups_core.Backup{
ID: uuid.New(),
@@ -502,7 +502,7 @@ func Test_CleanByGFS_WithWeeklyAndMonthlySlots_KeepsWiderSpread(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -561,7 +561,7 @@ func Test_CleanByGFS_WithHourlySlots_KeepsCorrectBackups(t *testing.T) {
// Create 5 backups spaced 1 hour apart; only the 3 newest hours should be kept
var backupIDs []uuid.UUID
for i := 0; i < 5; i++ {
for i := range 5 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
@@ -576,7 +576,7 @@ func Test_CleanByGFS_WithHourlySlots_KeepsCorrectBackups(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -677,7 +677,7 @@ func Test_CleanByGFS_SkipsRecentBackup_WhenNotInKeepSet(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -759,7 +759,7 @@ func Test_CleanByGFS_With20DailyBackups_KeepsOnlyExpectedCount(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -824,8 +824,8 @@ func Test_CleanByGFS_WithMultipleBackupsPerDay_KeepsOnlyOnePerDailySlot(t *testi
// Create 3 backups per day for 10 days = 30 total, all beyond grace period.
// Each day gets backups at base+0h, base+6h, base+12h.
for day := 0; day < 10; day++ {
for sub := 0; sub < 3; sub++ {
for day := range 10 {
for sub := range 3 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
@@ -844,7 +844,7 @@ func Test_CleanByGFS_WithMultipleBackupsPerDay_KeepsOnlyOnePerDailySlot(t *testi
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -915,7 +915,7 @@ func Test_CleanByGFS_With24HourlySlotsAnd23DailyBackups_DeletesExcessBackups(t *
now := time.Now().UTC()
for i := 0; i < 23; i++ {
for i := range 23 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
@@ -929,7 +929,7 @@ func Test_CleanByGFS_With24HourlySlotsAnd23DailyBackups_DeletesExcessBackups(t *
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -985,7 +985,7 @@ func Test_CleanByGFS_WithDisabledHourlySlotsAnd23DailyBackups_DeletesExcessBacku
now := time.Now().UTC()
for i := 0; i < 23; i++ {
for i := range 23 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
@@ -999,7 +999,7 @@ func Test_CleanByGFS_WithDisabledHourlySlotsAnd23DailyBackups_DeletesExcessBacku
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -1055,7 +1055,7 @@ func Test_CleanByGFS_WithDailySlotsAndWeeklyBackups_DeletesExcessBackups(t *test
// Create 10 weekly backups (1 per week, all >2h old past grace period).
// With 7d/4w config, correct behavior: ~8 kept (4 weekly + overlap with daily for recent ones).
// Daily slots should NOT absorb weekly backups that are older than 7 days.
for i := 0; i < 10; i++ {
for i := range 10 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
@@ -1069,7 +1069,7 @@ func Test_CleanByGFS_WithDailySlotsAndWeeklyBackups_DeletesExcessBackups(t *test
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -1138,7 +1138,7 @@ func Test_CleanByGFS_WithWeeklySlotsAndMonthlyBackups_DeletesExcessBackups(t *te
// With 52w/3m config, correct behavior: 3 kept (3 monthly slots; weekly should only
// cover recent 52 weeks but not artificially retain old monthly backups).
// Bug: all 8 kept because each monthly backup fills a unique weekly slot.
for i := 0; i < 8; i++ {
for i := range 8 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
@@ -1152,7 +1152,7 @@ func Test_CleanByGFS_WithWeeklySlotsAndMonthlyBackups_DeletesExcessBackups(t *te
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)

View File

@@ -1,14 +1,17 @@
package backuping
import (
"log/slog"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
billing_models "databasus-backend/internal/features/billing/models"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
@@ -17,6 +20,7 @@ import (
users_testing "databasus-backend/internal/features/users/testing"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/storage"
"databasus-backend/internal/util/logger"
"databasus-backend/internal/util/period"
)
@@ -51,6 +55,7 @@ func Test_CleanOldBackups_DeletesBackupsOlderThanRetentionTimePeriod(t *testing.
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -89,7 +94,7 @@ func Test_CleanOldBackups_DeletesBackupsOlderThanRetentionTimePeriod(t *testing.
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -129,6 +134,7 @@ func Test_CleanOldBackups_SkipsDatabaseWithForeverRetentionPeriod(t *testing.T)
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -145,7 +151,7 @@ func Test_CleanOldBackups_SkipsDatabaseWithForeverRetentionPeriod(t *testing.T)
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -154,7 +160,8 @@ func Test_CleanOldBackups_SkipsDatabaseWithForeverRetentionPeriod(t *testing.T)
assert.Equal(t, oldBackup.ID, remainingBackups[0].ID)
}
func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
func Test_CleanExceededBackups_WhenUnderStorageLimit_NoBackupsDeleted(t *testing.T) {
enableCloud(t)
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
@@ -178,33 +185,36 @@ func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 100,
BackupIntervalID: interval.ID,
BackupInterval: interval,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
for i := 0; i < 3; i++ {
for i := range 3 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 16.67,
BackupSizeMb: 100,
CreatedAt: time.Now().UTC().Add(-time.Duration(i) * time.Hour),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
}
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
}
cleaner := CreateTestBackupCleaner(mockBilling)
err = cleaner.cleanExceededStorageBackups(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -212,7 +222,8 @@ func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
assert.Equal(t, 3, len(remainingBackups))
}
func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T) {
func Test_CleanExceededBackups_WhenOverStorageLimit_DeletesOldestBackups(t *testing.T) {
enableCloud(t)
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
@@ -236,27 +247,29 @@ func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T)
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 30,
BackupIntervalID: interval.ID,
BackupInterval: interval,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// 5 backups at 300 MB each = 1500 MB total, limit = 1 GB (1024 MB)
// Expect 2 oldest deleted, 3 remain (900 MB < 1024 MB)
now := time.Now().UTC()
var backupIDs []uuid.UUID
for i := 0; i < 5; i++ {
for i := range 5 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
BackupSizeMb: 300,
CreatedAt: now.Add(-time.Duration(4-i) * time.Hour),
}
err = backupRepository.Save(backup)
@@ -264,8 +277,11 @@ func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T)
backupIDs = append(backupIDs, backup.ID)
}
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
}
cleaner := CreateTestBackupCleaner(mockBilling)
err = cleaner.cleanExceededStorageBackups(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -284,6 +300,7 @@ func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T)
}
func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
enableCloud(t)
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
@@ -307,28 +324,29 @@ func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 50,
BackupIntervalID: interval.ID,
BackupInterval: interval,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
now := time.Now().UTC()
// 3 completed at 500 MB each = 1500 MB, limit = 1 GB (1024 MB)
completedBackups := make([]*backups_core.Backup, 3)
for i := 0; i < 3; i++ {
for i := range 3 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 30,
BackupSizeMb: 500,
CreatedAt: now.Add(-time.Duration(3-i) * time.Hour),
}
err = backupRepository.Save(backup)
@@ -347,8 +365,11 @@ func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
err = backupRepository.Save(inProgressBackup)
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
}
cleaner := CreateTestBackupCleaner(mockBilling)
err = cleaner.cleanExceededStorageBackups(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -365,7 +386,8 @@ func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
assert.True(t, inProgressFound, "In-progress backup should not be deleted")
}
func Test_CleanExceededBackups_WithZeroLimit_SkipsDatabase(t *testing.T) {
func Test_CleanExceededBackups_WithZeroStorageLimit_RemovesAllBackups(t *testing.T) {
enableCloud(t)
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
@@ -389,38 +411,42 @@ func Test_CleanExceededBackups_WithZeroLimit_SkipsDatabase(t *testing.T) {
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 0,
BackupIntervalID: interval.ID,
BackupInterval: interval,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
for i := 0; i < 10; i++ {
for i := range 10 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 100,
CreatedAt: time.Now().UTC().Add(-time.Duration(i) * time.Hour),
CreatedAt: time.Now().UTC().Add(-time.Duration(i+2) * time.Hour),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
}
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
// StorageGB=0 means no storage allowed — all backups should be removed
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{StorageGB: 0, Status: billing_models.StatusActive},
}
cleaner := CreateTestBackupCleaner(mockBilling)
err = cleaner.cleanExceededStorageBackups(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Equal(t, 10, len(remainingBackups))
assert.Equal(t, 0, len(remainingBackups))
}
func Test_GetTotalSizeByDatabase_CalculatesCorrectly(t *testing.T) {
@@ -522,13 +548,14 @@ func Test_CleanByCount_KeepsNewestNBackups_DeletesOlder(t *testing.T) {
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
now := time.Now().UTC()
var backupIDs []uuid.UUID
for i := 0; i < 5; i++ {
for i := range 5 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
@@ -545,7 +572,7 @@ func Test_CleanByCount_KeepsNewestNBackups_DeletesOlder(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -594,11 +621,12 @@ func Test_CleanByCount_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
for i := 0; i < 5; i++ {
for i := range 5 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
@@ -612,7 +640,7 @@ func Test_CleanByCount_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -651,13 +679,14 @@ func Test_CleanByCount_DoesNotDeleteInProgressBackups(t *testing.T) {
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
now := time.Now().UTC()
for i := 0; i < 3; i++ {
for i := range 3 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
@@ -682,7 +711,7 @@ func Test_CleanByCount_DoesNotDeleteInProgressBackups(t *testing.T) {
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -776,6 +805,7 @@ func Test_CleanByTimePeriod_SkipsRecentBackup_EvenIfOlderThanRetention(t *testin
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -805,7 +835,7 @@ func Test_CleanByTimePeriod_SkipsRecentBackup_EvenIfOlderThanRetention(t *testin
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -847,6 +877,7 @@ func Test_CleanByCount_SkipsRecentBackup_EvenIfOverLimit(t *testing.T) {
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -893,7 +924,7 @@ func Test_CleanByCount_SkipsRecentBackup_EvenIfOverLimit(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -914,7 +945,8 @@ func Test_CleanByCount_SkipsRecentBackup_EvenIfOverLimit(t *testing.T) {
assert.True(t, remainingIDs[newestBackup.ID], "Newest backup should be preserved")
}
func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testing.T) {
func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverStorageLimit(t *testing.T) {
enableCloud(t)
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
@@ -937,18 +969,18 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
interval := createTestInterval()
// Total size limit is 10 MB. We have two backups of 8 MB each (16 MB total).
// Total size limit = 1 GB (1024 MB). Two backups of 600 MB each (1200 MB total).
// The oldest backup was created 30 minutes ago — within the grace period.
// The cleaner must stop and leave both backups intact.
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 10,
BackupIntervalID: interval.ID,
BackupInterval: interval,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -960,7 +992,7 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 8,
BackupSizeMb: 600,
CreatedAt: now.Add(-30 * time.Minute),
}
newerRecentBackup := &backups_core.Backup{
@@ -968,7 +1000,7 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 8,
BackupSizeMb: 600,
CreatedAt: now.Add(-10 * time.Minute),
}
@@ -977,8 +1009,11 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
err = backupRepository.Save(newerRecentBackup)
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
}
cleaner := CreateTestBackupCleaner(mockBilling)
err = cleaner.cleanExceededStorageBackups(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -991,6 +1026,82 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
)
}
func Test_CleanExceededStorageBackups_WhenNonCloud_SkipsCleanup(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// 5 backups at 500 MB each = 2500 MB, would exceed 1 GB limit in cloud mode
now := time.Now().UTC()
for i := range 5 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 500,
CreatedAt: now.Add(-time.Duration(i+2) * time.Hour),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
}
// IsCloud is false by default — cleaner should skip entirely
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
}
cleaner := CreateTestBackupCleaner(mockBilling)
err = cleaner.cleanExceededStorageBackups(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Equal(t, 5, len(remainingBackups), "All backups must remain in non-cloud mode")
}
type mockBillingService struct {
subscription *billing_models.Subscription
err error
}
func (m *mockBillingService) GetSubscription(
logger *slog.Logger,
databaseID uuid.UUID,
) (*billing_models.Subscription, error) {
return m.subscription, m.err
}
// Mock listener for testing
type mockBackupRemoveListener struct {
onBeforeBackupRemove func(*backups_core.Backup) error
@@ -1041,7 +1152,7 @@ func Test_CleanStaleUploadedBasebackups_MarksAsFailed(t *testing.T) {
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanStaleUploadedBasebackups()
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
assert.NoError(t, err)
updated, err := backupRepository.FindByID(staleBackup.ID)
@@ -1088,7 +1199,7 @@ func Test_CleanStaleUploadedBasebackups_SkipsRecentUploads(t *testing.T) {
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanStaleUploadedBasebackups()
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
assert.NoError(t, err)
updated, err := backupRepository.FindByID(recentBackup.ID)
@@ -1131,7 +1242,7 @@ func Test_CleanStaleUploadedBasebackups_SkipsActiveStreaming(t *testing.T) {
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanStaleUploadedBasebackups()
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
assert.NoError(t, err)
updated, err := backupRepository.FindByID(activeBackup.ID)
@@ -1179,7 +1290,7 @@ func Test_CleanStaleUploadedBasebackups_CleansStorageFiles(t *testing.T) {
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanStaleUploadedBasebackups()
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
assert.NoError(t, err)
updated, err := backupRepository.FindByID(staleBackup.ID)
@@ -1189,6 +1300,18 @@ func Test_CleanStaleUploadedBasebackups_CleansStorageFiles(t *testing.T) {
assert.Contains(t, *updated.FailMessage, "finalization timed out")
}
func enableCloud(t *testing.T) {
t.Helper()
config.GetEnv().IsCloud = true
t.Cleanup(func() {
config.GetEnv().IsCloud = false
})
}
func testLogger() *slog.Logger {
return logger.GetLogger().With("task_name", "test")
}
func createTestInterval() *intervals.Interval {
timeOfDay := "04:00"
interval := &intervals.Interval{

View File

@@ -1,7 +1,6 @@
package backuping
import (
"sync"
"sync/atomic"
"time"
@@ -10,6 +9,7 @@ import (
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/usecases"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/billing"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
@@ -28,10 +28,10 @@ var backupCleaner = &BackupCleaner{
backupRepository,
storages.GetStorageService(),
backups_config.GetBackupConfigService(),
billing.GetBillingService(),
encryption.GetFieldEncryptor(),
logger.GetLogger(),
[]backups_core.BackupRemoveListener{},
sync.Once{},
atomic.Bool{},
}
@@ -41,7 +41,6 @@ var backupNodesRegistry = &BackupNodesRegistry{
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
sync.Once{},
atomic.Bool{},
}
@@ -63,7 +62,6 @@ var backuperNode = &BackuperNode{
usecases.GetCreateBackupUsecase(),
getNodeID(),
time.Time{},
sync.Once{},
atomic.Bool{},
}
@@ -73,11 +71,11 @@ var backupsScheduler = &BackupsScheduler{
taskCancelManager,
backupNodesRegistry,
databases.GetDatabaseService(),
billing.GetBillingService(),
time.Now().UTC(),
logger.GetLogger(),
make(map[uuid.UUID]BackupToNodeRelation),
backuperNode,
sync.Once{},
atomic.Bool{},
}

View File

@@ -0,0 +1,13 @@
package backuping
import (
"log/slog"
"github.com/google/uuid"
billing_models "databasus-backend/internal/features/billing/models"
)
type BillingService interface {
GetSubscription(logger *slog.Logger, databaseID uuid.UUID) (*billing_models.Subscription, error)
}

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"log/slog"
"strings"
"sync"
"sync/atomic"
"time"
@@ -50,36 +49,30 @@ type BackupNodesRegistry struct {
pubsubBackups *cache_utils.PubSubManager
pubsubCompletions *cache_utils.PubSubManager
runOnce sync.Once
hasRun atomic.Bool
hasRun atomic.Bool
}
func (r *BackupNodesRegistry) Run(ctx context.Context) {
wasAlreadyRun := r.hasRun.Load()
if r.hasRun.Swap(true) {
panic(fmt.Sprintf("%T.Run() called multiple times", r))
}
r.runOnce.Do(func() {
r.hasRun.Store(true)
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
}
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
}
ticker := time.NewTicker(cleanupTickerInterval)
defer ticker.Stop()
ticker := time.NewTicker(cleanupTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes", "error", err)
}
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes", "error", err)
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", r))
}
}

View File

@@ -4,8 +4,6 @@ import (
"context"
"encoding/json"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
@@ -322,7 +320,7 @@ func Test_GetAvailableNodes_SkipsInvalidJsonData(t *testing.T) {
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
ctx, cancel := context.WithTimeout(t.Context(), registry.timeout)
defer cancel()
invalidKey := nodeInfoKeyPrefix + uuid.New().String() + nodeInfoKeySuffix
@@ -331,7 +329,7 @@ func Test_GetAvailableNodes_SkipsInvalidJsonData(t *testing.T) {
registry.client.B().Set().Key(invalidKey).Value("invalid json data").Build(),
)
defer func() {
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), registry.timeout)
cleanupCtx, cleanupCancel := context.WithTimeout(t.Context(), registry.timeout)
defer cleanupCancel()
registry.client.Do(cleanupCtx, registry.client.B().Del().Key(invalidKey).Build())
}()
@@ -401,7 +399,7 @@ func Test_GetAvailableNodes_ExcludesStaleNodesFromCache(t *testing.T) {
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
assert.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
ctx, cancel := context.WithTimeout(t.Context(), registry.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
@@ -419,7 +417,7 @@ func Test_GetAvailableNodes_ExcludesStaleNodesFromCache(t *testing.T) {
modifiedData, err := json.Marshal(node)
assert.NoError(t, err)
setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout)
setCtx, setCancel := context.WithTimeout(t.Context(), registry.timeout)
defer setCancel()
setResult := registry.client.Do(
setCtx,
@@ -464,7 +462,7 @@ func Test_GetBackupNodesStats_ExcludesStaleNodesFromCache(t *testing.T) {
err = registry.IncrementBackupsInProgress(node3.ID)
assert.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
ctx, cancel := context.WithTimeout(t.Context(), registry.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
@@ -482,7 +480,7 @@ func Test_GetBackupNodesStats_ExcludesStaleNodesFromCache(t *testing.T) {
modifiedData, err := json.Marshal(node)
assert.NoError(t, err)
setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout)
setCtx, setCancel := context.WithTimeout(t.Context(), registry.timeout)
defer setCancel()
setResult := registry.client.Do(
setCtx,
@@ -524,7 +522,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) {
err = registry.IncrementBackupsInProgress(node2.ID)
assert.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
ctx, cancel := context.WithTimeout(t.Context(), registry.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
@@ -542,7 +540,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) {
modifiedData, err := json.Marshal(node)
assert.NoError(t, err)
setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout)
setCtx, setCancel := context.WithTimeout(t.Context(), registry.timeout)
defer setCancel()
setResult := registry.client.Do(
setCtx,
@@ -553,7 +551,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) {
err = registry.cleanupDeadNodes()
assert.NoError(t, err)
checkCtx, checkCancel := context.WithTimeout(context.Background(), registry.timeout)
checkCtx, checkCancel := context.WithTimeout(t.Context(), registry.timeout)
defer checkCancel()
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
@@ -566,7 +564,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) {
node2.ID.String(),
nodeActiveBackupsSuffix,
)
counterCtx, counterCancel := context.WithTimeout(context.Background(), registry.timeout)
counterCtx, counterCancel := context.WithTimeout(t.Context(), registry.timeout)
defer counterCancel()
counterResult := registry.client.Do(
counterCtx,
@@ -575,7 +573,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) {
assert.Error(t, counterResult.Error())
activeInfoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node1.ID.String(), nodeInfoKeySuffix)
activeCtx, activeCancel := context.WithTimeout(context.Background(), registry.timeout)
activeCtx, activeCancel := context.WithTimeout(t.Context(), registry.timeout)
defer activeCancel()
activeResult := registry.client.Do(
activeCtx,
@@ -601,8 +599,6 @@ func createTestRegistry() *BackupNodesRegistry {
timeout: cache_utils.DefaultCacheTimeout,
pubsubBackups: cache_utils.NewPubSubManager(),
pubsubCompletions: cache_utils.NewPubSubManager(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}
@@ -732,7 +728,7 @@ func Test_SubscribeNodeForBackupsAssignment_HandlesInvalidJson(t *testing.T) {
time.Sleep(100 * time.Millisecond)
ctx := context.Background()
ctx := t.Context()
err = registry.pubsubBackups.Publish(ctx, "backup:submit", "invalid json")
assert.NoError(t, err)
@@ -978,7 +974,7 @@ func Test_SubscribeForBackupsCompletions_HandlesInvalidJson(t *testing.T) {
time.Sleep(100 * time.Millisecond)
ctx := context.Background()
ctx := t.Context()
err = registry.pubsubCompletions.Publish(ctx, "backup:completion", "invalid json")
assert.NoError(t, err)
@@ -1093,7 +1089,7 @@ func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) {
receivedAll2 := []uuid.UUID{}
receivedAll3 := []uuid.UUID{}
for i := 0; i < 3; i++ {
for range 3 {
select {
case received := <-receivedBackups1:
receivedAll1 = append(receivedAll1, received)
@@ -1102,7 +1098,7 @@ func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) {
}
}
for i := 0; i < 3; i++ {
for range 3 {
select {
case received := <-receivedBackups2:
receivedAll2 = append(receivedAll2, received)
@@ -1111,7 +1107,7 @@ func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) {
}
}
for i := 0; i < 3; i++ {
for range 3 {
select {
case received := <-receivedBackups3:
receivedAll3 = append(receivedAll3, received)

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
@@ -29,6 +28,7 @@ type BackupsScheduler struct {
taskCancelManager *task_cancellation.TaskCancelManager
backupNodesRegistry *BackupNodesRegistry
databaseService *databases.DatabaseService
billingService BillingService
lastBackupTime time.Time
logger *slog.Logger
@@ -36,68 +36,61 @@ type BackupsScheduler struct {
backupToNodeRelations map[uuid.UUID]BackupToNodeRelation
backuperNode *BackuperNode
runOnce sync.Once
hasRun atomic.Bool
hasRun atomic.Bool
}
func (s *BackupsScheduler) Run(ctx context.Context) {
wasAlreadyRun := s.hasRun.Load()
s.runOnce.Do(func() {
s.hasRun.Store(true)
s.lastBackupTime = time.Now().UTC()
if config.GetEnv().IsManyNodesMode {
// wait other nodes to start
time.Sleep(schedulerStartupDelay)
}
if err := s.failBackupsInProgress(); err != nil {
s.logger.Error("Failed to fail backups in progress", "error", err)
panic(err)
}
err := s.backupNodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted)
if err != nil {
s.logger.Error("Failed to subscribe to backup completions", "error", err)
panic(err)
}
defer func() {
if err := s.backupNodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
}
}()
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(schedulerTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.checkDeadNodesAndFailBackups(); err != nil {
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
}
if err := s.runPendingBackups(); err != nil {
s.logger.Error("Failed to run pending backups", "error", err)
}
s.lastBackupTime = time.Now().UTC()
}
}
})
if wasAlreadyRun {
if s.hasRun.Swap(true) {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
s.lastBackupTime = time.Now().UTC()
if config.GetEnv().IsManyNodesMode {
// wait other nodes to start
time.Sleep(schedulerStartupDelay)
}
if err := s.failBackupsInProgress(); err != nil {
s.logger.Error("Failed to fail backups in progress", "error", err)
panic(err)
}
err := s.backupNodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted)
if err != nil {
s.logger.Error("Failed to subscribe to backup completions", "error", err)
panic(err)
}
defer func() {
if err := s.backupNodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
}
}()
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(schedulerTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.checkDeadNodesAndFailBackups(); err != nil {
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
}
if err := s.runPendingBackups(); err != nil {
s.logger.Error("Failed to run pending backups", "error", err)
}
s.lastBackupTime = time.Now().UTC()
}
}
}
func (s *BackupsScheduler) IsSchedulerRunning() bool {
@@ -127,6 +120,34 @@ func (s *BackupsScheduler) StartBackup(database *databases.Database, isCallNotif
return
}
if config.GetEnv().IsCloud {
subscription, subErr := s.billingService.GetSubscription(s.logger, database.ID)
if subErr != nil || !subscription.CanCreateNewBackups() {
failMessage := "subscription has expired, please renew"
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: *backupConfig.StorageID,
Status: backups_core.BackupStatusFailed,
FailMessage: &failMessage,
IsSkipRetry: true,
CreatedAt: time.Now().UTC(),
}
backup.GenerateFilename(database.Name)
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error(
"failed to save failed backup for expired subscription",
"database_id", database.ID,
"error", err,
)
}
return
}
}
// Check for existing in-progress backups
inProgressBackups, err := s.backupRepository.FindByDatabaseIdAndStatus(
database.ID,
@@ -346,6 +367,27 @@ func (s *BackupsScheduler) runPendingBackups() error {
continue
}
if config.GetEnv().IsCloud {
subscription, subErr := s.billingService.GetSubscription(s.logger, backupConfig.DatabaseID)
if subErr != nil {
s.logger.Warn(
"failed to get subscription, skipping backup",
"database_id", backupConfig.DatabaseID,
"error", subErr,
)
continue
}
if !subscription.CanCreateNewBackups() {
s.logger.Debug(
"subscription is not active, skipping scheduled backup",
"database_id", backupConfig.DatabaseID,
"subscription_status", subscription.Status,
)
continue
}
}
s.StartBackup(database, remainedBackupTryCount == 1)
continue
}

View File

@@ -10,6 +10,7 @@ import (
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
billing_models "databasus-backend/internal/features/billing/models"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
@@ -968,7 +969,7 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T
cache_utils.ClearAllCache()
// Start scheduler so it can handle task completions
scheduler := CreateTestScheduler()
scheduler := CreateTestScheduler(nil)
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
@@ -1065,7 +1066,7 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
cache_utils.ClearAllCache()
// Start scheduler so it can handle task completions
scheduler := CreateTestScheduler()
scheduler := CreateTestScheduler(nil)
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
@@ -1332,7 +1333,7 @@ func Test_StartBackup_When2BackupsStartedForDifferentDatabases_BothUseCasesAreCa
defer StopBackuperNodeForTest(t, cancel, backuperNode)
// Create scheduler
scheduler := CreateTestScheduler()
scheduler := CreateTestScheduler(nil)
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
@@ -1458,3 +1459,313 @@ func Test_StartBackup_When2BackupsStartedForDifferentDatabases_BothUseCasesAreCa
time.Sleep(200 * time.Millisecond)
}
func Test_StartBackup_WhenCloudAndSubscriptionExpired_CreatesFailedBackup(t *testing.T) {
cache_utils.ClearAllCache()
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{
Status: billing_models.StatusExpired,
},
}
scheduler := CreateTestScheduler(mockBilling)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
enableCloud(t)
scheduler.StartBackup(database, false)
time.Sleep(100 * time.Millisecond)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
newestBackup := backups[0]
assert.Equal(t, backups_core.BackupStatusFailed, newestBackup.Status)
assert.NotNil(t, newestBackup.FailMessage)
assert.Equal(t, "subscription has expired, please renew", *newestBackup.FailMessage)
assert.True(t, newestBackup.IsSkipRetry)
time.Sleep(200 * time.Millisecond)
}
func Test_StartBackup_WhenCloudAndSubscriptionActive_ProceedsNormally(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{
Status: billing_models.StatusActive,
StorageGB: 10,
},
}
scheduler := CreateTestScheduler(mockBilling)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
enableCloud(t)
scheduler.StartBackup(database, false)
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
newestBackup := backups[0]
assert.Equal(t, backups_core.BackupStatusCompleted, newestBackup.Status)
time.Sleep(200 * time.Millisecond)
}
func Test_RunPendingBackups_WhenCloudAndSubscriptionExpired_SilentlySkips(t *testing.T) {
cache_utils.ClearAllCache()
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{
Status: billing_models.StatusExpired,
},
}
scheduler := CreateTestScheduler(mockBilling)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
})
enableCloud(t)
scheduler.runPendingBackups()
time.Sleep(100 * time.Millisecond)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1, "No new backup should be created, scheduler silently skips expired subscriptions")
time.Sleep(200 * time.Millisecond)
}
func Test_StartBackup_WhenNotCloudAndSubscriptionExpired_ProceedsNormally(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{
Status: billing_models.StatusExpired,
},
}
scheduler := CreateTestScheduler(mockBilling)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
scheduler.StartBackup(database, false)
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
assert.Equal(t, backups_core.BackupStatusCompleted, backups[0].Status,
"Billing check should not apply in non-cloud mode")
time.Sleep(200 * time.Millisecond)
}
func Test_RunPendingBackups_WhenNotCloudAndSubscriptionExpired_ProceedsNormally(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{
Status: billing_models.StatusExpired,
},
}
scheduler := CreateTestScheduler(mockBilling)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
})
scheduler.runPendingBackups()
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 2, "Billing check should not apply in non-cloud mode, new backup should be created")
time.Sleep(200 * time.Millisecond)
}

View File

@@ -3,7 +3,6 @@ package backuping
import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
@@ -35,58 +34,70 @@ func CreateTestRouter() *gin.Engine {
return router
}
func CreateTestBackupCleaner(billingService BillingService) *BackupCleaner {
return &BackupCleaner{
backupRepository,
storages.GetStorageService(),
backups_config.GetBackupConfigService(),
billingService,
encryption.GetFieldEncryptor(),
logger.GetLogger(),
[]backups_core.BackupRemoveListener{},
atomic.Bool{},
}
}
func CreateTestBackuperNode() *BackuperNode {
return &BackuperNode{
databaseService: databases.GetDatabaseService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
workspaceService: workspaces_services.GetWorkspaceService(),
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
notificationSender: notifiers.GetNotifierService(),
backupCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
logger: logger.GetLogger(),
createBackupUseCase: usecases.GetCreateBackupUsecase(),
nodeID: uuid.New(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
databases.GetDatabaseService(),
encryption.GetFieldEncryptor(),
workspaces_services.GetWorkspaceService(),
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
taskCancelManager,
backupNodesRegistry,
logger.GetLogger(),
usecases.GetCreateBackupUsecase(),
uuid.New(),
time.Time{},
atomic.Bool{},
}
}
func CreateTestBackuperNodeWithUseCase(useCase backups_core.CreateBackupUsecase) *BackuperNode {
return &BackuperNode{
databaseService: databases.GetDatabaseService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
workspaceService: workspaces_services.GetWorkspaceService(),
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
notificationSender: notifiers.GetNotifierService(),
backupCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
logger: logger.GetLogger(),
createBackupUseCase: useCase,
nodeID: uuid.New(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
databases.GetDatabaseService(),
encryption.GetFieldEncryptor(),
workspaces_services.GetWorkspaceService(),
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
taskCancelManager,
backupNodesRegistry,
logger.GetLogger(),
useCase,
uuid.New(),
time.Time{},
atomic.Bool{},
}
}
func CreateTestScheduler() *BackupsScheduler {
func CreateTestScheduler(billingService BillingService) *BackupsScheduler {
return &BackupsScheduler{
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
taskCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
lastBackupTime: time.Now().UTC(),
logger: logger.GetLogger(),
backupToNodeRelations: make(map[uuid.UUID]BackupToNodeRelation),
backuperNode: CreateTestBackuperNode(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
backupRepository,
backups_config.GetBackupConfigService(),
taskCancelManager,
backupNodesRegistry,
databases.GetDatabaseService(),
billingService,
time.Now().UTC(),
logger.GetLogger(),
make(map[uuid.UUID]BackupToNodeRelation),
CreateTestBackuperNode(),
atomic.Bool{},
}
}

View File

@@ -40,12 +40,15 @@ func (c *BackupController) RegisterPublicRoutes(router *gin.RouterGroup) {
// GetBackups
// @Summary Get backups for a database
// @Description Get paginated backups for the specified database
// @Description Get paginated backups for the specified database with optional filters
// @Tags backups
// @Produce json
// @Param database_id query string true "Database ID"
// @Param limit query int false "Number of items per page" default(10)
// @Param offset query int false "Offset for pagination" default(0)
// @Param status query []string false "Filter by backup status (can be repeated)" Enums(IN_PROGRESS, COMPLETED, FAILED, CANCELED)
// @Param beforeDate query string false "Filter backups created before this date (RFC3339)" format(date-time)
// @Param pgWalBackupType query string false "Filter by WAL backup type" Enums(PG_FULL_BACKUP, PG_WAL_SEGMENT)
// @Success 200 {object} backups_dto.GetBackupsResponse
// @Failure 400
// @Failure 401
@@ -70,7 +73,9 @@ func (c *BackupController) GetBackups(ctx *gin.Context) {
return
}
response, err := c.backupService.GetBackups(user, databaseID, request.Limit, request.Offset)
filters := c.buildBackupFilters(&request)
response, err := c.backupService.GetBackups(user, databaseID, request.Limit, request.Offset, filters)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -359,3 +364,35 @@ func (c *BackupController) startDownloadHeartbeat(ctx context.Context, userID uu
}
}
}
func (c *BackupController) buildBackupFilters(
request *backups_dto.GetBackupsRequest,
) *backups_core.BackupFilters {
isHasFilters := len(request.Statuses) > 0 ||
request.BeforeDate != nil ||
request.PgWalBackupType != nil
if !isHasFilters {
return nil
}
filters := &backups_core.BackupFilters{}
if len(request.Statuses) > 0 {
statuses := make([]backups_core.BackupStatus, 0, len(request.Statuses))
for _, statusStr := range request.Statuses {
statuses = append(statuses, backups_core.BackupStatus(statusStr))
}
filters.Statuses = statuses
}
filters.BeforeDate = request.BeforeDate
if request.PgWalBackupType != nil {
walType := backups_core.PgWalBackupType(*request.PgWalBackupType)
filters.PgWalBackupType = &walType
}
return filters
}

View File

@@ -140,6 +140,225 @@ func Test_GetBackups_PermissionsEnforced(t *testing.T) {
}
}
func Test_GetBackups_WithStatusFilter_ReturnsFilteredBackups(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
now := time.Now().UTC()
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
Status: backups_core.BackupStatusCompleted,
CreatedAt: now.Add(-3 * time.Hour),
})
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
Status: backups_core.BackupStatusFailed,
CreatedAt: now.Add(-2 * time.Hour),
})
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
Status: backups_core.BackupStatusCanceled,
CreatedAt: now.Add(-1 * time.Hour),
})
// Single status filter
var singleResponse backups_dto.GetBackupsResponse
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups?database_id=%s&status=COMPLETED", database.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&singleResponse,
)
assert.Equal(t, int64(1), singleResponse.Total)
assert.Len(t, singleResponse.Backups, 1)
assert.Equal(t, backups_core.BackupStatusCompleted, singleResponse.Backups[0].Status)
// Multiple status filter
var multiResponse backups_dto.GetBackupsResponse
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf(
"/api/v1/backups?database_id=%s&status=COMPLETED&status=FAILED",
database.ID.String(),
),
"Bearer "+owner.Token,
http.StatusOK,
&multiResponse,
)
assert.Equal(t, int64(2), multiResponse.Total)
assert.Len(t, multiResponse.Backups, 2)
for _, backup := range multiResponse.Backups {
assert.True(
t,
backup.Status == backups_core.BackupStatusCompleted ||
backup.Status == backups_core.BackupStatusFailed,
"expected COMPLETED or FAILED, got %s", backup.Status,
)
}
}
func Test_GetBackups_WithBeforeDateFilter_ReturnsFilteredBackups(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
now := time.Now().UTC()
cutoff := now.Add(-1 * time.Hour)
olderBackup := CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
Status: backups_core.BackupStatusCompleted,
CreatedAt: now.Add(-3 * time.Hour),
})
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
Status: backups_core.BackupStatusCompleted,
CreatedAt: now,
})
var response backups_dto.GetBackupsResponse
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf(
"/api/v1/backups?database_id=%s&beforeDate=%s",
database.ID.String(),
cutoff.Format(time.RFC3339),
),
"Bearer "+owner.Token,
http.StatusOK,
&response,
)
assert.Equal(t, int64(1), response.Total)
assert.Len(t, response.Backups, 1)
assert.Equal(t, olderBackup.ID, response.Backups[0].ID)
}
func Test_GetBackups_WithPgWalBackupTypeFilter_ReturnsFilteredBackups(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
now := time.Now().UTC()
fullBackupType := backups_core.PgWalBackupTypeFullBackup
walSegmentType := backups_core.PgWalBackupTypeWalSegment
fullBackup := CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
Status: backups_core.BackupStatusCompleted,
CreatedAt: now.Add(-2 * time.Hour),
PgWalBackupType: &fullBackupType,
})
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
Status: backups_core.BackupStatusCompleted,
CreatedAt: now.Add(-1 * time.Hour),
PgWalBackupType: &walSegmentType,
})
var response backups_dto.GetBackupsResponse
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf(
"/api/v1/backups?database_id=%s&pgWalBackupType=PG_FULL_BACKUP",
database.ID.String(),
),
"Bearer "+owner.Token,
http.StatusOK,
&response,
)
assert.Equal(t, int64(1), response.Total)
assert.Len(t, response.Backups, 1)
assert.Equal(t, fullBackup.ID, response.Backups[0].ID)
}
func Test_GetBackups_WithCombinedFilters_ReturnsFilteredBackups(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
now := time.Now().UTC()
cutoff := now.Add(-1 * time.Hour)
// Old completed — should match
oldCompleted := CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
Status: backups_core.BackupStatusCompleted,
CreatedAt: now.Add(-3 * time.Hour),
})
// Old failed — should NOT match (wrong status)
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
Status: backups_core.BackupStatusFailed,
CreatedAt: now.Add(-2 * time.Hour),
})
// New completed — should NOT match (too recent)
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
Status: backups_core.BackupStatusCompleted,
CreatedAt: now,
})
var response backups_dto.GetBackupsResponse
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf(
"/api/v1/backups?database_id=%s&status=COMPLETED&beforeDate=%s",
database.ID.String(),
cutoff.Format(time.RFC3339),
),
"Bearer "+owner.Token,
http.StatusOK,
&response,
)
assert.Equal(t, int64(1), response.Total)
assert.Len(t, response.Backups, 1)
assert.Equal(t, oldCompleted.ID, response.Backups[0].ID)
}
func Test_CreateBackup_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
@@ -376,7 +595,7 @@ func Test_DeleteBackup_PermissionsEnforced(t *testing.T) {
ownerUser, err := userService.GetUserFromToken(owner.Token)
assert.NoError(t, err)
response, err := backups_services.GetBackupService().GetBackups(ownerUser, database.ID, 10, 0)
response, err := backups_services.GetBackupService().GetBackups(ownerUser, database.ID, 10, 0, nil)
assert.NoError(t, err)
assert.Equal(t, 0, len(response.Backups))
}
@@ -1263,7 +1482,7 @@ func Test_MakeBackup_VerifyBackupAndMetadataFilesExistInStorage(t *testing.T) {
backuperCancel := backuping.StartBackuperNodeForTest(t, backuperNode)
defer backuping.StopBackuperNodeForTest(t, backuperCancel, backuperNode)
scheduler := backuping.CreateTestScheduler()
scheduler := backuping.CreateTestScheduler(nil)
schedulerCancel := backuping.StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
@@ -1838,7 +2057,7 @@ func Test_DeleteBackup_RemovesBackupAndMetadataFilesFromDisk(t *testing.T) {
backuperCancel := backuping.StartBackuperNodeForTest(t, backuperNode)
defer backuping.StopBackuperNodeForTest(t, backuperCancel, backuperNode)
scheduler := backuping.CreateTestScheduler()
scheduler := backuping.CreateTestScheduler(nil)
schedulerCancel := backuping.StartSchedulerForTest(t, scheduler)
defer schedulerCancel()

View File

@@ -95,3 +95,33 @@ func CreateTestBackup(databaseID, storageID uuid.UUID) *backups_core.Backup {
return backup
}
type TestBackupOptions struct {
Status backups_core.BackupStatus
CreatedAt time.Time
PgWalBackupType *backups_core.PgWalBackupType
}
// CreateTestBackupWithOptions creates a test backup with custom status, time, and WAL type
func CreateTestBackupWithOptions(
databaseID, storageID uuid.UUID,
opts TestBackupOptions,
) *backups_core.Backup {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: databaseID,
StorageID: storageID,
Status: opts.Status,
BackupSizeMb: 10.5,
BackupDurationMs: 1000,
PgWalBackupType: opts.PgWalBackupType,
CreatedAt: opts.CreatedAt,
}
repo := &backups_core.BackupRepository{}
if err := repo.Save(backup); err != nil {
panic(err)
}
return backup
}

View File

@@ -0,0 +1,9 @@
package backups_core
import "time"
type BackupFilters struct {
Statuses []BackupStatus
BeforeDate *time.Time
PgWalBackupType *PgWalBackupType
}

View File

@@ -422,3 +422,67 @@ func (r *BackupRepository) FindLastWalSegmentAfter(
return &backup, nil
}
func (r *BackupRepository) FindByDatabaseIDWithFiltersAndPagination(
databaseID uuid.UUID,
filters *BackupFilters,
limit, offset int,
) ([]*Backup, error) {
var backups []*Backup
query := storage.
GetDb().
Where("database_id = ?", databaseID)
if filters != nil {
query = filters.applyToQuery(query)
}
if err := query.
Order("created_at DESC").
Limit(limit).
Offset(offset).
Find(&backups).Error; err != nil {
return nil, err
}
return backups, nil
}
func (r *BackupRepository) CountByDatabaseIDWithFilters(
databaseID uuid.UUID,
filters *BackupFilters,
) (int64, error) {
var count int64
query := storage.
GetDb().
Model(&Backup{}).
Where("database_id = ?", databaseID)
if filters != nil {
query = filters.applyToQuery(query)
}
if err := query.Count(&count).Error; err != nil {
return 0, err
}
return count, nil
}
func (f *BackupFilters) applyToQuery(query *gorm.DB) *gorm.DB {
if len(f.Statuses) > 0 {
query = query.Where("status IN ?", f.Statuses)
}
if f.BeforeDate != nil {
query = query.Where("created_at < ?", *f.BeforeDate)
}
if f.PgWalBackupType != nil {
query = query.Where("pg_wal_backup_type = ?", *f.PgWalBackupType)
}
return query
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
)
@@ -13,38 +12,31 @@ type DownloadTokenBackgroundService struct {
downloadTokenService *DownloadTokenService
logger *slog.Logger
runOnce sync.Once
hasRun atomic.Bool
hasRun atomic.Bool
}
func (s *DownloadTokenBackgroundService) Run(ctx context.Context) {
wasAlreadyRun := s.hasRun.Load()
s.runOnce.Do(func() {
s.hasRun.Store(true)
s.logger.Info("Starting download token cleanup background service")
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
s.logger.Error("Failed to clean expired download tokens", "error", err)
}
}
}
})
if wasAlreadyRun {
if s.hasRun.Swap(true) {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
s.logger.Info("Starting download token cleanup background service")
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
s.logger.Error("Failed to clean expired download tokens", "error", err)
}
}
}
}

View File

@@ -1,9 +1,6 @@
package backups_download
import (
"sync"
"sync/atomic"
"databasus-backend/internal/config"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/logger"
@@ -37,8 +34,6 @@ func init() {
downloadTokenBackgroundService = &DownloadTokenBackgroundService{
downloadTokenService: downloadTokenService,
logger: logger.GetLogger(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}

View File

@@ -11,9 +11,12 @@ import (
)
type GetBackupsRequest struct {
DatabaseID string `form:"database_id" binding:"required"`
Limit int `form:"limit"`
Offset int `form:"offset"`
DatabaseID string `form:"database_id" binding:"required"`
Limit int `form:"limit"`
Offset int `form:"offset"`
Statuses []string `form:"status"`
BeforeDate *time.Time `form:"beforeDate"`
PgWalBackupType *string `form:"pgWalBackupType"`
}
type GetBackupsResponse struct {

View File

@@ -2,7 +2,6 @@ package backups_services
import (
"sync"
"sync/atomic"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups/backuping"
@@ -59,26 +58,11 @@ func GetWalService() *PostgreWalBackupService {
return walService
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
var SetupDependencies = sync.OnceFunc(func() {
backups_config.
GetBackupConfigService().
SetDatabaseStorageChangeListener(backupService)
func SetupDependencies() {
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
backups_config.
GetBackupConfigService().
SetDatabaseStorageChangeListener(backupService)
databases.GetDatabaseService().AddDbRemoveListener(backupService)
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}
databases.GetDatabaseService().AddDbRemoveListener(backupService)
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
})

View File

@@ -109,6 +109,7 @@ func (s *BackupService) GetBackups(
user *users_models.User,
databaseID uuid.UUID,
limit, offset int,
filters *backups_core.BackupFilters,
) (*backups_dto.GetBackupsResponse, error) {
database, err := s.databaseService.GetDatabaseByID(databaseID)
if err != nil {
@@ -134,12 +135,14 @@ func (s *BackupService) GetBackups(
offset = 0
}
backups, err := s.backupRepository.FindByDatabaseIDWithPagination(databaseID, limit, offset)
backups, err := s.backupRepository.FindByDatabaseIDWithFiltersAndPagination(
databaseID, filters, limit, offset,
)
if err != nil {
return nil, err
}
total, err := s.backupRepository.CountByDatabaseID(databaseID)
total, err := s.backupRepository.CountByDatabaseIDWithFilters(databaseID, filters)
if err != nil {
return nil, err
}

View File

@@ -109,6 +109,7 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
"--routines",
"--quick",
"--skip-extended-insert",
"--skip-add-locks",
"--verbose",
}
@@ -129,6 +130,8 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
if mdb.IsHttps {
args = append(args, "--ssl")
args = append(args, "--skip-ssl-verify-server-cert")
} else {
args = append(args, "--skip-ssl")
}
if mdb.Database != nil && *mdb.Database != "" {
@@ -278,15 +281,9 @@ func (uc *CreateMariadbBackupUsecase) createTempMyCnfFile(
mdbConfig *mariadbtypes.MariadbDatabase,
password string,
) (string, error) {
tempFolder := config.GetEnv().TempFolder
if err := os.MkdirAll(tempFolder, 0o700); err != nil {
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
}
if err := os.Chmod(tempFolder, 0o700); err != nil {
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
}
tempDir, err := os.MkdirTemp(tempFolder, "mycnf_"+uuid.New().String())
// Credential files use OS temp dir (/tmp) because some filesystems
// (e.g. ZFS on TrueNAS) ignore chmod, causing "group or world access" errors.
tempDir, err := os.MkdirTemp(os.TempDir(), "mycnf_"+uuid.New().String())
if err != nil {
return "", fmt.Errorf("failed to create temp directory: %w", err)
}

View File

@@ -108,6 +108,7 @@ func (uc *CreateMysqlBackupUsecase) buildMysqldumpArgs(my *mysqltypes.MysqlDatab
"--set-gtid-purged=OFF",
"--quick",
"--skip-extended-insert",
"--skip-add-locks",
"--verbose",
}
@@ -126,6 +127,8 @@ func (uc *CreateMysqlBackupUsecase) buildMysqldumpArgs(my *mysqltypes.MysqlDatab
if my.IsHttps {
args = append(args, "--ssl-mode=REQUIRED")
} else {
args = append(args, "--ssl-mode=DISABLED")
}
if my.Database != nil && *my.Database != "" {
@@ -297,15 +300,9 @@ func (uc *CreateMysqlBackupUsecase) createTempMyCnfFile(
myConfig *mysqltypes.MysqlDatabase,
password string,
) (string, error) {
tempFolder := config.GetEnv().TempFolder
if err := os.MkdirAll(tempFolder, 0o700); err != nil {
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
}
if err := os.Chmod(tempFolder, 0o700); err != nil {
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
}
tempDir, err := os.MkdirTemp(tempFolder, "mycnf_"+uuid.New().String())
// Credential files use OS temp dir (/tmp) because some filesystems
// (e.g. ZFS on TrueNAS) ignore chmod, causing "group or world access" errors.
tempDir, err := os.MkdirTemp(os.TempDir(), "mycnf_"+uuid.New().String())
if err != nil {
return "", fmt.Errorf("failed to create temp directory: %w", err)
}
@@ -326,6 +323,8 @@ port=%d
if myConfig.IsHttps {
content += "ssl-mode=REQUIRED\n"
} else {
content += "ssl-mode=DISABLED\n"
}
err = os.WriteFile(myCnfFile, []byte(content), 0o600)

View File

@@ -747,15 +747,9 @@ func (uc *CreatePostgresqlBackupUsecase) createTempPgpassFile(
escapedPassword,
)
tempFolder := config.GetEnv().TempFolder
if err := os.MkdirAll(tempFolder, 0o700); err != nil {
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
}
if err := os.Chmod(tempFolder, 0o700); err != nil {
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
}
tempDir, err := os.MkdirTemp(tempFolder, "pgpass_"+uuid.New().String())
// Credential files use OS temp dir (/tmp) because some filesystems
// (e.g. ZFS on TrueNAS) ignore chmod, causing "group or world access" errors.
tempDir, err := os.MkdirTemp(os.TempDir(), "pgpass_"+uuid.New().String())
if err != nil {
return "", fmt.Errorf("failed to create temporary directory: %w", err)
}

View File

@@ -16,7 +16,6 @@ type BackupConfigController struct {
func (c *BackupConfigController) RegisterRoutes(router *gin.RouterGroup) {
router.POST("/backup-configs/save", c.SaveBackupConfig)
router.GET("/backup-configs/database/:id/plan", c.GetDatabasePlan)
router.GET("/backup-configs/database/:id", c.GetBackupConfigByDbID)
router.GET("/backup-configs/storage/:id/is-using", c.IsStorageUsing)
router.GET("/backup-configs/storage/:id/databases-count", c.CountDatabasesForStorage)
@@ -93,39 +92,6 @@ func (c *BackupConfigController) GetBackupConfigByDbID(ctx *gin.Context) {
ctx.JSON(http.StatusOK, backupConfig)
}
// GetDatabasePlan
// @Summary Get database plan by database ID
// @Description Get the plan limits for a specific database (max backup size, max total size, max storage period)
// @Tags backup-configs
// @Produce json
// @Param id path string true "Database ID"
// @Success 200 {object} plans.DatabasePlan
// @Failure 400 {object} map[string]string "Invalid database ID"
// @Failure 401 {object} map[string]string "User not authenticated"
// @Failure 404 {object} map[string]string "Database not found or access denied"
// @Router /backup-configs/database/{id}/plan [get]
func (c *BackupConfigController) GetDatabasePlan(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
return
}
plan, err := c.backupConfigService.GetDatabasePlan(user, id)
if err != nil {
ctx.JSON(http.StatusNotFound, gin.H{"error": "database plan not found"})
return
}
ctx.JSON(http.StatusOK, plan)
}
// IsStorageUsing
// @Summary Check if storage is being used
// @Description Check if a storage is currently being used by any backup configuration

View File

@@ -17,14 +17,12 @@ import (
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/features/storages"
local_storage "databasus-backend/internal/features/storages/models/local"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/storage"
"databasus-backend/internal/util/period"
test_utils "databasus-backend/internal/util/testing"
"databasus-backend/internal/util/tools"
@@ -326,218 +324,13 @@ func Test_GetBackupConfigByDbID_ReturnsDefaultConfigForNewDatabase(t *testing.T)
&response,
)
var plan plans.DatabasePlan
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
"Bearer "+owner.Token,
http.StatusOK,
&plan,
)
assert.Equal(t, database.ID, response.DatabaseID)
assert.False(t, response.IsBackupsEnabled)
assert.Equal(t, plan.MaxStoragePeriod, response.RetentionTimePeriod)
assert.Equal(t, plan.MaxBackupSizeMB, response.MaxBackupSizeMB)
assert.Equal(t, plan.MaxBackupsTotalSizeMB, response.MaxBackupsTotalSizeMB)
assert.True(t, response.IsRetryIfFailed)
assert.Equal(t, 3, response.MaxFailedTriesCount)
assert.NotNil(t, response.BackupInterval)
}
func Test_GetDatabasePlan_ForNewDatabase_PlanAlwaysReturned(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
var response plans.DatabasePlan
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
"Bearer "+owner.Token,
http.StatusOK,
&response,
)
assert.Equal(t, database.ID, response.DatabaseID)
assert.NotNil(t, response.MaxBackupSizeMB)
assert.NotNil(t, response.MaxBackupsTotalSizeMB)
assert.NotEmpty(t, response.MaxStoragePeriod)
}
func Test_SaveBackupConfig_WhenPlanLimitsAreAdjusted_ValidationEnforced(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Get plan via API (triggers auto-creation)
var plan plans.DatabasePlan
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
"Bearer "+owner.Token,
http.StatusOK,
&plan,
)
assert.Equal(t, database.ID, plan.DatabaseID)
// Adjust plan limits directly in database to fixed restrictive values
err := storage.GetDb().Model(&plans.DatabasePlan{}).
Where("database_id = ?", database.ID).
Updates(map[string]any{
"max_backup_size_mb": 100,
"max_backups_total_size_mb": 1000,
"max_storage_period": period.PeriodMonth,
}).Error
assert.NoError(t, err)
// Test 1: Try to save backup config with exceeded backup size limit
timeOfDay := "04:00"
backupConfigExceededSize := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 200, // Exceeds limit of 100
MaxBackupsTotalSizeMB: 800,
}
respExceededSize := test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigExceededSize,
http.StatusBadRequest,
)
assert.Contains(t, string(respExceededSize.Body), "max backup size exceeds plan limit")
// Test 2: Try to save backup config with exceeded total size limit
backupConfigExceededTotal := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 50,
MaxBackupsTotalSizeMB: 2000, // Exceeds limit of 1000
}
respExceededTotal := test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigExceededTotal,
http.StatusBadRequest,
)
assert.Contains(t, string(respExceededTotal.Body), "max total backups size exceeds plan limit")
// Test 3: Try to save backup config with exceeded storage period limit
backupConfigExceededPeriod := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodYear, // Exceeds limit of Month
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 80,
MaxBackupsTotalSizeMB: 800,
}
respExceededPeriod := test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigExceededPeriod,
http.StatusBadRequest,
)
assert.Contains(t, string(respExceededPeriod.Body), "storage period exceeds plan limit")
// Test 4: Save backup config within all limits - should succeed
backupConfigValid := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek, // Within Month limit
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 80, // Within 100 limit
MaxBackupsTotalSizeMB: 800, // Within 1000 limit
}
var responseValid BackupConfig
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigValid,
http.StatusOK,
&responseValid,
)
assert.Equal(t, database.ID, responseValid.DatabaseID)
assert.Equal(t, int64(80), responseValid.MaxBackupSizeMB)
assert.Equal(t, int64(800), responseValid.MaxBackupsTotalSizeMB)
assert.Equal(t, period.PeriodWeek, responseValid.RetentionTimePeriod)
}
func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string

View File

@@ -2,14 +2,11 @@ package backups_config
import (
"sync"
"sync/atomic"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/features/storages"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/logger"
)
var (
@@ -20,7 +17,6 @@ var (
storages.GetStorageService(),
notifiers.GetNotifierService(),
workspaces_services.GetWorkspaceService(),
plans.GetDatabasePlanService(),
nil,
}
)
@@ -37,21 +33,6 @@ func GetBackupConfigService() *BackupConfigService {
return backupConfigService
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}
var SetupDependencies = sync.OnceFunc(func() {
storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService)
})

View File

@@ -7,5 +7,5 @@ type TransferDatabaseRequest struct {
TargetStorageID *uuid.UUID `json:"targetStorageId,omitempty"`
IsTransferWithStorage bool `json:"isTransferWithStorage,omitempty"`
IsTransferWithNotifiers bool `json:"isTransferWithNotifiers,omitempty"`
TargetNotifierIDs []uuid.UUID `json:"targetNotifierIds,omitempty"`
TargetNotifierIDs []uuid.UUID `json:"targetNotifierIds,omitzero"`
}

View File

@@ -9,7 +9,6 @@ import (
"databasus-backend/internal/config"
"databasus-backend/internal/features/intervals"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/features/storages"
"databasus-backend/internal/util/period"
)
@@ -29,8 +28,8 @@ type BackupConfig struct {
RetentionGfsMonths int `json:"retentionGfsMonths" gorm:"column:retention_gfs_months;type:int;not null;default:0"`
RetentionGfsYears int `json:"retentionGfsYears" gorm:"column:retention_gfs_years;type:int;not null;default:0"`
BackupIntervalID uuid.UUID `json:"backupIntervalId" gorm:"column:backup_interval_id;type:uuid;not null"`
BackupInterval *intervals.Interval `json:"backupInterval,omitempty" gorm:"foreignKey:BackupIntervalID"`
BackupIntervalID uuid.UUID `json:"backupIntervalId" gorm:"column:backup_interval_id;type:uuid;not null"`
BackupInterval *intervals.Interval `json:"backupInterval,omitzero" gorm:"foreignKey:BackupIntervalID"`
Storage *storages.Storage `json:"storage" gorm:"foreignKey:StorageID"`
StorageID *uuid.UUID `json:"storageId" gorm:"column:storage_id;type:uuid;"`
@@ -42,11 +41,6 @@ type BackupConfig struct {
MaxFailedTriesCount int `json:"maxFailedTriesCount" gorm:"column:max_failed_tries_count;type:int;not null"`
Encryption BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
// MaxBackupSizeMB limits individual backup size. 0 = unlimited.
MaxBackupSizeMB int64 `json:"maxBackupSizeMb" gorm:"column:max_backup_size_mb;type:int;not null"`
// MaxBackupsTotalSizeMB limits total size of all backups. 0 = unlimited.
MaxBackupsTotalSizeMB int64 `json:"maxBackupsTotalSizeMb" gorm:"column:max_backups_total_size_mb;type:int;not null"`
}
func (h *BackupConfig) TableName() string {
@@ -86,12 +80,12 @@ func (b *BackupConfig) AfterFind(tx *gorm.DB) error {
return nil
}
func (b *BackupConfig) Validate(plan *plans.DatabasePlan) error {
func (b *BackupConfig) Validate() error {
if b.BackupIntervalID == uuid.Nil && b.BackupInterval == nil {
return errors.New("backup interval is required")
}
if err := b.validateRetentionPolicy(plan); err != nil {
if err := b.validateRetentionPolicy(); err != nil {
return err
}
@@ -110,67 +104,38 @@ func (b *BackupConfig) Validate(plan *plans.DatabasePlan) error {
}
}
if b.MaxBackupSizeMB < 0 {
return errors.New("max backup size must be non-negative")
}
if b.MaxBackupsTotalSizeMB < 0 {
return errors.New("max backups total size must be non-negative")
}
if plan.MaxBackupSizeMB > 0 {
if b.MaxBackupSizeMB == 0 || b.MaxBackupSizeMB > plan.MaxBackupSizeMB {
return errors.New("max backup size exceeds plan limit")
}
}
if plan.MaxBackupsTotalSizeMB > 0 {
if b.MaxBackupsTotalSizeMB == 0 ||
b.MaxBackupsTotalSizeMB > plan.MaxBackupsTotalSizeMB {
return errors.New("max total backups size exceeds plan limit")
}
}
return nil
}
func (b *BackupConfig) Copy(newDatabaseID uuid.UUID) *BackupConfig {
return &BackupConfig{
DatabaseID: newDatabaseID,
IsBackupsEnabled: b.IsBackupsEnabled,
RetentionPolicyType: b.RetentionPolicyType,
RetentionTimePeriod: b.RetentionTimePeriod,
RetentionCount: b.RetentionCount,
RetentionGfsHours: b.RetentionGfsHours,
RetentionGfsDays: b.RetentionGfsDays,
RetentionGfsWeeks: b.RetentionGfsWeeks,
RetentionGfsMonths: b.RetentionGfsMonths,
RetentionGfsYears: b.RetentionGfsYears,
BackupIntervalID: uuid.Nil,
BackupInterval: b.BackupInterval.Copy(),
StorageID: b.StorageID,
SendNotificationsOn: b.SendNotificationsOn,
IsRetryIfFailed: b.IsRetryIfFailed,
MaxFailedTriesCount: b.MaxFailedTriesCount,
Encryption: b.Encryption,
MaxBackupSizeMB: b.MaxBackupSizeMB,
MaxBackupsTotalSizeMB: b.MaxBackupsTotalSizeMB,
DatabaseID: newDatabaseID,
IsBackupsEnabled: b.IsBackupsEnabled,
RetentionPolicyType: b.RetentionPolicyType,
RetentionTimePeriod: b.RetentionTimePeriod,
RetentionCount: b.RetentionCount,
RetentionGfsHours: b.RetentionGfsHours,
RetentionGfsDays: b.RetentionGfsDays,
RetentionGfsWeeks: b.RetentionGfsWeeks,
RetentionGfsMonths: b.RetentionGfsMonths,
RetentionGfsYears: b.RetentionGfsYears,
BackupIntervalID: uuid.Nil,
BackupInterval: b.BackupInterval.Copy(),
StorageID: b.StorageID,
SendNotificationsOn: b.SendNotificationsOn,
IsRetryIfFailed: b.IsRetryIfFailed,
MaxFailedTriesCount: b.MaxFailedTriesCount,
Encryption: b.Encryption,
}
}
func (b *BackupConfig) validateRetentionPolicy(plan *plans.DatabasePlan) error {
func (b *BackupConfig) validateRetentionPolicy() error {
switch b.RetentionPolicyType {
case RetentionPolicyTypeTimePeriod, "":
if b.RetentionTimePeriod == "" {
return errors.New("retention time period is required")
}
if plan.MaxStoragePeriod != period.PeriodForever {
if b.RetentionTimePeriod.CompareTo(plan.MaxStoragePeriod) > 0 {
return errors.New("storage period exceeds plan limit")
}
}
case RetentionPolicyTypeCount:
if b.RetentionCount <= 0 {
return errors.New("retention count must be greater than 0")

View File

@@ -6,248 +6,34 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
"databasus-backend/internal/features/intervals"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/util/period"
)
func Test_Validate_WhenRetentionTimePeriodIsWeekAndPlanAllowsMonth_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodWeek
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodMonth
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenRetentionTimePeriodIsYearAndPlanAllowsMonth_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodYear
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodMonth
err := config.Validate(plan)
assert.EqualError(t, err, "storage period exceeds plan limit")
}
func Test_Validate_WhenRetentionTimePeriodIsForeverAndPlanAllowsForever_ValidationPasses(
t *testing.T,
) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodForever
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodForever
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenRetentionTimePeriodIsForeverAndPlanAllowsYear_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodForever
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodYear
err := config.Validate(plan)
assert.EqualError(t, err, "storage period exceeds plan limit")
}
func Test_Validate_WhenRetentionTimePeriodEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodMonth
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodMonth
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenBackupSize100MBAndPlanAllows500MB_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 100
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 500
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenBackupSize500MBAndPlanAllows100MB_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 500
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 100
err := config.Validate(plan)
assert.EqualError(t, err, "max backup size exceeds plan limit")
}
func Test_Validate_WhenBackupSizeIsUnlimitedAndPlanAllowsUnlimited_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 0
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 0
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenBackupSizeIsUnlimitedAndPlanHas500MBLimit_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 0
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 500
err := config.Validate(plan)
assert.EqualError(t, err, "max backup size exceeds plan limit")
}
func Test_Validate_WhenBackupSizeEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 500
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 500
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenTotalSize1GBAndPlanAllows5GB_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 1000
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 5000
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenTotalSize5GBAndPlanAllows1GB_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 5000
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 1000
err := config.Validate(plan)
assert.EqualError(t, err, "max total backups size exceeds plan limit")
}
func Test_Validate_WhenTotalSizeIsUnlimitedAndPlanAllowsUnlimited_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 0
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 0
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenTotalSizeIsUnlimitedAndPlanHas1GBLimit_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 0
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 1000
err := config.Validate(plan)
assert.EqualError(t, err, "max total backups size exceeds plan limit")
}
func Test_Validate_WhenTotalSizeEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 5000
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 5000
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenAllLimitsAreUnlimitedInPlan_AnyConfigurationPasses(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodForever
config.MaxBackupSizeMB = 0
config.MaxBackupsTotalSizeMB = 0
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenMultipleLimitsExceeded_ValidationFailsWithFirstError(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodYear
config.MaxBackupSizeMB = 500
config.MaxBackupsTotalSizeMB = 5000
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodMonth
plan.MaxBackupSizeMB = 100
plan.MaxBackupsTotalSizeMB = 1000
err := config.Validate(plan)
assert.Error(t, err)
assert.EqualError(t, err, "storage period exceeds plan limit")
}
func Test_Validate_WhenConfigHasInvalidIntervalButPlanIsValid_ValidationFailsOnInterval(
t *testing.T,
) {
func Test_Validate_WhenIntervalIsMissing_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.BackupIntervalID = uuid.Nil
config.BackupInterval = nil
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.EqualError(t, err, "backup interval is required")
}
func Test_Validate_WhenIntervalIsMissing_ValidationFailsRegardlessOfPlan(t *testing.T) {
config := createValidBackupConfig()
config.BackupIntervalID = uuid.Nil
config.BackupInterval = nil
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "backup interval is required")
}
func Test_Validate_WhenRetryEnabledButMaxTriesIsZero_ValidationFailsRegardlessOfPlan(t *testing.T) {
func Test_Validate_WhenRetryEnabledButMaxTriesIsZero_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.IsRetryIfFailed = true
config.MaxFailedTriesCount = 0
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.EqualError(t, err, "max failed tries count must be greater than 0")
}
func Test_Validate_WhenEncryptionIsInvalid_ValidationFailsRegardlessOfPlan(t *testing.T) {
func Test_Validate_WhenEncryptionIsInvalid_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.Encryption = "INVALID"
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.EqualError(t, err, "encryption must be NONE or ENCRYPTED")
}
@@ -255,125 +41,16 @@ func Test_Validate_WhenRetentionTimePeriodIsEmpty_ValidationFails(t *testing.T)
config := createValidBackupConfig()
config.RetentionTimePeriod = ""
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.EqualError(t, err, "retention time period is required")
}
func Test_Validate_WhenMaxBackupSizeIsNegative_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = -100
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "max backup size must be non-negative")
}
func Test_Validate_WhenMaxTotalSizeIsNegative_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = -1000
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "max backups total size must be non-negative")
}
func Test_Validate_WhenPlanLimitsAreAtBoundary_ValidationWorks(t *testing.T) {
tests := []struct {
name string
configPeriod period.TimePeriod
planPeriod period.TimePeriod
configSize int64
planSize int64
configTotal int64
planTotal int64
shouldSucceed bool
}{
{
name: "all values just under limit",
configPeriod: period.PeriodWeek,
planPeriod: period.PeriodMonth,
configSize: 99,
planSize: 100,
configTotal: 999,
planTotal: 1000,
shouldSucceed: true,
},
{
name: "all values equal to limit",
configPeriod: period.PeriodMonth,
planPeriod: period.PeriodMonth,
configSize: 100,
planSize: 100,
configTotal: 1000,
planTotal: 1000,
shouldSucceed: true,
},
{
name: "period just over limit",
configPeriod: period.Period3Month,
planPeriod: period.PeriodMonth,
configSize: 100,
planSize: 100,
configTotal: 1000,
planTotal: 1000,
shouldSucceed: false,
},
{
name: "size just over limit",
configPeriod: period.PeriodMonth,
planPeriod: period.PeriodMonth,
configSize: 101,
planSize: 100,
configTotal: 1000,
planTotal: 1000,
shouldSucceed: false,
},
{
name: "total size just over limit",
configPeriod: period.PeriodMonth,
planPeriod: period.PeriodMonth,
configSize: 100,
planSize: 100,
configTotal: 1001,
planTotal: 1000,
shouldSucceed: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = tt.configPeriod
config.MaxBackupSizeMB = tt.configSize
config.MaxBackupsTotalSizeMB = tt.configTotal
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = tt.planPeriod
plan.MaxBackupSizeMB = tt.planSize
plan.MaxBackupsTotalSizeMB = tt.planTotal
err := config.Validate(plan)
if tt.shouldSucceed {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
})
}
}
func Test_Validate_WhenPolicyTypeIsCount_RequiresPositiveCount(t *testing.T) {
config := createValidBackupConfig()
config.RetentionPolicyType = RetentionPolicyTypeCount
config.RetentionCount = 0
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.EqualError(t, err, "retention count must be greater than 0")
}
@@ -382,9 +59,7 @@ func Test_Validate_WhenPolicyTypeIsCount_WithPositiveCount_ValidationPasses(t *t
config.RetentionPolicyType = RetentionPolicyTypeCount
config.RetentionCount = 10
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.NoError(t, err)
}
@@ -396,9 +71,7 @@ func Test_Validate_WhenPolicyTypeIsGFS_RequiresAtLeastOneField(t *testing.T) {
config.RetentionGfsMonths = 0
config.RetentionGfsYears = 0
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.EqualError(t, err, "at least one GFS retention field must be greater than 0")
}
@@ -407,9 +80,7 @@ func Test_Validate_WhenPolicyTypeIsGFS_WithOnlyHours_ValidationPasses(t *testing
config.RetentionPolicyType = RetentionPolicyTypeGFS
config.RetentionGfsHours = 24
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.NoError(t, err)
}
@@ -418,9 +89,7 @@ func Test_Validate_WhenPolicyTypeIsGFS_WithOnlyDays_ValidationPasses(t *testing.
config.RetentionPolicyType = RetentionPolicyTypeGFS
config.RetentionGfsDays = 7
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.NoError(t, err)
}
@@ -433,9 +102,7 @@ func Test_Validate_WhenPolicyTypeIsGFS_WithAllFields_ValidationPasses(t *testing
config.RetentionGfsMonths = 12
config.RetentionGfsYears = 3
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.NoError(t, err)
}
@@ -443,35 +110,59 @@ func Test_Validate_WhenPolicyTypeIsInvalid_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.RetentionPolicyType = "INVALID"
plan := createUnlimitedPlan()
err := config.Validate(plan)
err := config.Validate()
assert.EqualError(t, err, "invalid retention policy type")
}
func Test_Validate_WhenCloudAndEncryptionIsNotEncrypted_ValidationFails(t *testing.T) {
enableCloud(t)
backupConfig := createValidBackupConfig()
backupConfig.Encryption = BackupEncryptionNone
err := backupConfig.Validate()
assert.EqualError(t, err, "encryption is mandatory for cloud storage")
}
func Test_Validate_WhenCloudAndEncryptionIsEncrypted_ValidationPasses(t *testing.T) {
enableCloud(t)
backupConfig := createValidBackupConfig()
backupConfig.Encryption = BackupEncryptionEncrypted
err := backupConfig.Validate()
assert.NoError(t, err)
}
func Test_Validate_WhenNotCloudAndEncryptionIsNotEncrypted_ValidationPasses(t *testing.T) {
backupConfig := createValidBackupConfig()
backupConfig.Encryption = BackupEncryptionNone
err := backupConfig.Validate()
assert.NoError(t, err)
}
func enableCloud(t *testing.T) {
t.Helper()
config.GetEnv().IsCloud = true
t.Cleanup(func() {
config.GetEnv().IsCloud = false
})
}
func createValidBackupConfig() *BackupConfig {
intervalID := uuid.New()
return &BackupConfig{
DatabaseID: uuid.New(),
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodMonth,
BackupIntervalID: intervalID,
BackupInterval: &intervals.Interval{ID: intervalID},
SendNotificationsOn: []BackupNotificationType{},
IsRetryIfFailed: false,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 100,
MaxBackupsTotalSizeMB: 1000,
}
}
func createUnlimitedPlan() *plans.DatabasePlan {
return &plans.DatabasePlan{
DatabaseID: uuid.New(),
MaxBackupSizeMB: 0,
MaxBackupsTotalSizeMB: 0,
MaxStoragePeriod: period.PeriodForever,
return &BackupConfig{
DatabaseID: uuid.New(),
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodMonth,
BackupIntervalID: intervalID,
BackupInterval: &intervals.Interval{ID: intervalID},
SendNotificationsOn: []BackupNotificationType{},
IsRetryIfFailed: false,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
}
}

View File

@@ -62,6 +62,14 @@ func (r *BackupConfigRepository) FindByDatabaseID(databaseID uuid.UUID) (*Backup
GetDb().
Preload("BackupInterval").
Preload("Storage").
Preload("Storage.LocalStorage").
Preload("Storage.S3Storage").
Preload("Storage.GoogleDriveStorage").
Preload("Storage.NASStorage").
Preload("Storage.AzureBlobStorage").
Preload("Storage.FTPStorage").
Preload("Storage.SFTPStorage").
Preload("Storage.RcloneStorage").
Where("database_id = ?", databaseID).
First(&backupConfig).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -81,6 +89,14 @@ func (r *BackupConfigRepository) GetWithEnabledBackups() ([]*BackupConfig, error
GetDb().
Preload("BackupInterval").
Preload("Storage").
Preload("Storage.LocalStorage").
Preload("Storage.S3Storage").
Preload("Storage.GoogleDriveStorage").
Preload("Storage.NASStorage").
Preload("Storage.AzureBlobStorage").
Preload("Storage.FTPStorage").
Preload("Storage.SFTPStorage").
Preload("Storage.RcloneStorage").
Where("is_backups_enabled = ?", true).
Find(&backupConfigs).Error; err != nil {
return nil, err

View File

@@ -8,10 +8,10 @@ import (
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/features/storages"
users_models "databasus-backend/internal/features/users/models"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/period"
)
type BackupConfigService struct {
@@ -20,7 +20,6 @@ type BackupConfigService struct {
storageService *storages.StorageService
notifierService *notifiers.NotifierService
workspaceService *workspaces_services.WorkspaceService
databasePlanService *plans.DatabasePlanService
dbStorageChangeListener BackupConfigStorageChangeListener
}
@@ -46,12 +45,7 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
user *users_models.User,
backupConfig *BackupConfig,
) (*BackupConfig, error) {
plan, err := s.databasePlanService.GetDatabasePlan(backupConfig.DatabaseID)
if err != nil {
return nil, err
}
if err := backupConfig.Validate(plan); err != nil {
if err := backupConfig.Validate(); err != nil {
return nil, err
}
@@ -88,12 +82,7 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
func (s *BackupConfigService) SaveBackupConfig(
backupConfig *BackupConfig,
) (*BackupConfig, error) {
plan, err := s.databasePlanService.GetDatabasePlan(backupConfig.DatabaseID)
if err != nil {
return nil, err
}
if err := backupConfig.Validate(plan); err != nil {
if err := backupConfig.Validate(); err != nil {
return nil, err
}
@@ -131,18 +120,6 @@ func (s *BackupConfigService) GetBackupConfigByDbIdWithAuth(
return s.GetBackupConfigByDbId(databaseID)
}
func (s *BackupConfigService) GetDatabasePlan(
user *users_models.User,
databaseID uuid.UUID,
) (*plans.DatabasePlan, error) {
_, err := s.databaseService.GetDatabase(user, databaseID)
if err != nil {
return nil, err
}
return s.databasePlanService.GetDatabasePlan(databaseID)
}
func (s *BackupConfigService) GetBackupConfigByDbId(
databaseID uuid.UUID,
) (*BackupConfig, error) {
@@ -322,20 +299,13 @@ func (s *BackupConfigService) TransferDatabaseToWorkspace(
func (s *BackupConfigService) initializeDefaultConfig(
databaseID uuid.UUID,
) error {
plan, err := s.databasePlanService.GetDatabasePlan(databaseID)
if err != nil {
return err
}
timeOfDay := "04:00"
_, err = s.backupConfigRepository.Save(&BackupConfig{
DatabaseID: databaseID,
IsBackupsEnabled: false,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: plan.MaxStoragePeriod,
MaxBackupSizeMB: plan.MaxBackupSizeMB,
MaxBackupsTotalSizeMB: plan.MaxBackupsTotalSizeMB,
_, err := s.backupConfigRepository.Save(&BackupConfig{
DatabaseID: databaseID,
IsBackupsEnabled: false,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.Period3Month,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,

View File

@@ -0,0 +1,305 @@
package billing
import (
"errors"
"net/http"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
users_middleware "databasus-backend/internal/features/users/middleware"
"databasus-backend/internal/util/logger"
)
type BillingController struct {
billingService *BillingService
}
func (c *BillingController) RegisterRoutes(router *gin.RouterGroup) {
billing := router.Group("/billing")
billing.POST("/subscription", c.CreateSubscription)
billing.POST("/subscription/change-storage", c.ChangeSubscriptionStorage)
billing.POST("/subscription/portal/:subscription_id", c.GetPortalSession)
billing.GET("/subscription/events/:subscription_id", c.GetSubscriptionEvents)
billing.GET("/subscription/invoices/:subscription_id", c.GetInvoices)
billing.GET("/subscription/:database_id", c.GetSubscription)
}
// CreateSubscription
// @Summary Create a new subscription
// @Description Create a billing subscription for the specified database with the given storage
// @Tags billing
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body CreateSubscriptionRequest true "Subscription creation data"
// @Success 200 {object} CreateSubscriptionResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /billing/subscription [post]
func (c *BillingController) CreateSubscription(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(401, gin.H{"error": "User not authenticated"})
return
}
var request CreateSubscriptionRequest
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(400, gin.H{"error": "Invalid request"})
return
}
log := logger.GetLogger().With(
"request_id", uuid.New(),
"database_id", request.DatabaseID,
"user_id", user.ID,
)
transactionID, err := c.billingService.CreateSubscription(
log,
user,
request.DatabaseID,
request.StorageGB,
)
if err != nil {
log.Error("Failed to create subscription", "error", err)
ctx.JSON(500, gin.H{"error": "Failed to create subscription"})
return
}
ctx.JSON(200, CreateSubscriptionResponse{PaddleTransactionID: transactionID})
}
// ChangeSubscriptionStorage
// @Summary Change subscription storage
// @Description Update the storage allocation for an existing subscription
// @Tags billing
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body ChangeStorageRequest true "New storage configuration"
// @Success 200 {object} ChangeStorageResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /billing/subscription/change-storage [post]
func (c *BillingController) ChangeSubscriptionStorage(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(401, gin.H{"error": "User not authenticated"})
return
}
var request ChangeStorageRequest
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(400, gin.H{"error": "Invalid request"})
return
}
log := logger.GetLogger().With(
"request_id", uuid.New(),
"database_id", request.DatabaseID,
"user_id", user.ID,
)
result, err := c.billingService.ChangeSubscriptionStorage(log, user, request.DatabaseID, request.StorageGB)
if err != nil {
log.Error("Failed to change subscription storage", "error", err)
ctx.JSON(500, gin.H{"error": "Failed to change subscription storage"})
return
}
ctx.JSON(200, ChangeStorageResponse{
ApplyMode: result.ApplyMode,
CurrentGB: result.CurrentGB,
PendingGB: result.PendingGB,
})
}
// GetPortalSession
// @Summary Get billing portal session
// @Description Generate a portal session URL for managing the subscription
// @Tags billing
// @Produce json
// @Security BearerAuth
// @Param subscription_id path string true "Subscription ID"
// @Success 200 {object} GetPortalSessionResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /billing/subscription/portal/{subscription_id} [post]
func (c *BillingController) GetPortalSession(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(401, gin.H{"error": "User not authenticated"})
return
}
subscriptionID := ctx.Param("subscription_id")
if subscriptionID == "" {
ctx.JSON(400, gin.H{"error": "Subscription ID is required"})
return
}
log := logger.GetLogger().With(
"request_id", uuid.New(),
"subscription_id", subscriptionID,
"user_id", user.ID,
)
url, err := c.billingService.GetPortalURL(log, user, uuid.MustParse(subscriptionID))
if err != nil {
log.Error("Failed to get portal session", "error", err)
ctx.JSON(500, gin.H{"error": "Failed to get portal session"})
return
}
ctx.JSON(200, GetPortalSessionResponse{PortalURL: url})
}
// GetSubscriptionEvents
// @Summary Get subscription events
// @Description Retrieve the event history for a subscription
// @Tags billing
// @Produce json
// @Security BearerAuth
// @Param subscription_id path string true "Subscription ID"
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Success 200 {object} GetSubscriptionEventsResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /billing/subscription/events/{subscription_id} [get]
func (c *BillingController) GetSubscriptionEvents(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(401, gin.H{"error": "User not authenticated"})
return
}
subscriptionID, err := uuid.Parse(ctx.Param("subscription_id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid subscription ID"})
return
}
var request PaginatedRequest
if err := ctx.ShouldBindQuery(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
log := logger.GetLogger().With(
"request_id", uuid.New(),
"subscription_id", subscriptionID,
"user_id", user.ID,
)
response, err := c.billingService.GetSubscriptionEvents(log, user, subscriptionID, request.Limit, request.Offset)
if err != nil {
log.Error("Failed to get subscription events", "error", err)
ctx.JSON(500, gin.H{"error": "Failed to get subscription events"})
return
}
ctx.JSON(200, response)
}
// GetInvoices
// @Summary Get subscription invoices
// @Description Retrieve all invoices for a subscription
// @Tags billing
// @Produce json
// @Security BearerAuth
// @Param subscription_id path string true "Subscription ID"
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Success 200 {object} GetInvoicesResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /billing/subscription/invoices/{subscription_id} [get]
func (c *BillingController) GetInvoices(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(401, gin.H{"error": "User not authenticated"})
return
}
subscriptionID, err := uuid.Parse(ctx.Param("subscription_id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid subscription ID"})
return
}
var request PaginatedRequest
if err := ctx.ShouldBindQuery(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
log := logger.GetLogger().With(
"request_id", uuid.New(),
"subscription_id", subscriptionID,
"user_id", user.ID,
)
response, err := c.billingService.GetSubscriptionInvoices(log, user, subscriptionID, request.Limit, request.Offset)
if err != nil {
log.Error("Failed to get invoices", "error", err)
ctx.JSON(500, gin.H{"error": "Failed to get invoices"})
return
}
ctx.JSON(200, response)
}
// GetSubscription
// @Summary Get subscription by database
// @Description Retrieve the subscription associated with a specific database
// @Tags billing
// @Produce json
// @Security BearerAuth
// @Param database_id path string true "Database ID"
// @Success 200 {object} billing_models.Subscription
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /billing/subscription/{database_id} [get]
func (c *BillingController) GetSubscription(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(401, gin.H{"error": "User not authenticated"})
return
}
databaseID, err := uuid.Parse(ctx.Param("database_id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid database ID"})
return
}
log := logger.GetLogger().With(
"request_id", uuid.New(),
"database_id", databaseID,
"user_id", user.ID,
)
subscription, err := c.billingService.GetSubscriptionByDatabaseID(log, user, databaseID)
if err != nil {
if errors.Is(err, ErrSubscriptionNotFound) {
ctx.JSON(http.StatusNotFound, gin.H{"error": "Subscription not found"})
return
}
log.Error("failed to get subscription", "error", err)
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get subscription"})
return
}
ctx.JSON(200, subscription)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,35 @@
package billing
import (
"sync"
"sync/atomic"
billing_repositories "databasus-backend/internal/features/billing/repositories"
"databasus-backend/internal/features/databases"
workspaces_services "databasus-backend/internal/features/workspaces/services"
)
var (
billingService = &BillingService{
&billing_repositories.SubscriptionRepository{},
&billing_repositories.SubscriptionEventRepository{},
&billing_repositories.InvoiceRepository{},
nil, // billing provider will be set later to avoid circular dependency
workspaces_services.GetWorkspaceService(),
*databases.GetDatabaseService(),
atomic.Bool{},
}
billingController = &BillingController{billingService}
)
func GetBillingService() *BillingService {
return billingService
}
func GetBillingController() *BillingController {
return billingController
}
var SetupDependencies = sync.OnceFunc(func() {
databases.GetDatabaseService().AddDbCreationListener(billingService)
})

View File

@@ -0,0 +1,67 @@
package billing
import (
"github.com/google/uuid"
billing_models "databasus-backend/internal/features/billing/models"
)
type CreateSubscriptionRequest struct {
DatabaseID uuid.UUID `json:"databaseId" validate:"required"`
StorageGB int `json:"storageGb" validate:"required,min=1"`
}
type CreateSubscriptionResponse struct {
PaddleTransactionID string `json:"paddleTransactionId"`
}
type ChangeStorageApplyMode string
const (
ChangeStorageApplyImmediate ChangeStorageApplyMode = "immediate"
ChangeStorageApplyNextCycle ChangeStorageApplyMode = "next_cycle"
)
type ChangeStorageRequest struct {
DatabaseID uuid.UUID `json:"databaseId" validate:"required"`
StorageGB int `json:"storageGb" validate:"required,min=1"`
}
type ChangeStorageResponse struct {
ApplyMode ChangeStorageApplyMode `json:"applyMode"`
CurrentGB int `json:"currentGb"`
PendingGB *int `json:"pendingGb,omitempty"`
}
type PortalResponse struct {
URL string `json:"url"`
}
type ChangeStorageResult struct {
ApplyMode ChangeStorageApplyMode
CurrentGB int
PendingGB *int
}
type GetPortalSessionResponse struct {
PortalURL string `json:"url"`
}
type PaginatedRequest struct {
Limit int `form:"limit" json:"limit"`
Offset int `form:"offset" json:"offset"`
}
type GetSubscriptionEventsResponse struct {
Events []*billing_models.SubscriptionEvent `json:"events"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}
type GetInvoicesResponse struct {
Invoices []*billing_models.Invoice `json:"invoices"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}

View File

@@ -0,0 +1,15 @@
package billing
import "errors"
var (
ErrInvalidStorage = errors.New("storage must be between 20 and 10000 GB")
ErrAlreadySubscribed = errors.New("database already has an active subscription")
ErrExceedsUsage = errors.New("cannot downgrade below current storage usage")
ErrNoChange = errors.New("requested storage is the same as current")
ErrDuplicate = errors.New("duplicate event already processed")
ErrProviderUnavailable = errors.New("payment provider unavailable")
ErrNoActiveSubscription = errors.New("no active subscription for this database")
ErrAccessDenied = errors.New("user does not have access to this database")
ErrSubscriptionNotFound = errors.New("subscription not found")
)

View File

@@ -0,0 +1,24 @@
package billing_models
import (
"time"
"github.com/google/uuid"
)
type Invoice struct {
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
SubscriptionID uuid.UUID `json:"subscriptionId" gorm:"column:subscription_id;type:uuid;not null"`
ProviderInvoiceID string `json:"providerInvoiceId" gorm:"column:provider_invoice_id;type:text;not null"`
AmountCents int64 `json:"amountCents" gorm:"column:amount_cents;type:bigint;not null"`
StorageGB int `json:"storageGb" gorm:"column:storage_gb;type:int;not null"`
PeriodStart time.Time `json:"periodStart" gorm:"column:period_start;type:timestamptz;not null"`
PeriodEnd time.Time `json:"periodEnd" gorm:"column:period_end;type:timestamptz;not null"`
Status InvoiceStatus `json:"status" gorm:"column:status;type:text;not null"`
PaidAt *time.Time `json:"paidAt,omitzero" gorm:"column:paid_at;type:timestamptz"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;type:timestamptz;not null"`
}
func (Invoice) TableName() string {
return "invoices"
}

View File

@@ -0,0 +1,11 @@
package billing_models
type InvoiceStatus string
const (
InvoiceStatusPending InvoiceStatus = "pending"
InvoiceStatusPaid InvoiceStatus = "paid"
InvoiceStatusFailed InvoiceStatus = "failed"
InvoiceStatusRefunded InvoiceStatus = "refunded"
InvoiceStatusDisputed InvoiceStatus = "disputed"
)

View File

@@ -0,0 +1,72 @@
package billing_models
import (
"time"
"github.com/google/uuid"
"databasus-backend/internal/config"
)
type Subscription struct {
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;not null"`
Status SubscriptionStatus `json:"status" gorm:"column:status;type:text;not null"`
StorageGB int `json:"storageGb" gorm:"column:storage_gb;type:int;not null"`
PendingStorageGB *int `json:"pendingStorageGb,omitempty" gorm:"column:pending_storage_gb;type:int"`
CurrentPeriodStart time.Time `json:"currentPeriodStart" gorm:"column:current_period_start;type:timestamptz;not null"`
CurrentPeriodEnd time.Time `json:"currentPeriodEnd" gorm:"column:current_period_end;type:timestamptz;not null"`
CanceledAt *time.Time `json:"canceledAt,omitzero" gorm:"column:canceled_at;type:timestamptz"`
DataRetentionGracePeriodUntil *time.Time `json:"dataRetentionGracePeriodUntil,omitzero" gorm:"column:data_retention_grace_period_until;type:timestamptz"`
ProviderName *string `json:"providerName,omitempty" gorm:"column:provider_name;type:text"`
ProviderSubID *string `json:"providerSubId,omitempty" gorm:"column:provider_sub_id;type:text"`
ProviderCustomerID *string `json:"providerCustomerId,omitempty" gorm:"column:provider_customer_id;type:text"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;type:timestamptz;not null"`
UpdatedAt time.Time `json:"updatedAt" gorm:"column:updated_at;type:timestamptz;not null"`
}
func (Subscription) TableName() string {
return "subscriptions"
}
func (s *Subscription) PriceCents() int64 {
return int64(s.StorageGB) * config.GetEnv().PricePerGBCents
}
// CanCreateNewBackups - whether it is allowed to create new backups
// by scheduler or for user manually. Clarification: in grace period
// user can download, delete and restore backups, but cannot create new ones
func (s *Subscription) CanCreateNewBackups() bool {
switch s.Status {
case StatusActive, StatusPastDue:
return true
case StatusTrial, StatusCanceled:
return time.Now().Before(s.CurrentPeriodEnd)
case StatusExpired:
return false
default:
panic("unknown subscription status")
}
}
func (s *Subscription) GetBackupsStorageGB() int {
switch s.Status {
case StatusActive, StatusPastDue, StatusCanceled:
return s.StorageGB
case StatusTrial:
if time.Now().Before(s.CurrentPeriodEnd) {
return s.StorageGB
}
return 0
case StatusExpired:
return 0
default:
panic("unknown subscription status")
}
}

View File

@@ -0,0 +1,25 @@
package billing_models
import (
"time"
"github.com/google/uuid"
)
type SubscriptionEvent struct {
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
SubscriptionID uuid.UUID `json:"subscriptionId" gorm:"column:subscription_id;type:uuid;not null"`
ProviderEventID *string `json:"providerEventId,omitempty" gorm:"column:provider_event_id;type:text"`
Type SubscriptionEventType `json:"type" gorm:"column:type;type:text;not null"`
OldStorageGB *int `json:"oldStorageGb,omitempty" gorm:"column:old_storage_gb;type:int"`
NewStorageGB *int `json:"newStorageGb,omitempty" gorm:"column:new_storage_gb;type:int"`
OldStatus *SubscriptionStatus `json:"oldStatus,omitempty" gorm:"column:old_status;type:text"`
NewStatus *SubscriptionStatus `json:"newStatus,omitempty" gorm:"column:new_status;type:text"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;type:timestamptz;not null"`
}
func (SubscriptionEvent) TableName() string {
return "subscription_events"
}

View File

@@ -0,0 +1,17 @@
package billing_models
type SubscriptionEventType string
const (
EventCreated SubscriptionEventType = "subscription.created"
EventUpgraded SubscriptionEventType = "subscription.upgraded"
EventDowngraded SubscriptionEventType = "subscription.downgraded"
EventNewBillingCycleStarted SubscriptionEventType = "subscription.new_billing_cycle_started"
EventCanceled SubscriptionEventType = "subscription.canceled"
EventReactivated SubscriptionEventType = "subscription.reactivated"
EventExpired SubscriptionEventType = "subscription.expired"
EventPastDue SubscriptionEventType = "subscription.past_due"
EventRecoveredFromPastDue SubscriptionEventType = "subscription.recovered_from_past_due"
EventRefund SubscriptionEventType = "payment.refund"
EventDispute SubscriptionEventType = "payment.dispute"
)

View File

@@ -0,0 +1,11 @@
package billing_models
type SubscriptionStatus string
const (
StatusTrial SubscriptionStatus = "trial" // trial period (~24h after DB creation)
StatusActive SubscriptionStatus = "active" // paid, everything works
StatusPastDue SubscriptionStatus = "past_due" // payment failed, trying to charge again, but everything still works
StatusCanceled SubscriptionStatus = "canceled" // subscription canceled by user or after past_due (grace period is active)
StatusExpired SubscriptionStatus = "expired" // grace period ended, data marked for deletion, can come from canceled and trial
)

View File

@@ -0,0 +1,22 @@
package billing_models
import (
"time"
"github.com/google/uuid"
)
type WebhookEvent struct {
RequestID uuid.UUID
ProviderEventID string
DatabaseID *uuid.UUID
Type WebhookEventType
ProviderSubscriptionID string
ProviderCustomerID string
ProviderInvoiceID string
QuantityGB int
Status SubscriptionStatus
PeriodStart *time.Time
PeriodEnd *time.Time
AmountCents int64
}

View File

@@ -0,0 +1,13 @@
package billing_models
type WebhookEventType string
const (
WHEventSubscriptionCreated WebhookEventType = "subscription.created"
WHEventSubscriptionUpdated WebhookEventType = "subscription.updated"
WHEventSubscriptionCanceled WebhookEventType = "subscription.canceled"
WHEventSubscriptionPastDue WebhookEventType = "subscription.past_due"
WHEventSubscriptionReactivated WebhookEventType = "subscription.reactivated"
WHEventPaymentSucceeded WebhookEventType = "payment.succeeded"
WHEventSubscriptionDisputeCreated WebhookEventType = "dispute.created"
)

View File

@@ -0,0 +1,5 @@
**Paddle hints:**
- **max_quantity on price:** Paddle limits `quantity` on a price to 100 by default. You need to explicitly set the range (`quantity: {minimum: 20, maximum: 10000}`) when creating a price via API or dashboard. Otherwise requests with quantity > 100 will return an error.
- **Full items list on update:** Unlike Stripe, Paddle requires sending **all** subscription items in `PATCH /subscriptions/{id}`, not just the changed ones. `proration_billing_mode` is also required. Without this you can accidentally remove a line item or get a 400.
- **Webhook events mapping:** Paddle uses `transaction.completed` instead of `payment.succeeded`, `transaction.payment_failed` instead of `payment.failed`, `adjustment.created` instead of `dispute.created`.

View File

@@ -0,0 +1,83 @@
package billing_paddle
import (
"encoding/json"
"errors"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
billing_webhooks "databasus-backend/internal/features/billing/webhooks"
"databasus-backend/internal/util/logger"
)
type PaddleBillingController struct {
paddleBillingService *PaddleBillingService
}
func (c *PaddleBillingController) RegisterPublicRoutes(router *gin.RouterGroup) {
router.POST("/billing/paddle/webhook", c.HandlePaddleWebhook)
}
// HandlePaddleWebhook
// @Summary Handle Paddle webhook
// @Description Process incoming webhook events from Paddle payment provider
// @Tags billing
// @Accept json
// @Produce json
// @Param Paddle-Signature header string true "Paddle webhook signature"
// @Success 200
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 500
// @Router /billing/paddle/webhook [post]
func (c *PaddleBillingController) HandlePaddleWebhook(ctx *gin.Context) {
requestID := uuid.New()
log := logger.GetLogger().With("request_id", requestID)
body, err := io.ReadAll(io.LimitReader(ctx.Request.Body, 1<<20))
if err != nil {
log.Error("failed to read webhook request body", "error", err)
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Failed to read request body"})
return
}
headers := make(map[string]string)
for k := range ctx.Request.Header {
headers[k] = ctx.GetHeader(k)
}
if err := c.paddleBillingService.VerifyWebhookSignature(body, headers); err != nil {
log.Warn("paddle webhook signature verification failed", "error", err)
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid webhook signature"})
return
}
var webhookDTO PaddleWebhookDTO
if err := json.Unmarshal(body, &webhookDTO); err != nil {
log.Error("failed to unmarshal webhook payload", "error", err)
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid webhook payload"})
return
}
log = log.With(
"provider_event_id", webhookDTO.EventID,
"event_type", webhookDTO.EventType,
)
if err := c.paddleBillingService.ProcessWebhookEvent(log, requestID, webhookDTO, body); err != nil {
if errors.Is(err, billing_webhooks.ErrDuplicateWebhook) {
log.Info("duplicate webhook event, returning 200 to not force retry")
ctx.Status(http.StatusOK)
return
}
log.Error("Failed to process paddle webhook", "error", err)
ctx.Status(http.StatusInternalServerError)
return
}
ctx.Status(http.StatusOK)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,72 @@
package billing_paddle
import (
"sync"
"github.com/PaddleHQ/paddle-go-sdk"
"databasus-backend/internal/config"
"databasus-backend/internal/features/billing"
billing_webhooks "databasus-backend/internal/features/billing/webhooks"
)
var (
paddleBillingService *PaddleBillingService
paddleBillingController *PaddleBillingController
)
var initPaddle = sync.OnceFunc(func() {
if config.GetEnv().IsPaddleSandbox {
paddleClient, err := paddle.NewSandbox(config.GetEnv().PaddleApiKey)
if err != nil {
return
}
paddleBillingService = &PaddleBillingService{
paddleClient,
paddle.NewWebhookVerifier(config.GetEnv().PaddleWebhookSecret),
config.GetEnv().PaddlePriceID,
billing_webhooks.WebhookRepository{},
billing.GetBillingService(),
}
} else {
paddleClient, err := paddle.New(config.GetEnv().PaddleApiKey)
if err != nil {
return
}
paddleBillingService = &PaddleBillingService{
paddleClient,
paddle.NewWebhookVerifier(config.GetEnv().PaddleWebhookSecret),
config.GetEnv().PaddlePriceID,
billing_webhooks.WebhookRepository{},
billing.GetBillingService(),
}
}
paddleBillingController = &PaddleBillingController{paddleBillingService}
})
func GetPaddleBillingService() *PaddleBillingService {
if !config.GetEnv().IsCloud {
return nil
}
initPaddle()
return paddleBillingService
}
func GetPaddleBillingController() *PaddleBillingController {
if !config.GetEnv().IsCloud {
return nil
}
// Ensure service + controller are initialized
GetPaddleBillingService()
return paddleBillingController
}
func SetupDependencies() {
billing.GetBillingService().SetBillingProvider(GetPaddleBillingService())
}

View File

@@ -0,0 +1,9 @@
package billing_paddle
import "encoding/json"
type PaddleWebhookDTO struct {
EventID string `json:"event_id"`
EventType string `json:"event_type"`
Data json.RawMessage
}

View File

@@ -0,0 +1,50 @@
package billing_paddle
import "time"
type TestSubscriptionCreatedPayload struct {
EventID string
SubID string
CustomerID string
DatabaseID string
QuantityGB int
PeriodStart time.Time
PeriodEnd time.Time
}
type TestSubscriptionUpdatedPayload struct {
EventID string
SubID string
CustomerID string
QuantityGB int
PeriodStart time.Time
PeriodEnd time.Time
HasScheduledChange bool
ScheduledChangeAction string
}
type TestSubscriptionCanceledPayload struct {
EventID string
SubID string
CustomerID string
}
type TestTransactionCompletedPayload struct {
EventID string
TxnID string
SubID string
CustomerID string
TotalCents int64
QuantityGB int
PeriodStart time.Time
PeriodEnd time.Time
}
type TestSubscriptionPastDuePayload struct {
EventID string
SubID string
CustomerID string
QuantityGB int
PeriodStart time.Time
PeriodEnd time.Time
}

View File

@@ -0,0 +1,638 @@
package billing_paddle
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strconv"
"time"
"github.com/PaddleHQ/paddle-go-sdk"
"github.com/google/uuid"
"databasus-backend/internal/features/billing"
billing_models "databasus-backend/internal/features/billing/models"
billing_provider "databasus-backend/internal/features/billing/provider"
billing_webhooks "databasus-backend/internal/features/billing/webhooks"
)
type PaddleBillingService struct {
client *paddle.SDK
webhookVerified *paddle.WebhookVerifier
priceID string
webhookRepository billing_webhooks.WebhookRepository
billingService *billing.BillingService
}
func (s *PaddleBillingService) GetProviderName() billing_provider.ProviderName {
return billing_provider.ProviderPaddle
}
func (s *PaddleBillingService) CreateCheckoutSession(
logger *slog.Logger,
request billing_provider.CheckoutRequest,
) (string, error) {
logger = logger.With("database_id", request.DatabaseID)
logger.Debug(fmt.Sprintf("paddle: creating checkout session for %d GB", request.StorageGB))
txRequest := &paddle.CreateTransactionRequest{
Items: []paddle.CreateTransactionItems{
*paddle.NewCreateTransactionItemsCatalogItem(&paddle.CatalogItem{
PriceID: s.priceID,
Quantity: request.StorageGB,
}),
},
CustomData: paddle.CustomData{"database_id": request.DatabaseID.String()},
Checkout: &paddle.TransactionCheckout{
URL: &request.SuccessURL,
},
}
tx, err := s.client.CreateTransaction(context.Background(), txRequest)
if err != nil {
logger.Error("paddle: failed to create transaction", "error", err)
return "", err
}
return tx.ID, nil
}
func (s *PaddleBillingService) UpgradeQuantityWithSurcharge(
logger *slog.Logger,
providerSubscriptionID string,
quantityGB int,
) error {
logger = logger.With("provider_subscription_id", providerSubscriptionID)
logger.Debug(fmt.Sprintf("paddle: applying upgrade: new storage %d GB", quantityGB))
// important: paddle requires to send all items
// in the subscription when updating, not just the changed one
subscription, err := s.client.GetSubscription(context.Background(), &paddle.GetSubscriptionRequest{
SubscriptionID: providerSubscriptionID,
})
if err != nil {
logger.Error("paddle: failed to get subscription", "error", err)
return err
}
currentQuantity := subscription.Items[0].Quantity
if currentQuantity == quantityGB {
logger.Info("paddle: subscription already at requested quantity, skipping upgrade",
"current_quantity_gb", currentQuantity,
"requested_quantity_gb", quantityGB,
)
return nil
}
priceID := subscription.Items[0].Price.ID
_, err = s.client.UpdateSubscription(context.Background(), &paddle.UpdateSubscriptionRequest{
SubscriptionID: providerSubscriptionID,
Items: paddle.NewPatchField([]paddle.SubscriptionUpdateCatalogItem{
{
PriceID: priceID,
Quantity: quantityGB,
},
}),
ProrationBillingMode: paddle.NewPatchField(paddle.ProrationBillingModeProratedImmediately),
})
if err != nil {
logger.Error("paddle: failed to update subscription", "error", err)
return err
}
logger.Debug("paddle: successfully applied upgrade")
return nil
}
func (s *PaddleBillingService) ScheduleQuantityDowngradeFromNextBillingCycle(
logger *slog.Logger,
providerSubscriptionID string,
quantityGB int,
) error {
logger = logger.With("provider_subscription_id", providerSubscriptionID)
logger.Debug(fmt.Sprintf("paddle: scheduling downgrade from next billing cycle: new storage %d GB", quantityGB))
// important: paddle requires to send all items
// in the subscription when updating, not just the changed one
subscription, err := s.client.GetSubscription(context.Background(), &paddle.GetSubscriptionRequest{
SubscriptionID: providerSubscriptionID,
})
if err != nil {
logger.Error("paddle: failed to get subscription", "error", err)
return err
}
currentQuantity := subscription.Items[0].Quantity
if currentQuantity == quantityGB {
logger.Info("paddle: subscription already at requested quantity, skipping downgrade",
"current_quantity_gb", currentQuantity,
"requested_quantity_gb", quantityGB,
)
return nil
}
if subscription.ScheduledChange != nil {
logger.Info("paddle: subscription already has a scheduled change, skipping downgrade")
return nil
}
priceID := subscription.Items[0].Price.ID
// apply downgrade from next billing cycle by setting the proration billing mode to "prorate on next billing period"
_, err = s.client.UpdateSubscription(context.Background(), &paddle.UpdateSubscriptionRequest{
SubscriptionID: providerSubscriptionID,
Items: paddle.NewPatchField([]paddle.SubscriptionUpdateCatalogItem{
{
PriceID: priceID,
Quantity: quantityGB,
},
}),
ProrationBillingMode: paddle.NewPatchField(paddle.ProrationBillingModeFullNextBillingPeriod),
})
if err != nil {
logger.Error("paddle: failed to update subscription for downgrade", "error", err)
return fmt.Errorf("failed to update subscription: %w", err)
}
logger.Debug("paddle: successfully scheduled downgrade from next billing cycle")
return nil
}
func (s *PaddleBillingService) GetSubscription(
logger *slog.Logger,
providerSubscriptionID string,
) (billing_provider.ProviderSubscription, error) {
logger = logger.With("provider_subscription_id", providerSubscriptionID)
logger.Debug("paddle: getting subscription details")
subscription, err := s.client.GetSubscription(context.Background(), &paddle.GetSubscriptionRequest{
SubscriptionID: providerSubscriptionID,
})
if err != nil {
logger.Error("paddle: failed to get subscription", "error", err)
return billing_provider.ProviderSubscription{}, err
}
logger.Debug(
fmt.Sprintf(
"paddle: successfully got subscription details: status=%s, quantity=%d",
subscription.Status,
subscription.Items[0].Quantity,
),
)
return s.toProviderSubscription(logger, subscription)
}
func (s *PaddleBillingService) CreatePortalSession(
logger *slog.Logger,
providerCustomerID, returnURL string,
) (string, error) {
logger = logger.With("provider_customer_id", providerCustomerID)
logger.Debug("paddle: creating portal session")
subscriptions, err := s.client.ListSubscriptions(context.Background(), &paddle.ListSubscriptionsRequest{
CustomerID: []string{providerCustomerID},
Status: []string{
string(paddle.SubscriptionStatusActive),
string(paddle.SubscriptionStatusPastDue),
},
})
if err != nil {
logger.Error("paddle: failed to list subscriptions for portal session", "error", err)
return "", err
}
res := subscriptions.Next(context.Background())
if !res.Ok() {
if res.Err() != nil {
logger.Error("paddle: failed to iterate subscriptions", "error", res.Err())
return "", res.Err()
}
logger.Error("paddle: no active subscriptions found for customer")
return "", fmt.Errorf("no active subscriptions found for customer %s", providerCustomerID)
}
subscription := res.Value()
if subscription.ManagementURLs.UpdatePaymentMethod == nil {
logger.Error("paddle: subscription has no management URL")
return "", fmt.Errorf("subscription %s has no management URL", subscription.ID)
}
return *subscription.ManagementURLs.UpdatePaymentMethod, nil
}
func (s *PaddleBillingService) VerifyWebhookSignature(body []byte, headers map[string]string) error {
req, _ := http.NewRequestWithContext(context.Background(), "POST", "/", bytes.NewReader(body))
for k, v := range headers {
req.Header.Set(k, v)
}
ok, err := s.webhookVerified.Verify(req)
if err != nil || !ok {
return fmt.Errorf("failed to verify webhook signature: %w", err)
}
return nil
}
func (s *PaddleBillingService) ProcessWebhookEvent(
logger *slog.Logger,
requestID uuid.UUID,
webhookDTO PaddleWebhookDTO,
rawBody []byte,
) error {
webhookEvent, err := s.normalizeWebhookEvent(
logger,
requestID,
webhookDTO.EventID,
webhookDTO.EventType,
webhookDTO.Data,
)
if err != nil {
if errors.Is(err, billing_webhooks.ErrUnsupportedEventType) {
return s.skipWebhookEvent(logger, requestID, webhookDTO, rawBody)
}
logger.Error("paddle: failed to normalize webhook event", "error", err)
return err
}
logArgs := []any{
"provider_event_id", webhookEvent.ProviderEventID,
"provider_subscription_id", webhookEvent.ProviderSubscriptionID,
"provider_customer_id", webhookEvent.ProviderCustomerID,
}
if webhookEvent.DatabaseID != nil {
logArgs = append(logArgs, "database_id", webhookEvent.DatabaseID)
}
logger = logger.With(logArgs...)
existingRecord, err := s.webhookRepository.FindSuccessfulByProviderEventID(webhookEvent.ProviderEventID)
if err == nil && existingRecord != nil {
logger.Info("paddle: webhook already processed successfully, skipping",
"existing_request_id", existingRecord.RequestID,
)
return billing_webhooks.ErrDuplicateWebhook
}
webhookRecord := &billing_webhooks.WebhookRecord{
RequestID: requestID,
ProviderName: billing_provider.ProviderPaddle,
EventType: string(webhookEvent.Type),
ProviderEventID: webhookEvent.ProviderEventID,
RawPayload: string(rawBody),
}
if err := s.webhookRepository.Insert(webhookRecord); err != nil {
logger.Error("paddle: failed to save webhook record", "error", err)
return err
}
if err := s.processWebhookEvent(logger, webhookEvent); err != nil {
logger.Error("paddle: failed to process webhook event", "error", err)
if markErr := s.webhookRepository.MarkError(requestID.String(), err.Error()); markErr != nil {
logger.Error("paddle: failed to mark webhook as errored", "error", markErr)
}
return err
}
if markErr := s.webhookRepository.MarkProcessed(requestID.String()); markErr != nil {
logger.Error("paddle: failed to mark webhook as processed", "error", markErr)
}
return nil
}
func (s *PaddleBillingService) skipWebhookEvent(
logger *slog.Logger,
requestID uuid.UUID,
webhookDTO PaddleWebhookDTO,
rawBody []byte,
) error {
webhookRecord := &billing_webhooks.WebhookRecord{
RequestID: requestID,
ProviderName: billing_provider.ProviderPaddle,
EventType: webhookDTO.EventType,
ProviderEventID: webhookDTO.EventID,
RawPayload: string(rawBody),
}
if err := s.webhookRepository.Insert(webhookRecord); err != nil {
logger.Error("paddle: failed to save skipped webhook record", "error", err)
return err
}
if err := s.webhookRepository.MarkSkipped(requestID.String()); err != nil {
logger.Error("paddle: failed to mark webhook as skipped", "error", err)
}
return nil
}
func (s *PaddleBillingService) processWebhookEvent(
logger *slog.Logger,
webhookEvent billing_models.WebhookEvent,
) error {
logger.Debug("processing webhook event")
// subscription.created - there is no subscription in the database yet
if webhookEvent.Type == billing_models.WHEventSubscriptionCreated {
return s.billingService.ActivateSubscription(logger, webhookEvent)
}
// dispute - finds subscription via invoice, no provider subscription ID available
if webhookEvent.Type == billing_models.WHEventSubscriptionDisputeCreated {
return s.billingService.RecordDispute(logger, webhookEvent)
}
// for others - search subscription first
subscription, err := s.billingService.GetSubscriptionByProviderSubID(logger, webhookEvent.ProviderSubscriptionID)
if err != nil {
logger.Error("paddle: failed to find subscription for webhook event", "error", err)
return err
}
logger = logger.With("subscription_id", subscription.ID, "database_id", subscription.DatabaseID)
logger.Debug(fmt.Sprintf("found subscription in DB with ID: %s", subscription.ID))
switch webhookEvent.Type {
case billing_models.WHEventSubscriptionUpdated:
if subscription.Status == billing_models.StatusCanceled {
return s.billingService.ReactivateSubscription(logger, subscription, webhookEvent)
}
return s.billingService.SyncSubscriptionFromProvider(logger, subscription, webhookEvent)
case billing_models.WHEventSubscriptionCanceled:
return s.billingService.CancelSubscription(logger, subscription, webhookEvent)
case billing_models.WHEventPaymentSucceeded:
return s.billingService.RecordPaymentSuccess(logger, subscription, webhookEvent)
case billing_models.WHEventSubscriptionPastDue:
return s.billingService.RecordPaymentFailed(logger, subscription, webhookEvent)
default:
logger.Error(fmt.Sprintf("unhandled webhook event type: %s", string(webhookEvent.Type)))
return nil
}
}
func (s *PaddleBillingService) normalizeWebhookEvent(
logger *slog.Logger,
requestID uuid.UUID,
eventID, eventType string,
data json.RawMessage,
) (billing_models.WebhookEvent, error) {
webhookEvent := billing_models.WebhookEvent{
RequestID: requestID,
ProviderEventID: eventID,
}
switch eventType {
case "subscription.created":
webhookEvent.Type = billing_models.WHEventSubscriptionCreated
var subscription paddle.Subscription
if err := json.Unmarshal(data, &subscription); err != nil {
logger.Error("paddle: failed to unmarshal subscription.created webhook data", "error", err)
return webhookEvent, err
}
webhookEvent.ProviderSubscriptionID = subscription.ID
webhookEvent.ProviderCustomerID = subscription.CustomerID
webhookEvent.QuantityGB = subscription.Items[0].Quantity
status, err := mapPaddleStatus(logger, subscription.Status)
if err != nil {
return webhookEvent, err
}
webhookEvent.Status = status
if subscription.CurrentBillingPeriod != nil {
periodStart := mustParseRFC3339(logger, "period start", subscription.CurrentBillingPeriod.StartsAt)
periodEnd := mustParseRFC3339(logger, "period end", subscription.CurrentBillingPeriod.EndsAt)
webhookEvent.PeriodStart = &periodStart
webhookEvent.PeriodEnd = &periodEnd
}
if subscription.CustomData == nil || subscription.CustomData["database_id"] == "" {
logger.Error("paddle: subscription has no database_id in custom data")
}
databaseIDStr, isOk := subscription.CustomData["database_id"].(string)
if !isOk {
logger.Error("paddle: database_id in custom data is not a string")
return webhookEvent, fmt.Errorf("invalid database_id type in custom data")
}
databaseID := uuid.MustParse(databaseIDStr)
webhookEvent.DatabaseID = &databaseID
case "subscription.updated":
var subscription paddle.Subscription
if err := json.Unmarshal(data, &subscription); err != nil {
return webhookEvent, err
}
webhookEvent.ProviderSubscriptionID = subscription.ID
webhookEvent.ProviderCustomerID = subscription.CustomerID
webhookEvent.QuantityGB = subscription.Items[0].Quantity
webhookEvent.Type = billing_models.WHEventSubscriptionUpdated
status, err := mapPaddleStatus(logger, subscription.Status)
if err != nil {
return webhookEvent, err
}
webhookEvent.Status = status
if subscription.CurrentBillingPeriod != nil {
periodStart := mustParseRFC3339(logger, "period start", subscription.CurrentBillingPeriod.StartsAt)
periodEnd := mustParseRFC3339(logger, "period end", subscription.CurrentBillingPeriod.EndsAt)
webhookEvent.PeriodStart = &periodStart
webhookEvent.PeriodEnd = &periodEnd
}
if subscription.ScheduledChange != nil &&
subscription.ScheduledChange.Action == paddle.ScheduledChangeActionCancel {
webhookEvent.Type = billing_models.WHEventSubscriptionCanceled
}
case "subscription.canceled":
webhookEvent.Type = billing_models.WHEventSubscriptionCanceled
var subscription paddle.Subscription
if err := json.Unmarshal(data, &subscription); err != nil {
return webhookEvent, err
}
webhookEvent.ProviderSubscriptionID = subscription.ID
webhookEvent.ProviderCustomerID = subscription.CustomerID
status, err := mapPaddleStatus(logger, subscription.Status)
if err != nil {
return webhookEvent, err
}
webhookEvent.Status = status
case "transaction.completed":
webhookEvent.Type = billing_models.WHEventPaymentSucceeded
var transaction paddle.Transaction
if err := json.Unmarshal(data, &transaction); err != nil {
return webhookEvent, err
}
webhookEvent.ProviderInvoiceID = transaction.ID
if len(transaction.Items) > 0 {
webhookEvent.QuantityGB = transaction.Items[0].Quantity
}
if transaction.SubscriptionID != nil {
webhookEvent.ProviderSubscriptionID = *transaction.SubscriptionID
}
if transaction.CustomerID != nil {
webhookEvent.ProviderCustomerID = *transaction.CustomerID
}
amountCents, err := strconv.ParseInt(transaction.Details.Totals.Total, 10, 64)
if err != nil {
logger.Error("paddle: failed to parse transaction total", "error", err)
} else {
webhookEvent.AmountCents = amountCents
}
if transaction.BillingPeriod != nil {
periodStart := mustParseRFC3339(logger, "period start", transaction.BillingPeriod.StartsAt)
periodEnd := mustParseRFC3339(logger, "period end", transaction.BillingPeriod.EndsAt)
webhookEvent.PeriodStart = &periodStart
webhookEvent.PeriodEnd = &periodEnd
}
case "subscription.past_due":
webhookEvent.Type = billing_models.WHEventSubscriptionPastDue
var subscription paddle.Subscription
if err := json.Unmarshal(data, &subscription); err != nil {
return webhookEvent, err
}
webhookEvent.ProviderSubscriptionID = subscription.ID
webhookEvent.ProviderCustomerID = subscription.CustomerID
webhookEvent.QuantityGB = subscription.Items[0].Quantity
status, err := mapPaddleStatus(logger, subscription.Status)
if err != nil {
return webhookEvent, err
}
webhookEvent.Status = status
if subscription.CurrentBillingPeriod != nil {
periodStart := mustParseRFC3339(logger, "period start", subscription.CurrentBillingPeriod.StartsAt)
periodEnd := mustParseRFC3339(logger, "period end", subscription.CurrentBillingPeriod.EndsAt)
webhookEvent.PeriodStart = &periodStart
webhookEvent.PeriodEnd = &periodEnd
}
case "adjustment.created":
webhookEvent.Type = billing_models.WHEventSubscriptionDisputeCreated
var adjustment struct {
TransactionID string `json:"transaction_id"`
}
if err := json.Unmarshal(data, &adjustment); err != nil {
return webhookEvent, err
}
webhookEvent.ProviderInvoiceID = adjustment.TransactionID
default:
logger.Debug("unsupported paddle event type, skipping", "event_type", eventType)
return webhookEvent, billing_webhooks.ErrUnsupportedEventType
}
return webhookEvent, nil
}
func (s *PaddleBillingService) toProviderSubscription(
logger *slog.Logger,
paddleSubscription *paddle.Subscription,
) (billing_provider.ProviderSubscription, error) {
status, err := mapPaddleStatus(logger, paddleSubscription.Status)
if err != nil {
return billing_provider.ProviderSubscription{}, err
}
if len(paddleSubscription.Items) == 0 {
return billing_provider.ProviderSubscription{}, fmt.Errorf(
"paddle subscription %s has no items",
paddleSubscription.ID,
)
}
providerSubscription := &billing_provider.ProviderSubscription{
ProviderSubscriptionID: paddleSubscription.ID,
ProviderCustomerID: paddleSubscription.CustomerID,
Status: status,
QuantityGB: paddleSubscription.Items[0].Quantity,
}
if paddleSubscription.CurrentBillingPeriod != nil {
providerSubscription.PeriodStart = mustParseRFC3339(
logger,
"period start",
paddleSubscription.CurrentBillingPeriod.StartsAt,
)
providerSubscription.PeriodEnd = mustParseRFC3339(
logger,
"period end",
paddleSubscription.CurrentBillingPeriod.EndsAt,
)
}
return *providerSubscription, nil
}
func mustParseRFC3339(logger *slog.Logger, label, value string) time.Time {
parsed, err := time.Parse(time.RFC3339, value)
if err != nil {
logger.Error(fmt.Sprintf("paddle: failed to parse %s", label), "error", err)
}
return parsed
}
func mapPaddleStatus(logger *slog.Logger, s paddle.SubscriptionStatus) (billing_models.SubscriptionStatus, error) {
switch s {
case paddle.SubscriptionStatusActive:
return billing_models.StatusActive, nil
case paddle.SubscriptionStatusPastDue:
return billing_models.StatusPastDue, nil
case paddle.SubscriptionStatusCanceled:
return billing_models.StatusCanceled, nil
case paddle.SubscriptionStatusTrialing:
return billing_models.StatusTrial, nil
case paddle.SubscriptionStatusPaused:
return billing_models.StatusCanceled, nil
default:
logger.Error(fmt.Sprintf("paddle: unknown subscription status: %s", string(s)))
return "", fmt.Errorf("paddle: unknown subscription status: %s", string(s))
}
}

View File

@@ -0,0 +1,38 @@
package billing_provider
import (
"time"
"github.com/google/uuid"
billing_models "databasus-backend/internal/features/billing/models"
)
type CreateSubscriptionRequest struct {
ProviderCustomerID string
DatabaseID uuid.UUID
StorageGB int
}
type ProviderSubscription struct {
ProviderSubscriptionID string
ProviderCustomerID string
Status billing_models.SubscriptionStatus
QuantityGB int
PeriodStart time.Time
PeriodEnd time.Time
}
type CheckoutRequest struct {
DatabaseID uuid.UUID
Email string
StorageGB int
SuccessURL string
CancelURL string
}
type ProviderName string
const (
ProviderPaddle ProviderName = "paddle"
)

View File

@@ -0,0 +1,21 @@
package billing_provider
import "log/slog"
type BillingProvider interface {
GetProviderName() ProviderName
UpgradeQuantityWithSurcharge(logger *slog.Logger, providerSubscriptionID string, quantityGB int) error
ScheduleQuantityDowngradeFromNextBillingCycle(
logger *slog.Logger,
providerSubscriptionID string,
quantityGB int,
) error
GetSubscription(logger *slog.Logger, providerSubscriptionID string) (ProviderSubscription, error)
CreateCheckoutSession(logger *slog.Logger, req CheckoutRequest) (checkoutURL string, err error)
CreatePortalSession(logger *slog.Logger, providerCustomerID, returnURL string) (portalURL string, err error)
}

View File

@@ -0,0 +1,72 @@
package billing_repositories
import (
"errors"
"github.com/google/uuid"
"gorm.io/gorm"
billing_models "databasus-backend/internal/features/billing/models"
"databasus-backend/internal/storage"
)
type InvoiceRepository struct{}
func (r *InvoiceRepository) Save(invoice billing_models.Invoice) error {
if invoice.SubscriptionID == uuid.Nil {
return errors.New("subscription id is required")
}
db := storage.GetDb()
if invoice.ID == uuid.Nil {
invoice.ID = uuid.New()
return db.Create(&invoice).Error
}
return db.Save(invoice).Error
}
func (r *InvoiceRepository) FindByProviderInvID(providerInvoiceID string) (*billing_models.Invoice, error) {
var invoice billing_models.Invoice
if err := storage.GetDb().Where("provider_invoice_id = ?", providerInvoiceID).
First(&invoice).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &invoice, nil
}
func (r *InvoiceRepository) FindByDatabaseID(
databaseID uuid.UUID,
limit, offset int,
) ([]*billing_models.Invoice, error) {
var invoices []*billing_models.Invoice
if err := storage.GetDb().Joins("JOIN subscriptions ON subscriptions.id = invoices.subscription_id").
Where("subscriptions.database_id = ?", databaseID).
Order("invoices.created_at DESC").
Limit(limit).
Offset(offset).
Find(&invoices).Error; err != nil {
return nil, err
}
return invoices, nil
}
func (r *InvoiceRepository) CountByDatabaseID(databaseID uuid.UUID) (int64, error) {
var count int64
err := storage.GetDb().Model(&billing_models.Invoice{}).
Joins("JOIN subscriptions ON subscriptions.id = invoices.subscription_id").
Where("subscriptions.database_id = ?", databaseID).
Count(&count).Error
return count, err
}

View File

@@ -0,0 +1,50 @@
package billing_repositories
import (
"errors"
"github.com/google/uuid"
billing_models "databasus-backend/internal/features/billing/models"
"databasus-backend/internal/storage"
)
type SubscriptionEventRepository struct{}
func (r *SubscriptionEventRepository) Create(event billing_models.SubscriptionEvent) error {
if event.SubscriptionID == uuid.Nil {
return errors.New("subscription id is required")
}
event.ID = uuid.New()
return storage.GetDb().Create(&event).Error
}
func (r *SubscriptionEventRepository) FindByDatabaseID(
databaseID uuid.UUID,
limit, offset int,
) ([]*billing_models.SubscriptionEvent, error) {
var events []*billing_models.SubscriptionEvent
if err := storage.GetDb().Joins("JOIN subscriptions ON subscriptions.id = subscription_events.subscription_id").
Where("subscriptions.database_id = ?", databaseID).
Order("subscription_events.created_at DESC").
Limit(limit).
Offset(offset).
Find(&events).Error; err != nil {
return nil, err
}
return events, nil
}
func (r *SubscriptionEventRepository) CountByDatabaseID(databaseID uuid.UUID) (int64, error) {
var count int64
err := storage.GetDb().Model(&billing_models.SubscriptionEvent{}).
Joins("JOIN subscriptions ON subscriptions.id = subscription_events.subscription_id").
Where("subscriptions.database_id = ?", databaseID).
Count(&count).Error
return count, err
}

View File

@@ -0,0 +1,123 @@
package billing_repositories
import (
"errors"
"time"
"github.com/google/uuid"
"gorm.io/gorm"
billing_models "databasus-backend/internal/features/billing/models"
"databasus-backend/internal/storage"
)
type SubscriptionRepository struct{}
func (r *SubscriptionRepository) Save(sub billing_models.Subscription) error {
db := storage.GetDb()
if sub.ID == uuid.Nil {
sub.ID = uuid.New()
return db.Create(&sub).Error
}
return db.Save(&sub).Error
}
func (r *SubscriptionRepository) FindByID(id uuid.UUID) (*billing_models.Subscription, error) {
var sub billing_models.Subscription
if err := storage.GetDb().Where("id = ?", id).First(&sub).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &sub, nil
}
func (r *SubscriptionRepository) FindByDatabaseIDAndStatuses(
databaseID uuid.UUID,
stauses []billing_models.SubscriptionStatus,
) ([]*billing_models.Subscription, error) {
var subs []*billing_models.Subscription
if err := storage.GetDb().Where("database_id = ? AND status IN ?", databaseID, stauses).
Find(&subs).Error; err != nil {
return nil, err
}
return subs, nil
}
func (r *SubscriptionRepository) FindLatestByDatabaseID(databaseID uuid.UUID) (*billing_models.Subscription, error) {
var sub billing_models.Subscription
if err := storage.GetDb().
Where("database_id = ?", databaseID).
Order("created_at DESC").
First(&sub).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &sub, nil
}
func (r *SubscriptionRepository) FindByProviderSubID(providerSubID string) (*billing_models.Subscription, error) {
var sub billing_models.Subscription
if err := storage.GetDb().Where("provider_sub_id = ?", providerSubID).
First(&sub).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &sub, nil
}
func (r *SubscriptionRepository) FindByStatuses(
statuses []billing_models.SubscriptionStatus,
) ([]billing_models.Subscription, error) {
var subs []billing_models.Subscription
if err := storage.GetDb().Where("status IN ?", statuses).Find(&subs).Error; err != nil {
return nil, err
}
return subs, nil
}
func (r *SubscriptionRepository) FindCanceledWithEndedGracePeriod(
now time.Time,
) ([]billing_models.Subscription, error) {
var subs []billing_models.Subscription
if err := storage.GetDb().
Where("status = ? AND data_retention_grace_period_until < ?", billing_models.StatusCanceled, now).
Find(&subs).
Error; err != nil {
return nil, err
}
return subs, nil
}
func (r *SubscriptionRepository) FindExpiredTrials(now time.Time) ([]billing_models.Subscription, error) {
var subs []billing_models.Subscription
if err := storage.GetDb().Where("status = ? AND current_period_end < ?", billing_models.StatusTrial, now).
Find(&subs).Error; err != nil {
return nil, err
}
return subs, nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,8 @@
package billing_webhooks
import "errors"
var (
ErrDuplicateWebhook = errors.New("duplicate webhook event")
ErrUnsupportedEventType = errors.New("unsupported webhook event type")
)

View File

@@ -0,0 +1,25 @@
package billing_webhooks
import (
"time"
"github.com/google/uuid"
billing_provider "databasus-backend/internal/features/billing/provider"
)
type WebhookRecord struct {
RequestID uuid.UUID `gorm:"column:request_id;primaryKey;type:uuid;default:gen_random_uuid()"`
ProviderName billing_provider.ProviderName `gorm:"column:provider_name;type:text;not null"`
EventType string `gorm:"column:event_type;type:text;not null"`
ProviderEventID string `gorm:"column:provider_event_id;type:text;not null;index"`
RawPayload string `gorm:"column:raw_payload;type:text;not null"`
ProcessedAt *time.Time `gorm:"column:processed_at"`
IsSkipped bool `gorm:"column:is_skipped;not null;default:false"`
Error *string `gorm:"column:error"`
CreatedAt time.Time `gorm:"column:created_at;not null"`
}
func (WebhookRecord) TableName() string {
return "webhook_records"
}

View File

@@ -0,0 +1,73 @@
package billing_webhooks
import (
"errors"
"time"
"gorm.io/gorm"
"databasus-backend/internal/storage"
)
type WebhookRepository struct{}
func (r *WebhookRepository) FindSuccessfulByProviderEventID(providerEventID string) (*WebhookRecord, error) {
var record WebhookRecord
err := storage.GetDb().
Where("provider_event_id = ? AND processed_at IS NOT NULL", providerEventID).
First(&record).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
return nil, err
}
return &record, nil
}
func (r *WebhookRepository) Insert(record *WebhookRecord) error {
if record.ProviderEventID == "" {
return errors.New("provider event ID is required")
}
record.CreatedAt = time.Now().UTC()
return storage.GetDb().Create(record).Error
}
func (r *WebhookRepository) MarkProcessed(requestID string) error {
now := time.Now().UTC()
return storage.
GetDb().
Model(&WebhookRecord{}).
Where("request_id = ?", requestID).
Update("processed_at", now).
Error
}
func (r *WebhookRepository) MarkSkipped(requestID string) error {
now := time.Now().UTC()
return storage.
GetDb().
Model(&WebhookRecord{}).
Where("request_id = ?", requestID).
Updates(map[string]any{
"is_skipped": true,
"processed_at": now,
}).
Error
}
func (r *WebhookRepository) MarkError(requestID, errMsg string) error {
return storage.
GetDb().
Model(&WebhookRecord{}).
Where("request_id = ?", requestID).
Update("error", errMsg).
Error
}

View File

@@ -1328,6 +1328,143 @@ func getTestPostgresConfig() *postgresql.PostgresqlDatabase {
}
}
func Test_CreateDatabase_WhenCloudAndUserIsNotReadOnly_ReturnsBadRequest(t *testing.T) {
enableCloud(t)
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Cloud Not ReadOnly", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
request := Database{
Name: "Cloud Non-ReadOnly DB",
WorkspaceID: &workspace.ID,
Type: DatabaseTypePostgres,
Postgresql: getTestPostgresConfig(),
}
resp := test_utils.MakePostRequest(
t,
router,
"/api/v1/databases/create",
"Bearer "+owner.Token,
request,
http.StatusBadRequest,
)
assert.Contains(t, string(resp.Body), "in cloud mode, only read-only database users are allowed")
}
func Test_CreateDatabase_WhenCloudAndUserIsReadOnly_DatabaseCreated(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Cloud ReadOnly", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Temp DB for RO User", workspace.ID, owner.Token, router)
readOnlyUser := createReadOnlyUserViaAPI(t, router, database.ID, owner.Token)
assert.NotEmpty(t, readOnlyUser.Username)
assert.NotEmpty(t, readOnlyUser.Password)
RemoveTestDatabase(database)
enableCloud(t)
pgConfig := getTestPostgresConfig()
pgConfig.Username = readOnlyUser.Username
pgConfig.Password = readOnlyUser.Password
request := Database{
Name: "Cloud ReadOnly DB",
WorkspaceID: &workspace.ID,
Type: DatabaseTypePostgres,
Postgresql: pgConfig,
}
var response Database
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/databases/create",
"Bearer "+owner.Token,
request,
http.StatusCreated,
&response,
)
defer RemoveTestDatabase(&response)
assert.Equal(t, "Cloud ReadOnly DB", response.Name)
assert.NotEqual(t, uuid.Nil, response.ID)
}
func Test_CreateDatabase_WhenNotCloudAndUserIsNotReadOnly_DatabaseCreated(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Non-Cloud", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
request := Database{
Name: "Non-Cloud DB",
WorkspaceID: &workspace.ID,
Type: DatabaseTypePostgres,
Postgresql: getTestPostgresConfig(),
}
var response Database
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/databases/create",
"Bearer "+owner.Token,
request,
http.StatusCreated,
&response,
)
defer RemoveTestDatabase(&response)
assert.Equal(t, "Non-Cloud DB", response.Name)
assert.NotEqual(t, uuid.Nil, response.ID)
}
func enableCloud(t *testing.T) {
t.Helper()
config.GetEnv().IsCloud = true
t.Cleanup(func() {
config.GetEnv().IsCloud = false
})
}
func createReadOnlyUserViaAPI(
t *testing.T,
router *gin.Engine,
databaseID uuid.UUID,
token string,
) *CreateReadOnlyUserResponse {
var database Database
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/databases/%s", databaseID.String()),
"Bearer "+token,
http.StatusOK,
&database,
)
var response CreateReadOnlyUserResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/databases/create-readonly-user",
"Bearer "+token,
database,
http.StatusOK,
&response,
)
return &response
}
func getTestMariadbConfig() *mariadb.MariadbDatabase {
env := config.GetEnv()
portStr := env.TestMariadb1011Port

View File

@@ -1,7 +1,6 @@
package mariadb
import (
"context"
"fmt"
"log/slog"
"os"
@@ -212,7 +211,7 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
mariadbModel := createMariadbModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
ctx := t.Context()
isReadOnly, privileges, err := mariadbModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
@@ -241,7 +240,7 @@ func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
mariadbModel := createMariadbModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
ctx := t.Context()
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
@@ -313,7 +312,7 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
mariadbModel := createMariadbModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
ctx := t.Context()
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
@@ -390,7 +389,7 @@ func Test_ReadOnlyUser_FutureTables_NoSelectPermission(t *testing.T) {
mariadbModel := createMariadbModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
ctx := t.Context()
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
@@ -466,7 +465,7 @@ func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) {
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
ctx := t.Context()
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
@@ -511,7 +510,7 @@ func Test_ReadOnlyUser_CannotDropOrAlterTables(t *testing.T) {
mariadbModel := createMariadbModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
ctx := t.Context()
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)

View File

@@ -42,9 +42,9 @@ func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
t.Parallel()
container := connectToMongodbContainer(t, tc.port, tc.version)
defer container.Client.Disconnect(context.Background())
defer container.Client.Disconnect(t.Context())
ctx := context.Background()
ctx := t.Context()
db := container.Client.Database(container.Database)
_ = db.Collection("permission_test").Drop(ctx)
@@ -108,9 +108,9 @@ func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
t.Parallel()
container := connectToMongodbContainer(t, tc.port, tc.version)
defer container.Client.Disconnect(context.Background())
defer container.Client.Disconnect(t.Context())
ctx := context.Background()
ctx := t.Context()
db := container.Client.Database(container.Database)
_ = db.Collection("backup_test").Drop(ctx)
@@ -178,11 +178,11 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
t.Parallel()
container := connectToMongodbContainer(t, tc.port, tc.version)
defer container.Client.Disconnect(context.Background())
defer container.Client.Disconnect(t.Context())
mongodbModel := createMongodbModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
ctx := t.Context()
isReadOnly, roles, err := mongodbModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
@@ -195,9 +195,9 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
env := config.GetEnv()
container := connectToMongodbContainer(t, env.TestMongodb70Port, tools.MongodbVersion7)
defer container.Client.Disconnect(context.Background())
defer container.Client.Disconnect(t.Context())
ctx := context.Background()
ctx := t.Context()
db := container.Client.Database(container.Database)
_ = db.Collection("readonly_check_test").Drop(ctx)
@@ -251,15 +251,15 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
t.Parallel()
container := connectToMongodbContainer(t, tc.port, tc.version)
defer container.Client.Disconnect(context.Background())
defer container.Client.Disconnect(t.Context())
ctx := context.Background()
ctx := t.Context()
db := container.Client.Database(container.Database)
_ = db.Collection("readonly_test").Drop(ctx)
_ = db.Collection("hack_collection").Drop(ctx)
_, err := db.Collection("readonly_test").InsertMany(ctx, []interface{}{
_, err := db.Collection("readonly_test").InsertMany(ctx, []any{
bson.M{"data": "test1"},
bson.M{"data": "test2"},
})
@@ -317,9 +317,9 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
func Test_ReadOnlyUser_FutureCollections_CanSelect(t *testing.T) {
env := config.GetEnv()
container := connectToMongodbContainer(t, env.TestMongodb70Port, tools.MongodbVersion7)
defer container.Client.Disconnect(context.Background())
defer container.Client.Disconnect(t.Context())
ctx := context.Background()
ctx := t.Context()
db := container.Client.Database(container.Database)
mongodbModel := createMongodbModel(container)
@@ -348,9 +348,9 @@ func Test_ReadOnlyUser_FutureCollections_CanSelect(t *testing.T) {
func Test_ReadOnlyUser_CannotDropOrModifyCollections(t *testing.T) {
env := config.GetEnv()
container := connectToMongodbContainer(t, env.TestMongodb70Port, tools.MongodbVersion7)
defer container.Client.Disconnect(context.Background())
defer container.Client.Disconnect(t.Context())
ctx := context.Background()
ctx := t.Context()
db := container.Client.Database(container.Database)
_ = db.Collection("drop_test").Drop(ctx)
@@ -420,7 +420,7 @@ func connectToMongodbContainer(
authDatabase,
)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
defer cancel()
clientOptions := options.Client().ApplyURI(uri)
@@ -473,7 +473,7 @@ func connectWithCredentials(
container.Database, container.AuthDatabase,
)
ctx := context.Background()
ctx := t.Context()
clientOptions := options.Client().ApplyURI(uri)
client, err := mongo.Connect(ctx, clientOptions)
assert.NoError(t, err)

View File

@@ -1,7 +1,6 @@
package mysql
import (
"context"
"fmt"
"log/slog"
"os"
@@ -231,7 +230,7 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
mysqlModel := createMysqlModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
ctx := t.Context()
isReadOnly, privileges, err := mysqlModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
@@ -260,7 +259,7 @@ func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
mysqlModel := createMysqlModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
ctx := t.Context()
username, password, err := mysqlModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
@@ -326,7 +325,7 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
mysqlModel := createMysqlModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
ctx := t.Context()
username, password, err := mysqlModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
@@ -400,7 +399,7 @@ func Test_ReadOnlyUser_FutureTables_NoSelectPermission(t *testing.T) {
mysqlModel := createMysqlModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
ctx := t.Context()
username, password, err := mysqlModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
@@ -477,7 +476,7 @@ func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) {
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
ctx := t.Context()
username, password, err := mysqlModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
@@ -523,7 +522,7 @@ func Test_ReadOnlyUser_CannotDropOrAlterTables(t *testing.T) {
mysqlModel := createMysqlModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
ctx := t.Context()
username, password, err := mysqlModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)

View File

@@ -81,8 +81,8 @@ func (p *PostgresqlDatabase) Validate() error {
p.BackupType = PostgresBackupTypePgDump
}
if p.BackupType == PostgresBackupTypePgDump && config.GetEnv().IsCloud {
return errors.New("PG_DUMP backup type is not supported in cloud mode")
if p.BackupType != PostgresBackupTypePgDump && config.GetEnv().IsCloud {
return errors.New("only PG_DUMP backup type is supported in cloud mode")
}
if p.BackupType == PostgresBackupTypePgDump {

View File

@@ -1,7 +1,6 @@
package postgresql
import (
"context"
"fmt"
"log/slog"
"os"
@@ -267,7 +266,7 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
ctx := t.Context()
isReadOnly, privileges, err := pgModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
@@ -294,7 +293,7 @@ func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
ctx := t.Context()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
@@ -359,7 +358,7 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
ctx := t.Context()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
@@ -438,7 +437,7 @@ func Test_ReadOnlyUser_FutureTables_HaveSelectPermission(t *testing.T) {
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
ctx := t.Context()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
@@ -491,7 +490,7 @@ func Test_ReadOnlyUser_MultipleSchemas_AllAccessible(t *testing.T) {
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
ctx := t.Context()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
@@ -566,7 +565,7 @@ func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) {
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
ctx := t.Context()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
@@ -653,7 +652,7 @@ func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
ctx := t.Context()
connectionUsername, newPassword, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
@@ -743,7 +742,7 @@ func Test_CreateReadOnlyUser_WithPublicSchema_Success(t *testing.T) {
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
ctx := t.Context()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
@@ -851,7 +850,7 @@ func Test_CreateReadOnlyUser_WithoutPublicSchema_Success(t *testing.T) {
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
ctx := t.Context()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err, "CreateReadOnlyUser should succeed without public schema")
@@ -1018,7 +1017,7 @@ func Test_CreateReadOnlyUser_PublicSchemaExistsButNoPermissions_ReturnsError(t *
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
ctx := t.Context()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.Error(
@@ -1310,6 +1309,46 @@ func Test_Validate_WhenRequiredFieldsMissing_ReturnsError(t *testing.T) {
}
}
func Test_Validate_WhenCloudAndBackupTypeIsNotPgDump_ValidationFails(t *testing.T) {
enableCloud(t)
model := &PostgresqlDatabase{
Host: "example.com",
Port: 5432,
Username: "user",
Password: "pass",
CpuCount: 1,
BackupType: PostgresBackupTypeWalV1,
}
err := model.Validate()
assert.EqualError(t, err, "only PG_DUMP backup type is supported in cloud mode")
}
func Test_Validate_WhenCloudAndBackupTypeIsPgDump_ValidationPasses(t *testing.T) {
enableCloud(t)
model := &PostgresqlDatabase{
Host: "example.com",
Port: 5432,
Username: "user",
Password: "pass",
CpuCount: 1,
BackupType: PostgresBackupTypePgDump,
}
err := model.Validate()
assert.NoError(t, err)
}
func enableCloud(t *testing.T) {
t.Helper()
config.GetEnv().IsCloud = true
t.Cleanup(func() {
config.GetEnv().IsCloud = false
})
}
type PostgresContainer struct {
Host string
Port int
@@ -1395,7 +1434,7 @@ func Test_CreateReadOnlyUser_TablesCreatedByDifferentUser_ReadOnlyUserCanRead(t
// At this point, user_creator already owns objects, so ALTER DEFAULT PRIVILEGES FOR ROLE should apply
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
ctx := t.Context()
readonlyUsername, readonlyPassword, err := pgModel.CreateReadOnlyUser(
ctx,
@@ -1562,7 +1601,7 @@ func Test_CreateReadOnlyUser_WithIncludeSchemas_OnlyGrantsAccessToSpecifiedSchem
pgModel.IncludeSchemas = []string{"public", "included_schema"}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
ctx := t.Context()
readonlyUsername, readonlyPassword, err := pgModel.CreateReadOnlyUser(
ctx,

View File

@@ -2,7 +2,6 @@ package databases
import (
"sync"
"sync/atomic"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/notifiers"
@@ -40,22 +39,7 @@ func GetDatabaseController() *DatabaseController {
return databaseController
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService)
notifiers.GetNotifierService().SetNotifierDatabaseCounter(databaseService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}
var SetupDependencies = sync.OnceFunc(func() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService)
notifiers.GetNotifierService().SetNotifierDatabaseCounter(databaseService)
})

View File

@@ -25,16 +25,16 @@ type Database struct {
Name string `json:"name" gorm:"column:name;type:text;not null"`
Type DatabaseType `json:"type" gorm:"column:type;type:text;not null"`
Postgresql *postgresql.PostgresqlDatabase `json:"postgresql,omitempty" gorm:"foreignKey:DatabaseID"`
Mysql *mysql.MysqlDatabase `json:"mysql,omitempty" gorm:"foreignKey:DatabaseID"`
Mariadb *mariadb.MariadbDatabase `json:"mariadb,omitempty" gorm:"foreignKey:DatabaseID"`
Mongodb *mongodb.MongodbDatabase `json:"mongodb,omitempty" gorm:"foreignKey:DatabaseID"`
Postgresql *postgresql.PostgresqlDatabase `json:"postgresql,omitzero" gorm:"foreignKey:DatabaseID"`
Mysql *mysql.MysqlDatabase `json:"mysql,omitzero" gorm:"foreignKey:DatabaseID"`
Mariadb *mariadb.MariadbDatabase `json:"mariadb,omitzero" gorm:"foreignKey:DatabaseID"`
Mongodb *mongodb.MongodbDatabase `json:"mongodb,omitzero" gorm:"foreignKey:DatabaseID"`
Notifiers []notifiers.Notifier `json:"notifiers" gorm:"many2many:database_notifiers;"`
// these fields are not reliable, but
// they are used for pretty UI
LastBackupTime *time.Time `json:"lastBackupTime,omitempty" gorm:"column:last_backup_time;type:timestamp with time zone"`
LastBackupTime *time.Time `json:"lastBackupTime,omitzero" gorm:"column:last_backup_time;type:timestamp with time zone"`
LastBackupErrorMessage *string `json:"lastBackupErrorMessage,omitempty" gorm:"column:last_backup_error_message;type:text"`
HealthStatus *HealthStatus `json:"healthStatus" gorm:"column:health_status;type:text;not null"`

View File

@@ -1 +0,0 @@
package secrets

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
@@ -16,34 +15,28 @@ type HealthcheckAttemptBackgroundService struct {
checkDatabaseHealthUseCase *CheckDatabaseHealthUseCase
logger *slog.Logger
runOnce sync.Once
hasRun atomic.Bool
hasRun atomic.Bool
}
func (s *HealthcheckAttemptBackgroundService) Run(ctx context.Context) {
wasAlreadyRun := s.hasRun.Load()
s.runOnce.Do(func() {
s.hasRun.Store(true)
// first healthcheck immediately
s.checkDatabases()
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.checkDatabases()
}
}
})
if wasAlreadyRun {
if s.hasRun.Swap(true) {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
// first healthcheck immediately
s.checkDatabases()
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.checkDatabases()
}
}
}
func (s *HealthcheckAttemptBackgroundService) checkDatabases() {

View File

@@ -1,9 +1,6 @@
package healthcheck_attempt
import (
"sync"
"sync/atomic"
"databasus-backend/internal/features/databases"
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
"databasus-backend/internal/features/notifiers"
@@ -30,8 +27,6 @@ var healthcheckAttemptBackgroundService = &HealthcheckAttemptBackgroundService{
healthcheckConfigService: healthcheck_config.GetHealthcheckConfigService(),
checkDatabaseHealthUseCase: checkDatabaseHealthUseCase,
logger: logger.GetLogger(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
var healthcheckAttemptController = &HealthcheckAttemptController{

View File

@@ -2,7 +2,6 @@ package healthcheck_config
import (
"sync"
"sync/atomic"
"databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/databases"
@@ -33,23 +32,8 @@ func GetHealthcheckConfigController() *HealthcheckConfigController {
return healthcheckConfigController
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
databases.
GetDatabaseService().
AddDbCreationListener(healthcheckConfigService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}
var SetupDependencies = sync.OnceFunc(func() {
databases.
GetDatabaseService().
AddDbCreationListener(healthcheckConfigService)
})

View File

@@ -2,7 +2,6 @@ package notifiers
import (
"sync"
"sync/atomic"
audit_logs "databasus-backend/internal/features/audit_logs"
workspaces_services "databasus-backend/internal/features/workspaces/services"
@@ -39,21 +38,6 @@ func GetNotifierRepository() *NotifierRepository {
return notifierRepository
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}
var SetupDependencies = sync.OnceFunc(func() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService)
})

View File

@@ -23,12 +23,12 @@ type Notifier struct {
LastSendError *string `json:"lastSendError" gorm:"column:last_send_error;type:text"`
// specific notifier
TelegramNotifier *telegram_notifier.TelegramNotifier `json:"telegramNotifier" gorm:"foreignKey:NotifierID"`
EmailNotifier *email_notifier.EmailNotifier `json:"emailNotifier" gorm:"foreignKey:NotifierID"`
WebhookNotifier *webhook_notifier.WebhookNotifier `json:"webhookNotifier" gorm:"foreignKey:NotifierID"`
SlackNotifier *slack_notifier.SlackNotifier `json:"slackNotifier" gorm:"foreignKey:NotifierID"`
DiscordNotifier *discord_notifier.DiscordNotifier `json:"discordNotifier" gorm:"foreignKey:NotifierID"`
TeamsNotifier *teams_notifier.TeamsNotifier `json:"teamsNotifier,omitempty" gorm:"foreignKey:NotifierID;constraint:OnDelete:CASCADE"`
TelegramNotifier *telegram_notifier.TelegramNotifier `json:"telegramNotifier" gorm:"foreignKey:NotifierID"`
EmailNotifier *email_notifier.EmailNotifier `json:"emailNotifier" gorm:"foreignKey:NotifierID"`
WebhookNotifier *webhook_notifier.WebhookNotifier `json:"webhookNotifier" gorm:"foreignKey:NotifierID"`
SlackNotifier *slack_notifier.SlackNotifier `json:"slackNotifier" gorm:"foreignKey:NotifierID"`
DiscordNotifier *discord_notifier.DiscordNotifier `json:"discordNotifier" gorm:"foreignKey:NotifierID"`
TeamsNotifier *teams_notifier.TeamsNotifier `json:"teamsNotifier,omitzero" gorm:"foreignKey:NotifierID;constraint:OnDelete:CASCADE"`
}
func (n *Notifier) TableName() string {

View File

@@ -49,7 +49,7 @@ type cardAttachment struct {
type payload struct {
Title string `json:"title"`
Text string `json:"text"`
Attachments []cardAttachment `json:"attachments,omitempty"`
Attachments []cardAttachment `json:"attachments,omitzero"`
}
func (n *TeamsNotifier) Send(

Some files were not shown because too many files have changed in this diff Show More