mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 08:41:58 +02:00
Compare commits
34 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8c8c712d97 | ||
|
|
4e1cee2aa2 | ||
|
|
18b8178608 | ||
|
|
02d9cda86f | ||
|
|
cefedb6ddd | ||
|
|
27d891fb34 | ||
|
|
d1c41ed53a | ||
|
|
f287967b5d | ||
|
|
44ddcb836e | ||
|
|
7913c1b474 | ||
|
|
189573fa1b | ||
|
|
63e23b2489 | ||
|
|
1926096377 | ||
|
|
0a131511a8 | ||
|
|
aa01ce0b76 | ||
|
|
1ac0eb4d5b | ||
|
|
c7d091fe51 | ||
|
|
b1dfd1c425 | ||
|
|
4bee78646a | ||
|
|
3a5a53c92d | ||
|
|
f0ab470a84 | ||
|
|
f728fda759 | ||
|
|
80b5df6283 | ||
|
|
67556a0db1 | ||
|
|
c4cf7f8446 | ||
|
|
61a0bcabb1 | ||
|
|
f1e289c421 | ||
|
|
c0952e057f | ||
|
|
b4d4e0a1d7 | ||
|
|
c648e9c29f | ||
|
|
3fce6d2a99 | ||
|
|
198b94ba9d | ||
|
|
80cd0bf5d3 | ||
|
|
231e3cc709 |
42
Dockerfile
42
Dockerfile
@@ -239,7 +239,8 @@ RUN apt-get update && \
|
||||
fi
|
||||
|
||||
# Create postgres user and set up directories
|
||||
RUN useradd -m -s /bin/bash postgres || true && \
|
||||
RUN groupadd -g 999 postgres || true && \
|
||||
useradd -m -s /bin/bash -u 999 -g 999 postgres || true && \
|
||||
mkdir -p /databasus-data/pgdata && \
|
||||
chown -R postgres:postgres /databasus-data/pgdata
|
||||
|
||||
@@ -294,6 +295,23 @@ if [ -d "/postgresus-data" ] && [ "\$(ls -A /postgresus-data 2>/dev/null)" ]; th
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# ========= Adjust postgres user UID/GID =========
|
||||
PUID=\${PUID:-999}
|
||||
PGID=\${PGID:-999}
|
||||
|
||||
CURRENT_UID=\$(id -u postgres)
|
||||
CURRENT_GID=\$(id -g postgres)
|
||||
|
||||
if [ "\$CURRENT_GID" != "\$PGID" ]; then
|
||||
echo "Adjusting postgres group GID from \$CURRENT_GID to \$PGID..."
|
||||
groupmod -o -g "\$PGID" postgres
|
||||
fi
|
||||
|
||||
if [ "\$CURRENT_UID" != "\$PUID" ]; then
|
||||
echo "Adjusting postgres user UID from \$CURRENT_UID to \$PUID..."
|
||||
usermod -o -u "\$PUID" postgres
|
||||
fi
|
||||
|
||||
# PostgreSQL 17 binary paths
|
||||
PG_BIN="/usr/lib/postgresql/17/bin"
|
||||
|
||||
@@ -316,7 +334,9 @@ window.__RUNTIME_CONFIG__ = {
|
||||
GOOGLE_CLIENT_ID: '\${GOOGLE_CLIENT_ID:-}',
|
||||
IS_EMAIL_CONFIGURED: '\$IS_EMAIL_CONFIGURED',
|
||||
CLOUDFLARE_TURNSTILE_SITE_KEY: '\${CLOUDFLARE_TURNSTILE_SITE_KEY:-}',
|
||||
CONTAINER_ARCH: '\${CONTAINER_ARCH:-unknown}'
|
||||
CONTAINER_ARCH: '\${CONTAINER_ARCH:-unknown}',
|
||||
CLOUD_PRICE_PER_GB: '\${CLOUD_PRICE_PER_GB:-}',
|
||||
CLOUD_PADDLE_CLIENT_TOKEN: '\${CLOUD_PADDLE_CLIENT_TOKEN:-}'
|
||||
};
|
||||
JSEOF
|
||||
|
||||
@@ -329,6 +349,15 @@ if [ -n "\${ANALYTICS_SCRIPT:-}" ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
# Inject Paddle script if client token is provided (only if not already injected)
|
||||
if [ -n "\${CLOUD_PADDLE_CLIENT_TOKEN:-}" ]; then
|
||||
if ! grep -q "cdn.paddle.com" /app/ui/build/index.html 2>/dev/null; then
|
||||
echo "Injecting Paddle script..."
|
||||
sed -i "s#</head># <script src=\"https://cdn.paddle.com/paddle/v2/paddle.js\"></script>\\
|
||||
</head>#" /app/ui/build/index.html
|
||||
fi
|
||||
fi
|
||||
|
||||
# Inject static HTML into root div for cloud mode (payment system requires visible legal links)
|
||||
if [ "\${IS_CLOUD:-false}" = "true" ]; then
|
||||
if ! grep -q "cloud-static-content" /app/ui/build/index.html 2>/dev/null; then
|
||||
@@ -341,7 +370,7 @@ if [ "\${IS_CLOUD:-false}" = "true" ]; then
|
||||
close \$fh;
|
||||
\$c =~ s/\\n/ /g;
|
||||
}
|
||||
s/<div id="root"><\\/div>/<div id="root"><!-- cloud-static-content -->\$c<\\/div>/
|
||||
s/<div id="root"><\\/div>/<div id="root"><!-- cloud-static-content --><noscript>\$c<\\/noscript><\\/div>/
|
||||
' /app/ui/build/index.html
|
||||
fi
|
||||
fi
|
||||
@@ -395,7 +424,12 @@ fi
|
||||
# Function to start PostgreSQL and wait for it to be ready
|
||||
start_postgres() {
|
||||
echo "Starting PostgreSQL..."
|
||||
gosu postgres \$PG_BIN/postgres -D /databasus-data/pgdata -p 5437 &
|
||||
# -k /tmp: create Unix socket and lock file in /tmp instead of /var/run/postgresql/.
|
||||
# On NAS systems (e.g. TrueNAS Scale), the ZFS-backed Docker overlay filesystem
|
||||
# ignores chown/chmod on directories from image layers, so PostgreSQL gets
|
||||
# "Permission denied" when creating .s.PGSQL.5437.lock in /var/run/postgresql/.
|
||||
# All internal connections use TCP (-h localhost), so the socket location does not matter.
|
||||
gosu postgres \$PG_BIN/postgres -D /databasus-data/pgdata -p 5437 -k /tmp &
|
||||
POSTGRES_PID=\$!
|
||||
|
||||
echo "Waiting for PostgreSQL to be ready..."
|
||||
|
||||
22
NOTICE.md
Normal file
22
NOTICE.md
Normal file
@@ -0,0 +1,22 @@
|
||||
Copyright © 2025–2026 Rostislav Dugin and contributors.
|
||||
|
||||
“Databasus” is a trademark of Rostislav Dugin.
|
||||
|
||||
The source code in this repository is licensed under the Apache License, Version 2.0.
|
||||
That license applies to the code only and does not grant any right to use the
|
||||
Databasus name, logo, or branding, except for reasonable and customary referential
|
||||
use in describing the origin of the software and reproducing the content of this NOTICE.
|
||||
|
||||
Permitted referential use includes truthful use of the name “Databasus” to identify
|
||||
the original Databasus project in software catalogs, deployment templates, hosting
|
||||
panels, package indexes, compatibility pages, integrations, tutorials, reviews, and
|
||||
similar informational materials, including phrases such as “Databasus”,
|
||||
“Deploy Databasus”, “Databasus on Coolify”, and “Compatible with Databasus”.
|
||||
|
||||
You may not use “Databasus” as the name or primary branding of a competing product,
|
||||
service, fork, distribution, or hosted offering, or in any manner likely to cause
|
||||
confusion as to source, affiliation, sponsorship, or endorsement.
|
||||
|
||||
Nothing in this repository transfers, waives, limits, or estops any rights in the
|
||||
Databasus mark. All trademark rights are reserved except for the limited referential
|
||||
use stated above.
|
||||
22
README.md
22
README.md
@@ -1,8 +1,8 @@
|
||||
<div align="center">
|
||||
<img src="assets/logo.svg" alt="Databasus Logo" width="250"/>
|
||||
|
||||
<h3>Backup tool for PostgreSQL, MySQL and MongoDB</h3>
|
||||
<p>Databasus is a free, open source and self-hosted tool to backup databases (with focus on PostgreSQL). Make backups with different storages (S3, Google Drive, FTP, etc.) and notifications about progress (Slack, Discord, Telegram, etc.)</p>
|
||||
<h3>PostgreSQL backup tool (with MySQL\MariaDB and MongoDB support)</h3>
|
||||
<p>Databasus is a free, open source and self-hosted tool to backup databases (with primary focus on PostgreSQL). Make backups with different storages (S3, Google Drive, FTP, etc.) and notifications about progress (Slack, Discord, Telegram, etc.)</p>
|
||||
|
||||
<!-- Badges -->
|
||||
[](https://www.postgresql.org/)
|
||||
@@ -99,8 +99,8 @@ It is also important for Databasus that you are able to decrypt and restore back
|
||||
### 📦 **Backup types**
|
||||
|
||||
- **Logical** — Native dump of the database in its engine-specific binary format. Compressed and streamed directly to storage with no intermediate files
|
||||
- **Physical** — File-level copy of the entire database cluster. Faster backup and restore for large datasets compared to logical dumps (requires agent)
|
||||
- **Incremental** — Physical base backup combined with continuous WAL segment archiving. Enables Point-in-Time Recovery (PITR) — restore to any second between backups. Designed for disaster recovery and near-zero data loss requirements (requires agent)
|
||||
- **Physical** — File-level copy of the entire database cluster. Faster backup and restore for large datasets compared to logical dumps
|
||||
- **Incremental** — Physical base backup combined with continuous WAL segment archiving. Enables Point-in-time recovery (PITR) — restore to any second between backups. Designed for disaster recovery and near-zero data loss requirements
|
||||
|
||||
### 🐳 **Self-hosted & secure**
|
||||
|
||||
@@ -259,7 +259,9 @@ Contributions are welcome! Read the <a href="https://databasus.com/contribute">c
|
||||
|
||||
Also you can join our large community of developers, DBAs and DevOps engineers on Telegram [@databasus_community](https://t.me/databasus_community).
|
||||
|
||||
## AI disclaimer
|
||||
## FAQ
|
||||
|
||||
### AI disclaimer
|
||||
|
||||
There have been questions about AI usage in project development in issues and discussions. As the project focuses on security, reliability and production usage, it's important to explain how AI is used in the development process.
|
||||
|
||||
@@ -295,3 +297,13 @@ Moreover, it's important to note that we do not differentiate between bad human
|
||||
Even if code is written manually by a human, it's not guaranteed to be merged. Vibe code is not allowed at all and all such PRs are rejected by default (see [contributing guide](https://databasus.com/contribute)).
|
||||
|
||||
We also draw attention to fast issue resolution and security [vulnerability reporting](https://github.com/databasus/databasus?tab=security-ov-file#readme).
|
||||
|
||||
### You have a cloud version — are you truly open source?
|
||||
|
||||
Yes. Every feature available in Databasus Cloud is equally available in the self-hosted version with no restrictions, no feature gates and no usage limits. The entire codebase is Apache 2.0 licensed and always will be.
|
||||
|
||||
Databasus is not "open core." We do not withhold features behind a paid tier and then call the limited remainder "open source," as projects like GitLab or Sentry do. We believe open source means the complete product is open, not just a marketing label on a stripped-down edition.
|
||||
|
||||
Databasus Cloud runs the exact same code as the self-hosted version. The only difference is that we take care of infrastructure, availability, backups, reservations, monitoring and updates for you — so you don't have to. If you are using cloud, you can always move your databases from cloud to self-hosted if you wish.
|
||||
|
||||
Revenue from Cloud funds full-time development of the project. Most large open-source projects rely on corporate backing or sponsorship to survive. We chose a different path: Databasus sustains itself so it can grow and improve independently, without being tied to any enterprise or sponsor.
|
||||
|
||||
@@ -110,8 +110,7 @@ func (c *Config) applyDefaults() {
|
||||
}
|
||||
|
||||
if c.IsDeleteWalAfterUpload == nil {
|
||||
v := true
|
||||
c.IsDeleteWalAfterUpload = &v
|
||||
c.IsDeleteWalAfterUpload = new(true)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
60
agent/internal/features/api/idle_timeout_reader.go
Normal file
60
agent/internal/features/api/idle_timeout_reader.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
// IdleTimeoutReader wraps an io.Reader and cancels the associated context
|
||||
// if no bytes are successfully read within the specified timeout duration.
|
||||
// This detects stalled uploads where the network or source stops transmitting data.
|
||||
//
|
||||
// When the idle timeout fires, the reader is also closed (if it implements io.Closer)
|
||||
// to unblock any goroutine blocked on the underlying Read.
|
||||
type IdleTimeoutReader struct {
|
||||
reader io.Reader
|
||||
timeout time.Duration
|
||||
cancel context.CancelCauseFunc
|
||||
timer *time.Timer
|
||||
}
|
||||
|
||||
// NewIdleTimeoutReader creates a reader that cancels the context via cancel
|
||||
// if Read does not return any bytes for the given timeout duration.
|
||||
func NewIdleTimeoutReader(reader io.Reader, timeout time.Duration, cancel context.CancelCauseFunc) *IdleTimeoutReader {
|
||||
r := &IdleTimeoutReader{
|
||||
reader: reader,
|
||||
timeout: timeout,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
r.timer = time.AfterFunc(timeout, func() {
|
||||
cancel(fmt.Errorf("upload idle timeout: no bytes transmitted for %v", timeout))
|
||||
|
||||
if closer, ok := reader.(io.Closer); ok {
|
||||
_ = closer.Close()
|
||||
}
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *IdleTimeoutReader) Read(p []byte) (int, error) {
|
||||
n, err := r.reader.Read(p)
|
||||
|
||||
if n > 0 {
|
||||
r.timer.Reset(r.timeout)
|
||||
}
|
||||
|
||||
if err != nil && err != io.EOF {
|
||||
r.Stop()
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Stop cancels the idle timer. Must be called when the reader is no longer needed.
|
||||
func (r *IdleTimeoutReader) Stop() {
|
||||
r.timer.Stop()
|
||||
}
|
||||
112
agent/internal/features/api/idle_timeout_reader_test.go
Normal file
112
agent/internal/features/api/idle_timeout_reader_test.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_ReadThroughIdleTimeoutReader_WhenBytesFlowContinuously_DoesNotCancelContext(t *testing.T) {
|
||||
ctx, cancel := context.WithCancelCause(t.Context())
|
||||
defer cancel(nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
idleReader := NewIdleTimeoutReader(pr, 200*time.Millisecond, cancel)
|
||||
defer idleReader.Stop()
|
||||
|
||||
go func() {
|
||||
for range 5 {
|
||||
_, _ = pw.Write([]byte("data"))
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
_ = pw.Close()
|
||||
}()
|
||||
|
||||
data, err := io.ReadAll(idleReader)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "datadatadatadatadata", string(data))
|
||||
assert.NoError(t, ctx.Err(), "context should not be cancelled when bytes flow continuously")
|
||||
}
|
||||
|
||||
func Test_ReadThroughIdleTimeoutReader_WhenNoBytesTransmitted_CancelsContext(t *testing.T) {
|
||||
ctx, cancel := context.WithCancelCause(t.Context())
|
||||
defer cancel(nil)
|
||||
|
||||
pr, _ := io.Pipe()
|
||||
|
||||
idleReader := NewIdleTimeoutReader(pr, 100*time.Millisecond, cancel)
|
||||
defer idleReader.Stop()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
assert.Error(t, ctx.Err(), "context should be cancelled when no bytes are transmitted")
|
||||
assert.Contains(t, context.Cause(ctx).Error(), "upload idle timeout")
|
||||
}
|
||||
|
||||
func Test_ReadThroughIdleTimeoutReader_WhenBytesStopMidStream_CancelsContext(t *testing.T) {
|
||||
ctx, cancel := context.WithCancelCause(t.Context())
|
||||
defer cancel(nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
idleReader := NewIdleTimeoutReader(pr, 100*time.Millisecond, cancel)
|
||||
defer idleReader.Stop()
|
||||
|
||||
go func() {
|
||||
_, _ = pw.Write([]byte("initial"))
|
||||
// Stop writing — simulate stalled source
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
n, _ := idleReader.Read(buf)
|
||||
assert.Equal(t, "initial", string(buf[:n]))
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
assert.Error(t, ctx.Err(), "context should be cancelled when bytes stop mid-stream")
|
||||
assert.Contains(t, context.Cause(ctx).Error(), "upload idle timeout")
|
||||
}
|
||||
|
||||
func Test_StopIdleTimeoutReader_WhenCalledBeforeTimeout_DoesNotCancelContext(t *testing.T) {
|
||||
ctx, cancel := context.WithCancelCause(t.Context())
|
||||
defer cancel(nil)
|
||||
|
||||
pr, _ := io.Pipe()
|
||||
|
||||
idleReader := NewIdleTimeoutReader(pr, 100*time.Millisecond, cancel)
|
||||
idleReader.Stop()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
assert.NoError(t, ctx.Err(), "context should not be cancelled when reader is stopped before timeout")
|
||||
}
|
||||
|
||||
func Test_ReadThroughIdleTimeoutReader_WhenReaderReturnsError_PropagatesError(t *testing.T) {
|
||||
ctx, cancel := context.WithCancelCause(t.Context())
|
||||
defer cancel(nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
idleReader := NewIdleTimeoutReader(pr, 5*time.Second, cancel)
|
||||
defer idleReader.Stop()
|
||||
|
||||
expectedErr := fmt.Errorf("test read error")
|
||||
_ = pw.CloseWithError(expectedErr)
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
_, err := idleReader.Read(buf)
|
||||
|
||||
assert.ErrorIs(t, err, expectedErr)
|
||||
|
||||
// Timer should be stopped after error — context should not be cancelled
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
assert.NoError(t, ctx.Err(), "context should not be cancelled after reader error stops the timer")
|
||||
}
|
||||
@@ -21,9 +21,11 @@ import (
|
||||
const (
|
||||
checkInterval = 30 * time.Second
|
||||
retryDelay = 1 * time.Minute
|
||||
uploadTimeout = 30 * time.Minute
|
||||
uploadTimeout = 23 * time.Hour
|
||||
)
|
||||
|
||||
var uploadIdleTimeout = 5 * time.Minute
|
||||
|
||||
var retryDelayOverride *time.Duration
|
||||
|
||||
type CmdBuilder func(ctx context.Context) *exec.Cmd
|
||||
@@ -176,16 +178,32 @@ func (backuper *FullBackuper) executeAndUploadBasebackup(ctx context.Context) er
|
||||
|
||||
// Phase 1: Stream compressed data via io.Pipe directly to the API.
|
||||
pipeReader, pipeWriter := io.Pipe()
|
||||
defer func() { _ = pipeReader.Close() }()
|
||||
|
||||
go backuper.compressAndStream(pipeWriter, stdoutPipe)
|
||||
|
||||
uploadCtx, cancel := context.WithTimeout(ctx, uploadTimeout)
|
||||
defer cancel()
|
||||
uploadCtx, timeoutCancel := context.WithTimeout(ctx, uploadTimeout)
|
||||
defer timeoutCancel()
|
||||
|
||||
uploadResp, uploadErr := backuper.apiClient.UploadBasebackup(uploadCtx, pipeReader)
|
||||
idleCtx, idleCancel := context.WithCancelCause(uploadCtx)
|
||||
defer idleCancel(nil)
|
||||
|
||||
idleReader := api.NewIdleTimeoutReader(pipeReader, uploadIdleTimeout, idleCancel)
|
||||
defer idleReader.Stop()
|
||||
|
||||
uploadResp, uploadErr := backuper.apiClient.UploadBasebackup(idleCtx, idleReader)
|
||||
|
||||
if uploadErr != nil && cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
|
||||
cmdErr := cmd.Wait()
|
||||
|
||||
if uploadErr != nil {
|
||||
if cause := context.Cause(idleCtx); cause != nil {
|
||||
uploadErr = cause
|
||||
}
|
||||
|
||||
stderrStr := stderrBuf.String()
|
||||
if stderrStr != "" {
|
||||
return fmt.Errorf("upload basebackup: %w (pg_basebackup stderr: %s)", uploadErr, stderrStr)
|
||||
|
||||
@@ -71,7 +71,7 @@ func Test_RunFullBackup_WhenChainBroken_BasebackupTriggered(t *testing.T) {
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "test-backup-data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
@@ -124,7 +124,7 @@ func Test_RunFullBackup_WhenScheduledBackupDue_BasebackupTriggered(t *testing.T)
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "scheduled-backup-data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
@@ -169,7 +169,7 @@ func Test_RunFullBackup_WhenNoFullBackupExists_ImmediateBasebackupTriggered(t *t
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "first-backup-data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
@@ -233,7 +233,7 @@ func Test_RunFullBackup_WhenUploadFails_RetriesAfterDelay(t *testing.T) {
|
||||
setRetryDelay(100 * time.Millisecond)
|
||||
defer setRetryDelay(origRetryDelay)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
@@ -282,7 +282,7 @@ func Test_RunFullBackup_WhenAlreadyRunning_SkipsExecution(t *testing.T) {
|
||||
|
||||
fb.isRunning.Store(true)
|
||||
|
||||
fb.checkAndRunIfNeeded(context.Background())
|
||||
fb.checkAndRunIfNeeded(t.Context())
|
||||
|
||||
mu.Lock()
|
||||
count := uploadCount
|
||||
@@ -318,7 +318,7 @@ func Test_RunFullBackup_WhenContextCancelled_StopsCleanly(t *testing.T) {
|
||||
setRetryDelay(5 * time.Second)
|
||||
defer setRetryDelay(origRetryDelay)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
@@ -360,7 +360,7 @@ func Test_RunFullBackup_WhenChainValidAndNotScheduled_NoBasebackupTriggered(t *t
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
@@ -411,7 +411,7 @@ func Test_RunFullBackup_WhenStderrParsingFails_FinalizesWithErrorAndRetries(t *t
|
||||
setRetryDelay(100 * time.Millisecond)
|
||||
defer setRetryDelay(origRetryDelay)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
@@ -458,7 +458,7 @@ func Test_RunFullBackup_WhenNextBackupTimeNull_BasebackupTriggered(t *testing.T)
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "first-run-data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
@@ -498,7 +498,7 @@ func Test_RunFullBackup_WhenChainValidityReturns401_NoBasebackupTriggered(t *tes
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
@@ -538,7 +538,7 @@ func Test_RunFullBackup_WhenUploadSucceeds_BodyIsZstdCompressed(t *testing.T) {
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, originalContent, validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
@@ -562,6 +562,68 @@ func Test_RunFullBackup_WhenUploadSucceeds_BodyIsZstdCompressed(t *testing.T) {
|
||||
assert.Equal(t, originalContent, string(decompressed))
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenUploadStalls_FailsWithIdleTimeout(t *testing.T) {
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testFullStartPath:
|
||||
// Server reads body normally — it will block until connection is closed
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = stallingCmdBuilder(t)
|
||||
|
||||
origIdleTimeout := uploadIdleTimeout
|
||||
uploadIdleTimeout = 200 * time.Millisecond
|
||||
defer func() { uploadIdleTimeout = origIdleTimeout }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := fb.executeAndUploadBasebackup(ctx)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "idle timeout", "error should mention idle timeout")
|
||||
}
|
||||
|
||||
func stallingCmdBuilder(t *testing.T) CmdBuilder {
|
||||
t.Helper()
|
||||
|
||||
return func(ctx context.Context) *exec.Cmd {
|
||||
cmd := exec.CommandContext(ctx, os.Args[0],
|
||||
"-test.run=TestHelperProcessStalling",
|
||||
"--",
|
||||
)
|
||||
|
||||
cmd.Env = append(os.Environ(), "GO_TEST_HELPER_PROCESS_STALLING=1")
|
||||
|
||||
return cmd
|
||||
}
|
||||
}
|
||||
|
||||
func TestHelperProcessStalling(t *testing.T) {
|
||||
if os.Getenv("GO_TEST_HELPER_PROCESS_STALLING") != "1" {
|
||||
return
|
||||
}
|
||||
|
||||
// Write enough data to flush through the zstd encoder's internal buffer (~128KB blocks).
|
||||
// Without enough data, zstd buffers everything and the pipe never receives bytes.
|
||||
data := make([]byte, 256*1024)
|
||||
for i := range data {
|
||||
data[i] = byte(i)
|
||||
}
|
||||
_, _ = os.Stdout.Write(data)
|
||||
|
||||
// Stall with stdout open — the compress goroutine blocks on its next read.
|
||||
// The parent process will kill us when the context is cancelled.
|
||||
time.Sleep(time.Hour)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func newTestServer(t *testing.T, handler http.HandlerFunc) *httptest.Server {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ package restore
|
||||
import (
|
||||
"archive/tar"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -86,7 +85,7 @@ func Test_RunRestore_WhenBasebackupAndWalSegmentsAvailable_FilesExtractedAndReco
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
|
||||
|
||||
err := restorer.Run(context.Background())
|
||||
err := restorer.Run(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
pgVersionContent, err := os.ReadFile(filepath.Join(targetDir, "PG_VERSION"))
|
||||
@@ -152,7 +151,7 @@ func Test_RunRestore_WhenTargetTimeProvided_RecoveryTargetTimeWrittenToConfig(t
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "2026-02-28T14:30:00Z", "")
|
||||
|
||||
err := restorer.Run(context.Background())
|
||||
err := restorer.Run(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
autoConfContent, err := os.ReadFile(filepath.Join(targetDir, "postgresql.auto.conf"))
|
||||
@@ -169,7 +168,7 @@ func Test_RunRestore_WhenPgDataDirNotEmpty_ReturnsError(t *testing.T) {
|
||||
|
||||
restorer := newTestRestorer("http://localhost:0", targetDir, "", "", "")
|
||||
|
||||
err = restorer.Run(context.Background())
|
||||
err = restorer.Run(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not empty")
|
||||
}
|
||||
@@ -179,7 +178,7 @@ func Test_RunRestore_WhenPgDataDirDoesNotExist_ReturnsError(t *testing.T) {
|
||||
|
||||
restorer := newTestRestorer("http://localhost:0", nonExistentDir, "", "", "")
|
||||
|
||||
err := restorer.Run(context.Background())
|
||||
err := restorer.Run(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "does not exist")
|
||||
}
|
||||
@@ -197,7 +196,7 @@ func Test_RunRestore_WhenNoBackupsAvailable_ReturnsError(t *testing.T) {
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
|
||||
|
||||
err := restorer.Run(context.Background())
|
||||
err := restorer.Run(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "No full backups available")
|
||||
}
|
||||
@@ -216,7 +215,7 @@ func Test_RunRestore_WhenWalChainBroken_ReturnsError(t *testing.T) {
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
|
||||
|
||||
err := restorer.Run(context.Background())
|
||||
err := restorer.Run(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "WAL chain broken")
|
||||
assert.Contains(t, err.Error(), testWalSegment1)
|
||||
@@ -282,7 +281,7 @@ func Test_DownloadWalSegment_WhenFirstAttemptFails_RetriesAndSucceeds(t *testing
|
||||
retryDelayOverride = &testDelay
|
||||
defer func() { retryDelayOverride = origDelay }()
|
||||
|
||||
err := restorer.Run(context.Background())
|
||||
err := restorer.Run(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
mu.Lock()
|
||||
@@ -341,7 +340,7 @@ func Test_DownloadWalSegment_WhenAllAttemptsFail_ReturnsErrorWithSegmentName(t *
|
||||
retryDelayOverride = &testDelay
|
||||
defer func() { retryDelayOverride = origDelay }()
|
||||
|
||||
err := restorer.Run(context.Background())
|
||||
err := restorer.Run(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), testWalSegment1)
|
||||
assert.Contains(t, err.Error(), "3 attempts")
|
||||
@@ -351,7 +350,7 @@ func Test_RunRestore_WhenInvalidTargetTimeFormat_ReturnsError(t *testing.T) {
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer("http://localhost:0", targetDir, "", "not-a-valid-time", "")
|
||||
|
||||
err := restorer.Run(context.Background())
|
||||
err := restorer.Run(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid --target-time format")
|
||||
}
|
||||
@@ -384,7 +383,7 @@ func Test_RunRestore_WhenBasebackupDownloadFails_ReturnsError(t *testing.T) {
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
|
||||
|
||||
err := restorer.Run(context.Background())
|
||||
err := restorer.Run(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "basebackup download failed")
|
||||
}
|
||||
@@ -423,7 +422,7 @@ func Test_RunRestore_WhenNoWalSegmentsInPlan_BasebackupRestoredSuccessfully(t *t
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
|
||||
|
||||
err := restorer.Run(context.Background())
|
||||
err := restorer.Run(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
pgVersionContent, err := os.ReadFile(filepath.Join(targetDir, "PG_VERSION"))
|
||||
@@ -486,7 +485,7 @@ func Test_RunRestore_WhenMakingApiCalls_AuthTokenIncludedInRequests(t *testing.T
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
|
||||
|
||||
err := restorer.Run(context.Background())
|
||||
err := restorer.Run(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.GreaterOrEqual(t, int(receivedAuthHeaders.Load()), 2)
|
||||
@@ -530,7 +529,7 @@ func Test_ConfigurePostgresRecovery_WhenPgTypeHost_UsesHostAbsolutePath(t *testi
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "host")
|
||||
|
||||
err := restorer.Run(context.Background())
|
||||
err := restorer.Run(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
autoConfContent, err := os.ReadFile(filepath.Join(targetDir, "postgresql.auto.conf"))
|
||||
@@ -577,7 +576,7 @@ func Test_ConfigurePostgresRecovery_WhenPgTypeDocker_UsesContainerPath(t *testin
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "docker")
|
||||
|
||||
err := restorer.Run(context.Background())
|
||||
err := restorer.Run(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
autoConfContent, err := os.ReadFile(filepath.Join(targetDir, "postgresql.auto.conf"))
|
||||
|
||||
@@ -21,7 +21,7 @@ func Test_NewLockWatcher_CapturesInode(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
_, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
watcher, err := NewLockWatcher(lockFile, cancel, log)
|
||||
@@ -37,7 +37,7 @@ func Test_LockWatcher_FileUnchanged_ContextNotCancelled(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
watcher, err := NewLockWatcher(lockFile, cancel, log)
|
||||
@@ -62,7 +62,7 @@ func Test_LockWatcher_FileDeleted_CancelsContext(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
watcher, err := NewLockWatcher(lockFile, cancel, log)
|
||||
@@ -88,7 +88,7 @@ func Test_LockWatcher_FileReplacedWithDifferentInode_CancelsContext(t *testing.T
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
watcher, err := NewLockWatcher(lockFile, cancel, log)
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -18,6 +18,8 @@ import (
|
||||
"databasus-agent/internal/features/api"
|
||||
)
|
||||
|
||||
var uploadIdleTimeout = 5 * time.Minute
|
||||
|
||||
const (
|
||||
pollInterval = 10 * time.Second
|
||||
uploadTimeout = 5 * time.Minute
|
||||
@@ -113,7 +115,7 @@ func (s *Streamer) listSegments() ([]string, error) {
|
||||
segments = append(segments, name)
|
||||
}
|
||||
|
||||
sort.Strings(segments)
|
||||
slices.Sort(segments)
|
||||
|
||||
return segments, nil
|
||||
}
|
||||
@@ -122,16 +124,27 @@ func (s *Streamer) uploadSegment(ctx context.Context, segmentName string) error
|
||||
filePath := filepath.Join(s.cfg.PgWalDir, segmentName)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
defer func() { _ = pr.Close() }()
|
||||
|
||||
go s.compressAndStream(pw, filePath)
|
||||
|
||||
uploadCtx, cancel := context.WithTimeout(ctx, uploadTimeout)
|
||||
defer cancel()
|
||||
uploadCtx, timeoutCancel := context.WithTimeout(ctx, uploadTimeout)
|
||||
defer timeoutCancel()
|
||||
|
||||
idleCtx, idleCancel := context.WithCancelCause(uploadCtx)
|
||||
defer idleCancel(nil)
|
||||
|
||||
idleReader := api.NewIdleTimeoutReader(pr, uploadIdleTimeout, idleCancel)
|
||||
defer idleReader.Stop()
|
||||
|
||||
s.log.Info("Uploading WAL segment", "segment", segmentName)
|
||||
|
||||
result, err := s.apiClient.UploadWalSegment(uploadCtx, segmentName, pr)
|
||||
result, err := s.apiClient.UploadWalSegment(idleCtx, segmentName, idleReader)
|
||||
if err != nil {
|
||||
if cause := context.Cause(idleCtx); cause != nil {
|
||||
return fmt.Errorf("upload WAL segment: %w", cause)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package wal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -42,7 +44,7 @@ func Test_UploadSegment_SingleSegment_ServerReceivesCorrectHeadersAndBody(t *tes
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
@@ -79,7 +81,7 @@ func Test_UploadSegments_MultipleSegmentsOutOfOrder_UploadedInAscendingOrder(t *
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
@@ -115,7 +117,7 @@ func Test_UploadSegments_DirectoryHasTmpFiles_TmpFilesIgnored(t *testing.T) {
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
@@ -146,7 +148,7 @@ func Test_UploadSegment_DeleteEnabled_FileRemovedAfterUpload(t *testing.T) {
|
||||
apiClient := api.NewClient(server.URL, cfg.Token, logger.GetLogger())
|
||||
streamer := NewStreamer(cfg, apiClient, logger.GetLogger())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
@@ -174,7 +176,7 @@ func Test_UploadSegment_DeleteDisabled_FileKeptAfterUpload(t *testing.T) {
|
||||
apiClient := api.NewClient(server.URL, cfg.Token, logger.GetLogger())
|
||||
streamer := NewStreamer(cfg, apiClient, logger.GetLogger())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
@@ -199,7 +201,7 @@ func Test_UploadSegment_ServerReturns500_FileKeptInQueue(t *testing.T) {
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
@@ -223,7 +225,7 @@ func Test_ProcessQueue_EmptyDirectory_NoUploads(t *testing.T) {
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
@@ -238,7 +240,7 @@ func Test_Run_ContextCancelled_StopsImmediately(t *testing.T) {
|
||||
|
||||
streamer := newTestStreamer(walDir, "http://localhost:0")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
@@ -276,7 +278,7 @@ func Test_UploadSegment_ServerReturns409_FileNotDeleted(t *testing.T) {
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
@@ -287,6 +289,49 @@ func Test_UploadSegment_ServerReturns409_FileNotDeleted(t *testing.T) {
|
||||
assert.NoError(t, err, "segment file should not be deleted on gap detection")
|
||||
}
|
||||
|
||||
func Test_UploadSegment_WhenUploadStalls_FailsWithIdleTimeout(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
|
||||
// Use incompressible random data to ensure TCP buffers fill up
|
||||
segmentContent := make([]byte, 1024*1024)
|
||||
_, err := rand.Read(segmentContent)
|
||||
require.NoError(t, err)
|
||||
|
||||
writeTestSegment(t, walDir, "000000010000000100000001", segmentContent)
|
||||
|
||||
var requestReceived atomic.Bool
|
||||
handlerDone := make(chan struct{})
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestReceived.Store(true)
|
||||
|
||||
// Read one byte then stall — simulates a network stall
|
||||
buf := make([]byte, 1)
|
||||
_, _ = r.Body.Read(buf)
|
||||
<-handlerDone
|
||||
}))
|
||||
defer server.Close()
|
||||
defer close(handlerDone)
|
||||
|
||||
origIdleTimeout := uploadIdleTimeout
|
||||
uploadIdleTimeout = 200 * time.Millisecond
|
||||
defer func() { uploadIdleTimeout = origIdleTimeout }()
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
uploadErr := streamer.uploadSegment(ctx, "000000010000000100000001")
|
||||
|
||||
assert.Error(t, uploadErr, "upload should fail when stalled")
|
||||
assert.True(t, requestReceived.Load(), "server should have received the request")
|
||||
assert.Contains(t, uploadErr.Error(), "idle timeout", "error should mention idle timeout")
|
||||
|
||||
_, statErr := os.Stat(filepath.Join(walDir, "000000010000000100000001"))
|
||||
assert.NoError(t, statErr, "segment file should remain in queue after idle timeout")
|
||||
}
|
||||
|
||||
func newTestStreamer(walDir, serverURL string) *Streamer {
|
||||
cfg := createTestConfig(walDir, serverURL)
|
||||
apiClient := api.NewClient(serverURL, cfg.Token, logger.GetLogger())
|
||||
|
||||
@@ -64,16 +64,12 @@ func (w *rotatingWriter) rotate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
loggerInstance *slog.Logger
|
||||
once sync.Once
|
||||
)
|
||||
var loggerInstance *slog.Logger
|
||||
|
||||
var initLogger = sync.OnceFunc(initialize)
|
||||
|
||||
func GetLogger() *slog.Logger {
|
||||
once.Do(func() {
|
||||
initialize()
|
||||
})
|
||||
|
||||
initLogger()
|
||||
return loggerInstance
|
||||
}
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ func Test_Write_MultipleSmallWrites_CurrentSizeAccumulated(t *testing.T) {
|
||||
rw, _, _ := setupRotatingWriter(t, 1024)
|
||||
|
||||
var totalWritten int64
|
||||
for i := 0; i < 10; i++ {
|
||||
for range 10 {
|
||||
data := []byte("line\n")
|
||||
n, err := rw.Write(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -27,6 +27,13 @@ VALKEY_PORT=6379
|
||||
VALKEY_USERNAME=
|
||||
VALKEY_PASSWORD=
|
||||
VALKEY_IS_SSL=false
|
||||
# billing
|
||||
PRICE_PER_GB_CENTS=
|
||||
IS_PADDLE_SANDBOX=true
|
||||
PADDLE_API_KEY=
|
||||
PADDLE_WEBHOOK_SECRET=
|
||||
PADDLE_PRICE_ID=
|
||||
PADDLE_CLIENT_TOKEN=
|
||||
# testing
|
||||
# to get Google Drive env variables: add storage in UI and copy data from added storage here
|
||||
TEST_GOOGLE_DRIVE_CLIENT_ID=
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime/debug"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -25,6 +26,8 @@ import (
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
backups_services "databasus-backend/internal/features/backups/backups/services"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/billing"
|
||||
billing_paddle "databasus-backend/internal/features/billing/paddle"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/disk"
|
||||
"databasus-backend/internal/features/encryption/secrets"
|
||||
@@ -105,7 +108,9 @@ func main() {
|
||||
go generateSwaggerDocs(log)
|
||||
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
ginApp := gin.Default()
|
||||
ginApp := gin.New()
|
||||
ginApp.Use(gin.Logger())
|
||||
ginApp.Use(ginRecoveryWithLogger(log))
|
||||
|
||||
// Add GZIP compression middleware
|
||||
ginApp.Use(gzip.Gzip(
|
||||
@@ -188,7 +193,7 @@ func startServerWithGracefulShutdown(log *slog.Logger, app *gin.Engine) {
|
||||
log.Info("Shutdown signal received")
|
||||
|
||||
// Gracefully shutdown VictoriaLogs writer
|
||||
logger.ShutdownVictoriaLogs(5 * time.Second)
|
||||
logger.ShutdownVictoriaLogs()
|
||||
|
||||
// The context is used to inform the server it has 10 seconds to finish
|
||||
// the request it is currently handling
|
||||
@@ -217,6 +222,10 @@ func setUpRoutes(r *gin.Engine) {
|
||||
backups_controllers.GetPostgresWalBackupController().RegisterRoutes(v1)
|
||||
databases.GetDatabaseController().RegisterPublicRoutes(v1)
|
||||
|
||||
if config.GetEnv().IsCloud {
|
||||
billing_paddle.GetPaddleBillingController().RegisterPublicRoutes(v1)
|
||||
}
|
||||
|
||||
// Setup auth middleware
|
||||
userService := users_services.GetUserService()
|
||||
authMiddleware := users_middleware.AuthMiddleware(userService)
|
||||
@@ -240,6 +249,7 @@ func setUpRoutes(r *gin.Engine) {
|
||||
audit_logs.GetAuditLogController().RegisterRoutes(protected)
|
||||
users_controllers.GetManagementController().RegisterRoutes(protected)
|
||||
users_controllers.GetSettingsController().RegisterRoutes(protected)
|
||||
billing.GetBillingController().RegisterRoutes(protected)
|
||||
}
|
||||
|
||||
func setUpDependencies() {
|
||||
@@ -252,6 +262,11 @@ func setUpDependencies() {
|
||||
storages.SetupDependencies()
|
||||
backups_config.SetupDependencies()
|
||||
task_cancellation.SetupDependencies()
|
||||
billing.SetupDependencies()
|
||||
|
||||
if config.GetEnv().IsCloud {
|
||||
billing_paddle.SetupDependencies()
|
||||
}
|
||||
}
|
||||
|
||||
func runBackgroundTasks(log *slog.Logger) {
|
||||
@@ -308,6 +323,12 @@ func runBackgroundTasks(log *slog.Logger) {
|
||||
go runWithPanicLogging(log, "restore nodes registry background service", func() {
|
||||
restoring.GetRestoreNodesRegistry().Run(ctx)
|
||||
})
|
||||
|
||||
if config.GetEnv().IsCloud {
|
||||
go runWithPanicLogging(log, "billing background service", func() {
|
||||
billing.GetBillingService().Run(ctx, *log)
|
||||
})
|
||||
}
|
||||
} else {
|
||||
log.Info("Skipping primary node tasks as not primary node")
|
||||
}
|
||||
@@ -330,7 +351,7 @@ func runBackgroundTasks(log *slog.Logger) {
|
||||
func runWithPanicLogging(log *slog.Logger, serviceName string, fn func()) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error("Panic in "+serviceName, "error", r)
|
||||
log.Error("Panic in "+serviceName, "error", r, "stacktrace", string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
fn()
|
||||
@@ -410,6 +431,25 @@ func enableCors(ginApp *gin.Engine) {
|
||||
}
|
||||
}
|
||||
|
||||
func ginRecoveryWithLogger(log *slog.Logger) gin.HandlerFunc {
|
||||
return func(ctx *gin.Context) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error("Panic recovered in HTTP handler",
|
||||
"error", r,
|
||||
"stacktrace", string(debug.Stack()),
|
||||
"method", ctx.Request.Method,
|
||||
"path", ctx.Request.URL.Path,
|
||||
)
|
||||
|
||||
ctx.AbortWithStatus(http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
|
||||
ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func mountFrontend(ginApp *gin.Engine) {
|
||||
staticDir := "./ui/build"
|
||||
ginApp.NoRoute(func(c *gin.Context) {
|
||||
|
||||
@@ -5,6 +5,7 @@ go 1.26.1
|
||||
require (
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0
|
||||
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.3
|
||||
github.com/PaddleHQ/paddle-go-sdk v1.0.0
|
||||
github.com/gin-contrib/cors v1.7.5
|
||||
github.com/gin-contrib/gzip v1.2.3
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
@@ -100,6 +101,8 @@ require (
|
||||
github.com/emersion/go-message v0.18.2 // indirect
|
||||
github.com/emersion/go-vcard v0.0.0-20241024213814-c9703dde27ff // indirect
|
||||
github.com/flynn/noise v1.1.0 // indirect
|
||||
github.com/ggicci/httpin v0.19.0 // indirect
|
||||
github.com/ggicci/owl v0.8.2 // indirect
|
||||
github.com/go-chi/chi/v5 v5.2.3 // indirect
|
||||
github.com/go-darwin/apfs v0.0.0-20211011131704-f84b94dbf348 // indirect
|
||||
github.com/go-git/go-billy/v5 v5.6.2 // indirect
|
||||
|
||||
@@ -77,6 +77,8 @@ github.com/Max-Sum/base32768 v0.0.0-20230304063302-18e6ce5945fd/go.mod h1:C8yoIf
|
||||
github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY=
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/PaddleHQ/paddle-go-sdk v1.0.0 h1:+EXitsPFbRcc0CpQE/MIeudxiVOR8pFe/aOWTEUHDKU=
|
||||
github.com/PaddleHQ/paddle-go-sdk v1.0.0/go.mod h1:kbBBzf0BHEj38QvhtoELqlGip3alKgA/I+vl7RQzB58=
|
||||
github.com/ProtonMail/bcrypt v0.0.0-20210511135022-227b4adcab57/go.mod h1:HecWFHognK8GfRDGnFQbW/LiV7A3MX3gZVs45vk5h8I=
|
||||
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf h1:yc9daCCYUefEs69zUkSzubzjBbL+cmOXgnmt9Fyd9ug=
|
||||
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf/go.mod h1:o0ESU9p83twszAU8LBeJKFAAMX14tISa0yk4Oo5TOqo=
|
||||
@@ -248,6 +250,10 @@ github.com/gabriel-vasile/mimetype v1.4.11/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9t
|
||||
github.com/geoffgarside/ber v1.1.0/go.mod h1:jVPKeCbj6MvQZhwLYsGwaGI52oUorHoHKNecGT85ZCc=
|
||||
github.com/geoffgarside/ber v1.2.0 h1:/loowoRcs/MWLYmGX9QtIAbA+V/FrnVLsMMPhwiRm64=
|
||||
github.com/geoffgarside/ber v1.2.0/go.mod h1:jVPKeCbj6MvQZhwLYsGwaGI52oUorHoHKNecGT85ZCc=
|
||||
github.com/ggicci/httpin v0.19.0 h1:p0B3SWLVgg770VirYiHB14M5wdRx3zR8mCTzM/TkTQ8=
|
||||
github.com/ggicci/httpin v0.19.0/go.mod h1:hzsQHcbqLabmGOycf7WNw6AAzcVbsMeoOp46bWAbIWc=
|
||||
github.com/ggicci/owl v0.8.2 h1:og+lhqpzSMPDdEB+NJfzoAJARP7qCG3f8uUC3xvGukA=
|
||||
github.com/ggicci/owl v0.8.2/go.mod h1:PHRD57u41vFN5UtFz2SF79yTVoM3HlWpjMiE+ZU2dj4=
|
||||
github.com/gin-contrib/cors v1.7.5 h1:cXC9SmofOrRg0w9PigwGlHG3ztswH6bqq4vJVXnvYMk=
|
||||
github.com/gin-contrib/cors v1.7.5/go.mod h1:4q3yi7xBEDDWKapjT2o1V7mScKDDr8k+jZ0fSquGoy0=
|
||||
github.com/gin-contrib/gzip v1.2.3 h1:dAhT722RuEG330ce2agAs75z7yB+NKvX/ZM1r8w0u2U=
|
||||
@@ -454,6 +460,8 @@ github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1
|
||||
github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk=
|
||||
github.com/jtolio/noiseconn v0.0.0-20231127013910-f6d9ecbf1de7 h1:JcltaO1HXM5S2KYOYcKgAV7slU0xPy1OcvrVgn98sRQ=
|
||||
github.com/jtolio/noiseconn v0.0.0-20231127013910-f6d9ecbf1de7/go.mod h1:MEkhEPFwP3yudWO0lj6vfYpLIB+3eIcuIW+e0AZzUQk=
|
||||
github.com/justinas/alice v1.2.0 h1:+MHSA/vccVCF4Uq37S42jwlkvI2Xzl7zTPCN5BnZNVo=
|
||||
github.com/justinas/alice v1.2.0/go.mod h1:fN5HRH/reO/zrUflLfTN43t3vXvKzvZIENsNEe7i7qA=
|
||||
github.com/jzelinskie/whirlpool v0.0.0-20201016144138-0675e54bb004 h1:G+9t9cEtnC9jFiTxyptEKuNIAbiN5ZCQzX2a74lj3xg=
|
||||
github.com/jzelinskie/whirlpool v0.0.0-20201016144138-0675e54bb004/go.mod h1:KmHnJWQrgEvbuy0vcvj00gtMqbvNn1L+3YUZLK/B92c=
|
||||
github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU=
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ilyakaznacheev/cleanenv"
|
||||
"github.com/joho/godotenv"
|
||||
@@ -53,6 +54,20 @@ type EnvVariables struct {
|
||||
TempFolder string
|
||||
SecretKeyPath string
|
||||
|
||||
// Billing (always tax-exclusive)
|
||||
PricePerGBCents int64 `env:"PRICE_PER_GB_CENTS"`
|
||||
MinStorageGB int
|
||||
MaxStorageGB int
|
||||
TrialDuration time.Duration
|
||||
TrialStorageGB int
|
||||
GracePeriod time.Duration
|
||||
// Paddle billing
|
||||
IsPaddleSandbox bool `env:"IS_PADDLE_SANDBOX"`
|
||||
PaddleApiKey string `env:"PADDLE_API_KEY"`
|
||||
PaddleWebhookSecret string `env:"PADDLE_WEBHOOK_SECRET"`
|
||||
PaddlePriceID string `env:"PADDLE_PRICE_ID"`
|
||||
PaddleClientToken string `env:"PADDLE_CLIENT_TOKEN"`
|
||||
|
||||
TestGoogleDriveClientID string `env:"TEST_GOOGLE_DRIVE_CLIENT_ID"`
|
||||
TestGoogleDriveClientSecret string `env:"TEST_GOOGLE_DRIVE_CLIENT_SECRET"`
|
||||
TestGoogleDriveTokenJSON string `env:"TEST_GOOGLE_DRIVE_TOKEN_JSON"`
|
||||
@@ -127,14 +142,13 @@ type EnvVariables struct {
|
||||
DatabasusURL string `env:"DATABASUS_URL"`
|
||||
}
|
||||
|
||||
var (
|
||||
env EnvVariables
|
||||
once sync.Once
|
||||
)
|
||||
var env EnvVariables
|
||||
|
||||
func GetEnv() EnvVariables {
|
||||
once.Do(loadEnvVariables)
|
||||
return env
|
||||
var initEnv = sync.OnceFunc(loadEnvVariables)
|
||||
|
||||
func GetEnv() *EnvVariables {
|
||||
initEnv()
|
||||
return &env
|
||||
}
|
||||
|
||||
func loadEnvVariables() {
|
||||
@@ -363,5 +377,39 @@ func loadEnvVariables() {
|
||||
|
||||
}
|
||||
|
||||
// Billing
|
||||
if env.IsCloud {
|
||||
if env.PricePerGBCents == 0 {
|
||||
log.Error("PRICE_PER_GB_CENTS is empty or zero")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if env.PaddleApiKey == "" {
|
||||
log.Error("PADDLE_API_KEY is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if env.PaddleWebhookSecret == "" {
|
||||
log.Error("PADDLE_WEBHOOK_SECRET is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if env.PaddlePriceID == "" {
|
||||
log.Error("PADDLE_PRICE_ID is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if env.PaddleClientToken == "" {
|
||||
log.Error("PADDLE_CLIENT_TOKEN is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
env.MinStorageGB = 20
|
||||
env.MaxStorageGB = 10_000
|
||||
env.TrialDuration = 24 * time.Hour
|
||||
env.TrialStorageGB = 20
|
||||
env.GracePeriod = 30 * 24 * time.Hour
|
||||
|
||||
log.Info("Environment variables loaded successfully!")
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
@@ -13,39 +12,32 @@ type AuditLogBackgroundService struct {
|
||||
auditLogService *AuditLogService
|
||||
logger *slog.Logger
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *AuditLogBackgroundService) Run(ctx context.Context) {
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
if s.hasRun.Swap(true) {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
s.logger.Info("Starting audit log cleanup background service")
|
||||
|
||||
s.logger.Info("Starting audit log cleanup background service")
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.cleanOldAuditLogs(); err != nil {
|
||||
s.logger.Error("Failed to clean old audit logs", "error", err)
|
||||
}
|
||||
case <-ticker.C:
|
||||
if err := s.cleanOldAuditLogs(); err != nil {
|
||||
s.logger.Error("Failed to clean old audit logs", "error", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -102,7 +102,7 @@ func Test_CleanOldAuditLogs_DeletesMultipleOldLogs(t *testing.T) {
|
||||
|
||||
// Create many old logs with specific UUIDs to track them
|
||||
testLogIDs := make([]uuid.UUID, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
for i := range 5 {
|
||||
testLogIDs[i] = uuid.New()
|
||||
daysAgo := 400 + (i * 10)
|
||||
log := &AuditLog{
|
||||
|
||||
@@ -2,7 +2,6 @@ package audit_logs
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
"databasus-backend/internal/util/logger"
|
||||
@@ -23,8 +22,6 @@ var auditLogController = &AuditLogController{
|
||||
var auditLogBackgroundService = &AuditLogBackgroundService{
|
||||
auditLogService: auditLogService,
|
||||
logger: logger.GetLogger(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
|
||||
func GetAuditLogService() *AuditLogService {
|
||||
@@ -39,23 +36,8 @@ func GetAuditLogBackgroundService() *AuditLogBackgroundService {
|
||||
return auditLogBackgroundService
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
users_services.GetUserService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
var SetupDependencies = sync.OnceFunc(func() {
|
||||
users_services.GetUserService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
|
||||
})
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -46,80 +45,73 @@ type BackuperNode struct {
|
||||
|
||||
lastHeartbeat time.Time
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (n *BackuperNode) Run(ctx context.Context) {
|
||||
wasAlreadyRun := n.hasRun.Load()
|
||||
if n.hasRun.Swap(true) {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", n))
|
||||
}
|
||||
|
||||
n.runOnce.Do(func() {
|
||||
n.hasRun.Store(true)
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
|
||||
|
||||
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
|
||||
backupNode := BackupNode{
|
||||
ID: n.nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: time.Now().UTC(),
|
||||
}
|
||||
|
||||
backupNode := BackupNode{
|
||||
ID: n.nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: time.Now().UTC(),
|
||||
}
|
||||
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
|
||||
n.logger.Error("Failed to register node in registry", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
|
||||
n.logger.Error("Failed to register node in registry", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
|
||||
go func() {
|
||||
n.MakeBackup(backupID, isCallNotifier)
|
||||
if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil {
|
||||
n.logger.Error(
|
||||
"Failed to publish backup completion",
|
||||
"error",
|
||||
err,
|
||||
"backupID",
|
||||
backupID,
|
||||
)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
err := n.backupNodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID, backupHandler)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := n.backupNodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
|
||||
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
|
||||
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
|
||||
go func() {
|
||||
n.MakeBackup(backupID, isCallNotifier)
|
||||
if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil {
|
||||
n.logger.Error(
|
||||
"Failed to publish backup completion",
|
||||
"error",
|
||||
err,
|
||||
"backupID",
|
||||
backupID,
|
||||
)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(heartbeatTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
|
||||
|
||||
if err := n.backupNodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
|
||||
n.logger.Error("Failed to unregister node from registry", "error", err)
|
||||
}
|
||||
|
||||
return
|
||||
case <-ticker.C:
|
||||
n.sendHeartbeat(&backupNode)
|
||||
}
|
||||
err := n.backupNodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID, backupHandler)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := n.backupNodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
|
||||
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
|
||||
}
|
||||
})
|
||||
}()
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", n))
|
||||
ticker := time.NewTicker(heartbeatTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
|
||||
|
||||
if err := n.backupNodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
|
||||
n.logger.Error("Failed to unregister node from registry", "error", err)
|
||||
}
|
||||
|
||||
return
|
||||
case <-ticker.C:
|
||||
n.sendHeartbeat(&backupNode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -171,26 +163,6 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
backup.BackupSizeMb = completedMBs
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
// Check size limit (0 = unlimited)
|
||||
if backupConfig.MaxBackupSizeMB > 0 &&
|
||||
completedMBs > float64(backupConfig.MaxBackupSizeMB) {
|
||||
errMsg := fmt.Sprintf(
|
||||
"backup size (%.2f MB) exceeded maximum allowed size (%d MB)",
|
||||
completedMBs,
|
||||
backupConfig.MaxBackupSizeMB,
|
||||
)
|
||||
|
||||
backup.Status = backups_core.BackupStatusFailed
|
||||
backup.IsSkipRetry = true
|
||||
backup.FailMessage = &errMsg
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save backup with size exceeded error", "error", err)
|
||||
}
|
||||
cancel() // Cancel the backup context
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to update backup progress", "error", err)
|
||||
}
|
||||
@@ -308,7 +280,6 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
return
|
||||
}
|
||||
|
||||
backup.Status = backups_core.BackupStatusCompleted
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
// Update backup with encryption metadata if provided
|
||||
@@ -325,12 +296,6 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
backup.Encryption = backupMetadata.Encryption
|
||||
}
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save backup", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Save metadata file to storage
|
||||
if backupMetadata != nil {
|
||||
metadataJSON, err := json.Marshal(backupMetadata)
|
||||
if err != nil {
|
||||
@@ -363,6 +328,13 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
}
|
||||
}
|
||||
|
||||
backup.Status = backups_core.BackupStatusCompleted
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save backup", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Update database last backup time
|
||||
now := time.Now().UTC()
|
||||
if updateErr := n.databaseService.SetLastBackupTime(databaseID, now); updateErr != nil {
|
||||
|
||||
@@ -153,121 +153,3 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
assert.Equal(t, notifier.ID, capturedNotifier.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_BackupSizeLimits(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
t.Run("UnlimitedSize_MaxBackupSizeMBIsZero_BackupCompletes", func(t *testing.T) {
|
||||
// Enable backups with unlimited size (0)
|
||||
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
backupConfig.MaxBackupSizeMB = 0 // unlimited
|
||||
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.createBackupUseCase = &CreateLargeBackupUsecase{}
|
||||
|
||||
// Create a backup record
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode.MakeBackup(backup.ID, false)
|
||||
|
||||
// Verify backup completed successfully even with large size
|
||||
updatedBackup, err := backupRepository.FindByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
|
||||
assert.Equal(t, float64(10000), updatedBackup.BackupSizeMb)
|
||||
assert.Nil(t, updatedBackup.FailMessage)
|
||||
})
|
||||
|
||||
t.Run("SizeExceeded_BackupFailedWithIsSkipRetry", func(t *testing.T) {
|
||||
// Enable backups with 5 MB limit
|
||||
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
backupConfig.MaxBackupSizeMB = 5
|
||||
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.createBackupUseCase = &CreateProgressiveBackupUsecase{}
|
||||
|
||||
// Create a backup record
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode.MakeBackup(backup.ID, false)
|
||||
|
||||
// Verify backup was marked as failed with IsSkipRetry=true
|
||||
updatedBackup, err := backupRepository.FindByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusFailed, updatedBackup.Status)
|
||||
assert.True(t, updatedBackup.IsSkipRetry)
|
||||
assert.NotNil(t, updatedBackup.FailMessage)
|
||||
assert.Contains(t, *updatedBackup.FailMessage, "exceeded maximum allowed size")
|
||||
assert.Contains(t, *updatedBackup.FailMessage, "10.00 MB")
|
||||
assert.Contains(t, *updatedBackup.FailMessage, "5 MB")
|
||||
assert.Greater(t, updatedBackup.BackupSizeMb, float64(5))
|
||||
})
|
||||
|
||||
t.Run("SizeWithinLimit_BackupCompletes", func(t *testing.T) {
|
||||
// Enable backups with 100 MB limit
|
||||
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
backupConfig.MaxBackupSizeMB = 100
|
||||
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.createBackupUseCase = &CreateMediumBackupUsecase{}
|
||||
|
||||
// Create a backup record
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode.MakeBackup(backup.ID, false)
|
||||
|
||||
// Verify backup completed successfully
|
||||
updatedBackup, err := backupRepository.FindByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
|
||||
assert.Equal(t, float64(50), updatedBackup.BackupSizeMb)
|
||||
assert.Nil(t, updatedBackup.FailMessage)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,12 +4,12 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/storages"
|
||||
@@ -26,49 +26,47 @@ type BackupCleaner struct {
|
||||
backupRepository *backups_core.BackupRepository
|
||||
storageService *storages.StorageService
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
billingService BillingService
|
||||
fieldEncryptor util_encryption.FieldEncryptor
|
||||
logger *slog.Logger
|
||||
backupRemoveListeners []backups_core.BackupRemoveListener
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) Run(ctx context.Context) {
|
||||
wasAlreadyRun := c.hasRun.Load()
|
||||
if c.hasRun.Swap(true) {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", c))
|
||||
}
|
||||
|
||||
c.runOnce.Do(func() {
|
||||
c.hasRun.Store(true)
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
retentionLog := c.logger.With("task_name", "clean_by_retention_policy")
|
||||
exceededLog := c.logger.With("task_name", "clean_exceeded_storage_backups")
|
||||
staleLog := c.logger.With("task_name", "clean_stale_basebackups")
|
||||
|
||||
ticker := time.NewTicker(cleanerTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
case <-ticker.C:
|
||||
if err := c.cleanByRetentionPolicy(retentionLog); err != nil {
|
||||
retentionLog.Error("failed to clean backups by retention policy", "error", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(cleanerTickerInterval)
|
||||
defer ticker.Stop()
|
||||
if err := c.cleanExceededStorageBackups(exceededLog); err != nil {
|
||||
exceededLog.Error("failed to clean exceeded backups", "error", err)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := c.cleanByRetentionPolicy(); err != nil {
|
||||
c.logger.Error("Failed to clean backups by retention policy", "error", err)
|
||||
}
|
||||
|
||||
if err := c.cleanExceededBackups(); err != nil {
|
||||
c.logger.Error("Failed to clean exceeded backups", "error", err)
|
||||
}
|
||||
|
||||
if err := c.cleanStaleUploadedBasebackups(); err != nil {
|
||||
c.logger.Error("Failed to clean stale uploaded basebackups", "error", err)
|
||||
}
|
||||
if err := c.cleanStaleUploadedBasebackups(staleLog); err != nil {
|
||||
staleLog.Error("failed to clean stale uploaded basebackups", "error", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", c))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,7 +102,7 @@ func (c *BackupCleaner) AddBackupRemoveListener(listener backups_core.BackupRemo
|
||||
c.backupRemoveListeners = append(c.backupRemoveListeners, listener)
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanStaleUploadedBasebackups() error {
|
||||
func (c *BackupCleaner) cleanStaleUploadedBasebackups(logger *slog.Logger) error {
|
||||
staleBackups, err := c.backupRepository.FindStaleUploadedBasebackups(
|
||||
time.Now().UTC().Add(-10 * time.Minute),
|
||||
)
|
||||
@@ -113,31 +111,30 @@ func (c *BackupCleaner) cleanStaleUploadedBasebackups() error {
|
||||
}
|
||||
|
||||
for _, backup := range staleBackups {
|
||||
backupLog := logger.With("database_id", backup.DatabaseID, "backup_id", backup.ID)
|
||||
|
||||
staleStorage, storageErr := c.storageService.GetStorageByID(backup.StorageID)
|
||||
if storageErr != nil {
|
||||
c.logger.Error(
|
||||
"Failed to get storage for stale basebackup cleanup",
|
||||
"backupId", backup.ID,
|
||||
"storageId", backup.StorageID,
|
||||
backupLog.Error(
|
||||
"failed to get storage for stale basebackup cleanup",
|
||||
"storage_id", backup.StorageID,
|
||||
"error", storageErr,
|
||||
)
|
||||
} else {
|
||||
if err := staleStorage.DeleteFile(c.fieldEncryptor, backup.FileName); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to delete stale basebackup file",
|
||||
"backupId", backup.ID,
|
||||
"fileName", backup.FileName,
|
||||
"error", err,
|
||||
backupLog.Error(
|
||||
fmt.Sprintf("failed to delete stale basebackup file: %s", backup.FileName),
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
metadataFileName := backup.FileName + ".metadata"
|
||||
if err := staleStorage.DeleteFile(c.fieldEncryptor, metadataFileName); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to delete stale basebackup metadata file",
|
||||
"backupId", backup.ID,
|
||||
"fileName", metadataFileName,
|
||||
"error", err,
|
||||
backupLog.Error(
|
||||
fmt.Sprintf("failed to delete stale basebackup metadata file: %s", metadataFileName),
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -147,77 +144,67 @@ func (c *BackupCleaner) cleanStaleUploadedBasebackups() error {
|
||||
backup.FailMessage = &failMsg
|
||||
|
||||
if err := c.backupRepository.Save(backup); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to mark stale uploaded basebackup as failed",
|
||||
"backupId", backup.ID,
|
||||
"error", err,
|
||||
)
|
||||
backupLog.Error("failed to mark stale uploaded basebackup as failed", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Marked stale uploaded basebackup as failed and cleaned storage",
|
||||
"backupId", backup.ID,
|
||||
"databaseId", backup.DatabaseID,
|
||||
)
|
||||
backupLog.Info("marked stale uploaded basebackup as failed and cleaned storage")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanByRetentionPolicy() error {
|
||||
func (c *BackupCleaner) cleanByRetentionPolicy(logger *slog.Logger) error {
|
||||
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
dbLog := logger.With("database_id", backupConfig.DatabaseID, "policy", backupConfig.RetentionPolicyType)
|
||||
|
||||
var cleanErr error
|
||||
|
||||
switch backupConfig.RetentionPolicyType {
|
||||
case backups_config.RetentionPolicyTypeCount:
|
||||
cleanErr = c.cleanByCount(backupConfig)
|
||||
cleanErr = c.cleanByCount(dbLog, backupConfig)
|
||||
case backups_config.RetentionPolicyTypeGFS:
|
||||
cleanErr = c.cleanByGFS(backupConfig)
|
||||
cleanErr = c.cleanByGFS(dbLog, backupConfig)
|
||||
default:
|
||||
cleanErr = c.cleanByTimePeriod(backupConfig)
|
||||
cleanErr = c.cleanByTimePeriod(dbLog, backupConfig)
|
||||
}
|
||||
|
||||
if cleanErr != nil {
|
||||
c.logger.Error(
|
||||
"Failed to clean backups by retention policy",
|
||||
"databaseId", backupConfig.DatabaseID,
|
||||
"policy", backupConfig.RetentionPolicyType,
|
||||
"error", cleanErr,
|
||||
)
|
||||
dbLog.Error("failed to clean backups by retention policy", "error", cleanErr)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanExceededBackups() error {
|
||||
func (c *BackupCleaner) cleanExceededStorageBackups(logger *slog.Logger) error {
|
||||
if !config.GetEnv().IsCloud {
|
||||
return nil
|
||||
}
|
||||
|
||||
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
if backupConfig.MaxBackupsTotalSizeMB <= 0 {
|
||||
dbLog := logger.With("database_id", backupConfig.DatabaseID)
|
||||
|
||||
subscription, subErr := c.billingService.GetSubscription(dbLog, backupConfig.DatabaseID)
|
||||
if subErr != nil {
|
||||
dbLog.Error("failed to get subscription for exceeded backups check", "error", subErr)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := c.cleanExceededBackupsForDatabase(
|
||||
backupConfig.DatabaseID,
|
||||
backupConfig.MaxBackupsTotalSizeMB,
|
||||
); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to clean exceeded backups for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
storageLimitMB := int64(subscription.GetBackupsStorageGB()) * 1024
|
||||
|
||||
if err := c.cleanExceededBackupsForDatabase(dbLog, backupConfig.DatabaseID, storageLimitMB); err != nil {
|
||||
dbLog.Error("failed to clean exceeded backups for database", "error", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -225,7 +212,7 @@ func (c *BackupCleaner) cleanExceededBackups() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanByTimePeriod(backupConfig *backups_config.BackupConfig) error {
|
||||
func (c *BackupCleaner) cleanByTimePeriod(logger *slog.Logger, backupConfig *backups_config.BackupConfig) error {
|
||||
if backupConfig.RetentionTimePeriod == "" {
|
||||
return nil
|
||||
}
|
||||
@@ -255,21 +242,17 @@ func (c *BackupCleaner) cleanByTimePeriod(backupConfig *backups_config.BackupCon
|
||||
}
|
||||
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
c.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
|
||||
logger.Error("failed to delete old backup", "backup_id", backup.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Deleted old backup",
|
||||
"backupId", backup.ID,
|
||||
"databaseId", backupConfig.DatabaseID,
|
||||
)
|
||||
logger.Info("deleted old backup", "backup_id", backup.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanByCount(backupConfig *backups_config.BackupConfig) error {
|
||||
func (c *BackupCleaner) cleanByCount(logger *slog.Logger, backupConfig *backups_config.BackupConfig) error {
|
||||
if backupConfig.RetentionCount <= 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -298,28 +281,20 @@ func (c *BackupCleaner) cleanByCount(backupConfig *backups_config.BackupConfig)
|
||||
}
|
||||
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to delete backup by count policy",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
logger.Error("failed to delete backup by count policy", "backup_id", backup.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Deleted backup by count policy",
|
||||
"backupId", backup.ID,
|
||||
"databaseId", backupConfig.DatabaseID,
|
||||
"retentionCount", backupConfig.RetentionCount,
|
||||
logger.Info(
|
||||
fmt.Sprintf("deleted backup by count policy: retention count is %d", backupConfig.RetentionCount),
|
||||
"backup_id", backup.ID,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanByGFS(backupConfig *backups_config.BackupConfig) error {
|
||||
func (c *BackupCleaner) cleanByGFS(logger *slog.Logger, backupConfig *backups_config.BackupConfig) error {
|
||||
if backupConfig.RetentionGfsHours <= 0 && backupConfig.RetentionGfsDays <= 0 &&
|
||||
backupConfig.RetentionGfsWeeks <= 0 && backupConfig.RetentionGfsMonths <= 0 &&
|
||||
backupConfig.RetentionGfsYears <= 0 {
|
||||
@@ -357,29 +332,20 @@ func (c *BackupCleaner) cleanByGFS(backupConfig *backups_config.BackupConfig) er
|
||||
}
|
||||
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to delete backup by GFS policy",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
logger.Error("failed to delete backup by GFS policy", "backup_id", backup.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Deleted backup by GFS policy",
|
||||
"backupId", backup.ID,
|
||||
"databaseId", backupConfig.DatabaseID,
|
||||
)
|
||||
logger.Info("deleted backup by GFS policy", "backup_id", backup.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanExceededBackupsForDatabase(
|
||||
logger *slog.Logger,
|
||||
databaseID uuid.UUID,
|
||||
limitperDbMB int64,
|
||||
limitPerDbMB int64,
|
||||
) error {
|
||||
for {
|
||||
backupsTotalSizeMB, err := c.backupRepository.GetTotalSizeByDatabase(databaseID)
|
||||
@@ -387,7 +353,7 @@ func (c *BackupCleaner) cleanExceededBackupsForDatabase(
|
||||
return err
|
||||
}
|
||||
|
||||
if backupsTotalSizeMB <= float64(limitperDbMB) {
|
||||
if backupsTotalSizeMB <= float64(limitPerDbMB) {
|
||||
break
|
||||
}
|
||||
|
||||
@@ -400,59 +366,27 @@ func (c *BackupCleaner) cleanExceededBackupsForDatabase(
|
||||
}
|
||||
|
||||
if len(oldestBackups) == 0 {
|
||||
c.logger.Warn(
|
||||
"No backups to delete but still over limit",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"totalSizeMB",
|
||||
backupsTotalSizeMB,
|
||||
"limitMB",
|
||||
limitperDbMB,
|
||||
)
|
||||
logger.Warn(fmt.Sprintf(
|
||||
"no backups to delete but still over limit: total size is %.1f MB, limit is %d MB",
|
||||
backupsTotalSizeMB, limitPerDbMB,
|
||||
))
|
||||
break
|
||||
}
|
||||
|
||||
backup := oldestBackups[0]
|
||||
if isRecentBackup(backup) {
|
||||
c.logger.Warn(
|
||||
"Oldest backup is too recent to delete, stopping size cleanup",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"totalSizeMB",
|
||||
backupsTotalSizeMB,
|
||||
"limitMB",
|
||||
limitperDbMB,
|
||||
)
|
||||
break
|
||||
}
|
||||
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to delete exceeded backup",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
logger.Error("failed to delete exceeded backup", "backup_id", backup.ID, "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Deleted exceeded backup",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"backupSizeMB",
|
||||
backup.BackupSizeMb,
|
||||
"totalSizeMB",
|
||||
backupsTotalSizeMB,
|
||||
"limitMB",
|
||||
limitperDbMB,
|
||||
logger.Info(
|
||||
fmt.Sprintf("deleted exceeded backup: backup size is %.1f MB, total size is %.1f MB, limit is %d MB",
|
||||
backup.BackupSizeMb, backupsTotalSizeMB, limitPerDbMB),
|
||||
"backup_id", backup.ID,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ func Test_BuildGFSKeepSet(t *testing.T) {
|
||||
// backupsEveryDay returns n backups, newest-first, each 1 day apart.
|
||||
backupsEveryDay := func(n int) []*backups_core.Backup {
|
||||
bs := make([]*backups_core.Backup, n)
|
||||
for i := 0; i < n; i++ {
|
||||
for i := range n {
|
||||
bs[i] = newBackup(ref.Add(-time.Duration(i) * day))
|
||||
}
|
||||
return bs
|
||||
@@ -42,7 +42,7 @@ func Test_BuildGFSKeepSet(t *testing.T) {
|
||||
// backupsEveryWeek returns n backups, newest-first, each 7 days apart.
|
||||
backupsEveryWeek := func(n int) []*backups_core.Backup {
|
||||
bs := make([]*backups_core.Backup, n)
|
||||
for i := 0; i < n; i++ {
|
||||
for i := range n {
|
||||
bs[i] = newBackup(ref.Add(-time.Duration(i) * week))
|
||||
}
|
||||
return bs
|
||||
@@ -53,7 +53,7 @@ func Test_BuildGFSKeepSet(t *testing.T) {
|
||||
// backupsEveryHour returns n backups, newest-first, each 1 hour apart.
|
||||
backupsEveryHour := func(n int) []*backups_core.Backup {
|
||||
bs := make([]*backups_core.Backup, n)
|
||||
for i := 0; i < n; i++ {
|
||||
for i := range n {
|
||||
bs[i] = newBackup(ref.Add(-time.Duration(i) * hour))
|
||||
}
|
||||
return bs
|
||||
@@ -62,7 +62,7 @@ func Test_BuildGFSKeepSet(t *testing.T) {
|
||||
// backupsEveryMonth returns n backups, newest-first, each ~1 month apart.
|
||||
backupsEveryMonth := func(n int) []*backups_core.Backup {
|
||||
bs := make([]*backups_core.Backup, n)
|
||||
for i := 0; i < n; i++ {
|
||||
for i := range n {
|
||||
bs[i] = newBackup(ref.AddDate(0, -i, 0))
|
||||
}
|
||||
return bs
|
||||
@@ -71,7 +71,7 @@ func Test_BuildGFSKeepSet(t *testing.T) {
|
||||
// backupsEveryYear returns n backups, newest-first, each 1 year apart.
|
||||
backupsEveryYear := func(n int) []*backups_core.Backup {
|
||||
bs := make([]*backups_core.Backup, n)
|
||||
for i := 0; i < n; i++ {
|
||||
for i := range n {
|
||||
bs[i] = newBackup(ref.AddDate(-i, 0, 0))
|
||||
}
|
||||
return bs
|
||||
@@ -410,7 +410,7 @@ func Test_CleanByGFS_KeepsCorrectBackupsPerSlot(t *testing.T) {
|
||||
|
||||
// Create 5 backups on 5 different days; only the 3 newest days should be kept
|
||||
var backupIDs []uuid.UUID
|
||||
for i := 0; i < 5; i++ {
|
||||
for i := range 5 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
@@ -425,7 +425,7 @@ func Test_CleanByGFS_KeepsCorrectBackupsPerSlot(t *testing.T) {
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -486,7 +486,7 @@ func Test_CleanByGFS_WithWeeklyAndMonthlySlots_KeepsWiderSpread(t *testing.T) {
|
||||
// Create one backup per week for 6 weeks (each on Monday of that week)
|
||||
// GFS should keep: 2 daily (most recent 2 unique days) + 2 weekly + 1 monthly = up to 5 unique
|
||||
var createdIDs []uuid.UUID
|
||||
for i := 0; i < 6; i++ {
|
||||
for i := range 6 {
|
||||
weekOffset := time.Duration(5-i) * 7 * 24 * time.Hour
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
@@ -502,7 +502,7 @@ func Test_CleanByGFS_WithWeeklyAndMonthlySlots_KeepsWiderSpread(t *testing.T) {
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -561,7 +561,7 @@ func Test_CleanByGFS_WithHourlySlots_KeepsCorrectBackups(t *testing.T) {
|
||||
|
||||
// Create 5 backups spaced 1 hour apart; only the 3 newest hours should be kept
|
||||
var backupIDs []uuid.UUID
|
||||
for i := 0; i < 5; i++ {
|
||||
for i := range 5 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
@@ -576,7 +576,7 @@ func Test_CleanByGFS_WithHourlySlots_KeepsCorrectBackups(t *testing.T) {
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -677,7 +677,7 @@ func Test_CleanByGFS_SkipsRecentBackup_WhenNotInKeepSet(t *testing.T) {
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -759,7 +759,7 @@ func Test_CleanByGFS_With20DailyBackups_KeepsOnlyExpectedCount(t *testing.T) {
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -824,8 +824,8 @@ func Test_CleanByGFS_WithMultipleBackupsPerDay_KeepsOnlyOnePerDailySlot(t *testi
|
||||
|
||||
// Create 3 backups per day for 10 days = 30 total, all beyond grace period.
|
||||
// Each day gets backups at base+0h, base+6h, base+12h.
|
||||
for day := 0; day < 10; day++ {
|
||||
for sub := 0; sub < 3; sub++ {
|
||||
for day := range 10 {
|
||||
for sub := range 3 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
@@ -844,7 +844,7 @@ func Test_CleanByGFS_WithMultipleBackupsPerDay_KeepsOnlyOnePerDailySlot(t *testi
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -915,7 +915,7 @@ func Test_CleanByGFS_With24HourlySlotsAnd23DailyBackups_DeletesExcessBackups(t *
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
for i := 0; i < 23; i++ {
|
||||
for i := range 23 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
@@ -929,7 +929,7 @@ func Test_CleanByGFS_With24HourlySlotsAnd23DailyBackups_DeletesExcessBackups(t *
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -985,7 +985,7 @@ func Test_CleanByGFS_WithDisabledHourlySlotsAnd23DailyBackups_DeletesExcessBacku
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
for i := 0; i < 23; i++ {
|
||||
for i := range 23 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
@@ -999,7 +999,7 @@ func Test_CleanByGFS_WithDisabledHourlySlotsAnd23DailyBackups_DeletesExcessBacku
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -1055,7 +1055,7 @@ func Test_CleanByGFS_WithDailySlotsAndWeeklyBackups_DeletesExcessBackups(t *test
|
||||
// Create 10 weekly backups (1 per week, all >2h old past grace period).
|
||||
// With 7d/4w config, correct behavior: ~8 kept (4 weekly + overlap with daily for recent ones).
|
||||
// Daily slots should NOT absorb weekly backups that are older than 7 days.
|
||||
for i := 0; i < 10; i++ {
|
||||
for i := range 10 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
@@ -1069,7 +1069,7 @@ func Test_CleanByGFS_WithDailySlotsAndWeeklyBackups_DeletesExcessBackups(t *test
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -1138,7 +1138,7 @@ func Test_CleanByGFS_WithWeeklySlotsAndMonthlyBackups_DeletesExcessBackups(t *te
|
||||
// With 52w/3m config, correct behavior: 3 kept (3 monthly slots; weekly should only
|
||||
// cover recent 52 weeks but not artificially retain old monthly backups).
|
||||
// Bug: all 8 kept because each monthly backup fills a unique weekly slot.
|
||||
for i := 0; i < 8; i++ {
|
||||
for i := range 8 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
@@ -1152,7 +1152,7 @@ func Test_CleanByGFS_WithWeeklySlotsAndMonthlyBackups_DeletesExcessBackups(t *te
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
billing_models "databasus-backend/internal/features/billing/models"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
@@ -17,6 +20,7 @@ import (
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/storage"
|
||||
"databasus-backend/internal/util/logger"
|
||||
"databasus-backend/internal/util/period"
|
||||
)
|
||||
|
||||
@@ -51,6 +55,7 @@ func Test_CleanOldBackups_DeletesBackupsOlderThanRetentionTimePeriod(t *testing.
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
@@ -89,7 +94,7 @@ func Test_CleanOldBackups_DeletesBackupsOlderThanRetentionTimePeriod(t *testing.
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -129,6 +134,7 @@ func Test_CleanOldBackups_SkipsDatabaseWithForeverRetentionPeriod(t *testing.T)
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
@@ -145,7 +151,7 @@ func Test_CleanOldBackups_SkipsDatabaseWithForeverRetentionPeriod(t *testing.T)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -154,7 +160,8 @@ func Test_CleanOldBackups_SkipsDatabaseWithForeverRetentionPeriod(t *testing.T)
|
||||
assert.Equal(t, oldBackup.ID, remainingBackups[0].ID)
|
||||
}
|
||||
|
||||
func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
|
||||
func Test_CleanExceededBackups_WhenUnderStorageLimit_NoBackupsDeleted(t *testing.T) {
|
||||
enableCloud(t)
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
@@ -178,33 +185,36 @@ func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
|
||||
interval := createTestInterval()
|
||||
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
MaxBackupsTotalSizeMB: 100,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
for i := range 3 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 16.67,
|
||||
BackupSizeMb: 100,
|
||||
CreatedAt: time.Now().UTC().Add(-time.Duration(i) * time.Hour),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanExceededBackups()
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
|
||||
}
|
||||
cleaner := CreateTestBackupCleaner(mockBilling)
|
||||
err = cleaner.cleanExceededStorageBackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -212,7 +222,8 @@ func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
|
||||
assert.Equal(t, 3, len(remainingBackups))
|
||||
}
|
||||
|
||||
func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T) {
|
||||
func Test_CleanExceededBackups_WhenOverStorageLimit_DeletesOldestBackups(t *testing.T) {
|
||||
enableCloud(t)
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
@@ -236,27 +247,29 @@ func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T)
|
||||
interval := createTestInterval()
|
||||
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
MaxBackupsTotalSizeMB: 30,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 5 backups at 300 MB each = 1500 MB total, limit = 1 GB (1024 MB)
|
||||
// Expect 2 oldest deleted, 3 remain (900 MB < 1024 MB)
|
||||
now := time.Now().UTC()
|
||||
var backupIDs []uuid.UUID
|
||||
for i := 0; i < 5; i++ {
|
||||
for i := range 5 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10,
|
||||
BackupSizeMb: 300,
|
||||
CreatedAt: now.Add(-time.Duration(4-i) * time.Hour),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
@@ -264,8 +277,11 @@ func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T)
|
||||
backupIDs = append(backupIDs, backup.ID)
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanExceededBackups()
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
|
||||
}
|
||||
cleaner := CreateTestBackupCleaner(mockBilling)
|
||||
err = cleaner.cleanExceededStorageBackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -284,6 +300,7 @@ func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T)
|
||||
}
|
||||
|
||||
func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
|
||||
enableCloud(t)
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
@@ -307,28 +324,29 @@ func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
|
||||
interval := createTestInterval()
|
||||
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
MaxBackupsTotalSizeMB: 50,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
// 3 completed at 500 MB each = 1500 MB, limit = 1 GB (1024 MB)
|
||||
completedBackups := make([]*backups_core.Backup, 3)
|
||||
for i := 0; i < 3; i++ {
|
||||
for i := range 3 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 30,
|
||||
BackupSizeMb: 500,
|
||||
CreatedAt: now.Add(-time.Duration(3-i) * time.Hour),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
@@ -347,8 +365,11 @@ func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
|
||||
err = backupRepository.Save(inProgressBackup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanExceededBackups()
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
|
||||
}
|
||||
cleaner := CreateTestBackupCleaner(mockBilling)
|
||||
err = cleaner.cleanExceededStorageBackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -365,7 +386,8 @@ func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
|
||||
assert.True(t, inProgressFound, "In-progress backup should not be deleted")
|
||||
}
|
||||
|
||||
func Test_CleanExceededBackups_WithZeroLimit_SkipsDatabase(t *testing.T) {
|
||||
func Test_CleanExceededBackups_WithZeroStorageLimit_RemovesAllBackups(t *testing.T) {
|
||||
enableCloud(t)
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
@@ -389,38 +411,42 @@ func Test_CleanExceededBackups_WithZeroLimit_SkipsDatabase(t *testing.T) {
|
||||
interval := createTestInterval()
|
||||
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
MaxBackupsTotalSizeMB: 0,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
for i := range 10 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 100,
|
||||
CreatedAt: time.Now().UTC().Add(-time.Duration(i) * time.Hour),
|
||||
CreatedAt: time.Now().UTC().Add(-time.Duration(i+2) * time.Hour),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanExceededBackups()
|
||||
// StorageGB=0 means no storage allowed — all backups should be removed
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{StorageGB: 0, Status: billing_models.StatusActive},
|
||||
}
|
||||
cleaner := CreateTestBackupCleaner(mockBilling)
|
||||
err = cleaner.cleanExceededStorageBackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 10, len(remainingBackups))
|
||||
assert.Equal(t, 0, len(remainingBackups))
|
||||
}
|
||||
|
||||
func Test_GetTotalSizeByDatabase_CalculatesCorrectly(t *testing.T) {
|
||||
@@ -522,13 +548,14 @@ func Test_CleanByCount_KeepsNewestNBackups_DeletesOlder(t *testing.T) {
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
now := time.Now().UTC()
|
||||
var backupIDs []uuid.UUID
|
||||
for i := 0; i < 5; i++ {
|
||||
for i := range 5 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
@@ -545,7 +572,7 @@ func Test_CleanByCount_KeepsNewestNBackups_DeletesOlder(t *testing.T) {
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -594,11 +621,12 @@ func Test_CleanByCount_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
for i := range 5 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
@@ -612,7 +640,7 @@ func Test_CleanByCount_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -651,13 +679,14 @@ func Test_CleanByCount_DoesNotDeleteInProgressBackups(t *testing.T) {
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
for i := range 3 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
@@ -682,7 +711,7 @@ func Test_CleanByCount_DoesNotDeleteInProgressBackups(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -776,6 +805,7 @@ func Test_CleanByTimePeriod_SkipsRecentBackup_EvenIfOlderThanRetention(t *testin
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
@@ -805,7 +835,7 @@ func Test_CleanByTimePeriod_SkipsRecentBackup_EvenIfOlderThanRetention(t *testin
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -847,6 +877,7 @@ func Test_CleanByCount_SkipsRecentBackup_EvenIfOverLimit(t *testing.T) {
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
@@ -893,7 +924,7 @@ func Test_CleanByCount_SkipsRecentBackup_EvenIfOverLimit(t *testing.T) {
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -914,7 +945,8 @@ func Test_CleanByCount_SkipsRecentBackup_EvenIfOverLimit(t *testing.T) {
|
||||
assert.True(t, remainingIDs[newestBackup.ID], "Newest backup should be preserved")
|
||||
}
|
||||
|
||||
func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testing.T) {
|
||||
func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverStorageLimit(t *testing.T) {
|
||||
enableCloud(t)
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
@@ -937,18 +969,18 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
|
||||
|
||||
interval := createTestInterval()
|
||||
|
||||
// Total size limit is 10 MB. We have two backups of 8 MB each (16 MB total).
|
||||
// Total size limit = 1 GB (1024 MB). Two backups of 600 MB each (1200 MB total).
|
||||
// The oldest backup was created 30 minutes ago — within the grace period.
|
||||
// The cleaner must stop and leave both backups intact.
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
MaxBackupsTotalSizeMB: 10,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
@@ -960,7 +992,7 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 8,
|
||||
BackupSizeMb: 600,
|
||||
CreatedAt: now.Add(-30 * time.Minute),
|
||||
}
|
||||
newerRecentBackup := &backups_core.Backup{
|
||||
@@ -968,7 +1000,7 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 8,
|
||||
BackupSizeMb: 600,
|
||||
CreatedAt: now.Add(-10 * time.Minute),
|
||||
}
|
||||
|
||||
@@ -977,8 +1009,11 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
|
||||
err = backupRepository.Save(newerRecentBackup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanExceededBackups()
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
|
||||
}
|
||||
cleaner := CreateTestBackupCleaner(mockBilling)
|
||||
err = cleaner.cleanExceededStorageBackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -991,6 +1026,82 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
|
||||
)
|
||||
}
|
||||
|
||||
func Test_CleanExceededStorageBackups_WhenNonCloud_SkipsCleanup(t *testing.T) {
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
interval := createTestInterval()
|
||||
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 5 backups at 500 MB each = 2500 MB, would exceed 1 GB limit in cloud mode
|
||||
now := time.Now().UTC()
|
||||
for i := range 5 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 500,
|
||||
CreatedAt: now.Add(-time.Duration(i+2) * time.Hour),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// IsCloud is false by default — cleaner should skip entirely
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
|
||||
}
|
||||
cleaner := CreateTestBackupCleaner(mockBilling)
|
||||
err = cleaner.cleanExceededStorageBackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 5, len(remainingBackups), "All backups must remain in non-cloud mode")
|
||||
}
|
||||
|
||||
type mockBillingService struct {
|
||||
subscription *billing_models.Subscription
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockBillingService) GetSubscription(
|
||||
logger *slog.Logger,
|
||||
databaseID uuid.UUID,
|
||||
) (*billing_models.Subscription, error) {
|
||||
return m.subscription, m.err
|
||||
}
|
||||
|
||||
// Mock listener for testing
|
||||
type mockBackupRemoveListener struct {
|
||||
onBeforeBackupRemove func(*backups_core.Backup) error
|
||||
@@ -1041,7 +1152,7 @@ func Test_CleanStaleUploadedBasebackups_MarksAsFailed(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanStaleUploadedBasebackups()
|
||||
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
updated, err := backupRepository.FindByID(staleBackup.ID)
|
||||
@@ -1088,7 +1199,7 @@ func Test_CleanStaleUploadedBasebackups_SkipsRecentUploads(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanStaleUploadedBasebackups()
|
||||
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
updated, err := backupRepository.FindByID(recentBackup.ID)
|
||||
@@ -1131,7 +1242,7 @@ func Test_CleanStaleUploadedBasebackups_SkipsActiveStreaming(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanStaleUploadedBasebackups()
|
||||
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
updated, err := backupRepository.FindByID(activeBackup.ID)
|
||||
@@ -1179,7 +1290,7 @@ func Test_CleanStaleUploadedBasebackups_CleansStorageFiles(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanStaleUploadedBasebackups()
|
||||
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
updated, err := backupRepository.FindByID(staleBackup.ID)
|
||||
@@ -1189,6 +1300,18 @@ func Test_CleanStaleUploadedBasebackups_CleansStorageFiles(t *testing.T) {
|
||||
assert.Contains(t, *updated.FailMessage, "finalization timed out")
|
||||
}
|
||||
|
||||
func enableCloud(t *testing.T) {
|
||||
t.Helper()
|
||||
config.GetEnv().IsCloud = true
|
||||
t.Cleanup(func() {
|
||||
config.GetEnv().IsCloud = false
|
||||
})
|
||||
}
|
||||
|
||||
func testLogger() *slog.Logger {
|
||||
return logger.GetLogger().With("task_name", "test")
|
||||
}
|
||||
|
||||
func createTestInterval() *intervals.Interval {
|
||||
timeOfDay := "04:00"
|
||||
interval := &intervals.Interval{
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -10,6 +9,7 @@ import (
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/usecases"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/billing"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
@@ -28,10 +28,10 @@ var backupCleaner = &BackupCleaner{
|
||||
backupRepository,
|
||||
storages.GetStorageService(),
|
||||
backups_config.GetBackupConfigService(),
|
||||
billing.GetBillingService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
logger.GetLogger(),
|
||||
[]backups_core.BackupRemoveListener{},
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
@@ -41,7 +41,6 @@ var backupNodesRegistry = &BackupNodesRegistry{
|
||||
cache_utils.DefaultCacheTimeout,
|
||||
cache_utils.NewPubSubManager(),
|
||||
cache_utils.NewPubSubManager(),
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
@@ -63,7 +62,6 @@ var backuperNode = &BackuperNode{
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
getNodeID(),
|
||||
time.Time{},
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
@@ -73,11 +71,11 @@ var backupsScheduler = &BackupsScheduler{
|
||||
taskCancelManager,
|
||||
backupNodesRegistry,
|
||||
databases.GetDatabaseService(),
|
||||
billing.GetBillingService(),
|
||||
time.Now().UTC(),
|
||||
logger.GetLogger(),
|
||||
make(map[uuid.UUID]BackupToNodeRelation),
|
||||
backuperNode,
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
billing_models "databasus-backend/internal/features/billing/models"
|
||||
)
|
||||
|
||||
type BillingService interface {
|
||||
GetSubscription(logger *slog.Logger, databaseID uuid.UUID) (*billing_models.Subscription, error)
|
||||
}
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -50,36 +49,30 @@ type BackupNodesRegistry struct {
|
||||
pubsubBackups *cache_utils.PubSubManager
|
||||
pubsubCompletions *cache_utils.PubSubManager
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) Run(ctx context.Context) {
|
||||
wasAlreadyRun := r.hasRun.Load()
|
||||
if r.hasRun.Swap(true) {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", r))
|
||||
}
|
||||
|
||||
r.runOnce.Do(func() {
|
||||
r.hasRun.Store(true)
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
|
||||
}
|
||||
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
|
||||
}
|
||||
ticker := time.NewTicker(cleanupTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
ticker := time.NewTicker(cleanupTickerInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes", "error", err)
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes", "error", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", r))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,8 +4,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -322,7 +320,7 @@ func Test_GetAvailableNodes_SkipsInvalidJsonData(t *testing.T) {
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), registry.timeout)
|
||||
defer cancel()
|
||||
|
||||
invalidKey := nodeInfoKeyPrefix + uuid.New().String() + nodeInfoKeySuffix
|
||||
@@ -331,7 +329,7 @@ func Test_GetAvailableNodes_SkipsInvalidJsonData(t *testing.T) {
|
||||
registry.client.B().Set().Key(invalidKey).Value("invalid json data").Build(),
|
||||
)
|
||||
defer func() {
|
||||
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
cleanupCtx, cleanupCancel := context.WithTimeout(t.Context(), registry.timeout)
|
||||
defer cleanupCancel()
|
||||
registry.client.Do(cleanupCtx, registry.client.B().Del().Key(invalidKey).Build())
|
||||
}()
|
||||
@@ -401,7 +399,7 @@ func Test_GetAvailableNodes_ExcludesStaleNodesFromCache(t *testing.T) {
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), registry.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
|
||||
@@ -419,7 +417,7 @@ func Test_GetAvailableNodes_ExcludesStaleNodesFromCache(t *testing.T) {
|
||||
modifiedData, err := json.Marshal(node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
setCtx, setCancel := context.WithTimeout(t.Context(), registry.timeout)
|
||||
defer setCancel()
|
||||
setResult := registry.client.Do(
|
||||
setCtx,
|
||||
@@ -464,7 +462,7 @@ func Test_GetBackupNodesStats_ExcludesStaleNodesFromCache(t *testing.T) {
|
||||
err = registry.IncrementBackupsInProgress(node3.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), registry.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
|
||||
@@ -482,7 +480,7 @@ func Test_GetBackupNodesStats_ExcludesStaleNodesFromCache(t *testing.T) {
|
||||
modifiedData, err := json.Marshal(node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
setCtx, setCancel := context.WithTimeout(t.Context(), registry.timeout)
|
||||
defer setCancel()
|
||||
setResult := registry.client.Do(
|
||||
setCtx,
|
||||
@@ -524,7 +522,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) {
|
||||
err = registry.IncrementBackupsInProgress(node2.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), registry.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
|
||||
@@ -542,7 +540,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) {
|
||||
modifiedData, err := json.Marshal(node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
setCtx, setCancel := context.WithTimeout(t.Context(), registry.timeout)
|
||||
defer setCancel()
|
||||
setResult := registry.client.Do(
|
||||
setCtx,
|
||||
@@ -553,7 +551,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) {
|
||||
err = registry.cleanupDeadNodes()
|
||||
assert.NoError(t, err)
|
||||
|
||||
checkCtx, checkCancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
checkCtx, checkCancel := context.WithTimeout(t.Context(), registry.timeout)
|
||||
defer checkCancel()
|
||||
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
|
||||
@@ -566,7 +564,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) {
|
||||
node2.ID.String(),
|
||||
nodeActiveBackupsSuffix,
|
||||
)
|
||||
counterCtx, counterCancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
counterCtx, counterCancel := context.WithTimeout(t.Context(), registry.timeout)
|
||||
defer counterCancel()
|
||||
counterResult := registry.client.Do(
|
||||
counterCtx,
|
||||
@@ -575,7 +573,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) {
|
||||
assert.Error(t, counterResult.Error())
|
||||
|
||||
activeInfoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node1.ID.String(), nodeInfoKeySuffix)
|
||||
activeCtx, activeCancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
activeCtx, activeCancel := context.WithTimeout(t.Context(), registry.timeout)
|
||||
defer activeCancel()
|
||||
activeResult := registry.client.Do(
|
||||
activeCtx,
|
||||
@@ -601,8 +599,6 @@ func createTestRegistry() *BackupNodesRegistry {
|
||||
timeout: cache_utils.DefaultCacheTimeout,
|
||||
pubsubBackups: cache_utils.NewPubSubManager(),
|
||||
pubsubCompletions: cache_utils.NewPubSubManager(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -732,7 +728,7 @@ func Test_SubscribeNodeForBackupsAssignment_HandlesInvalidJson(t *testing.T) {
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
err = registry.pubsubBackups.Publish(ctx, "backup:submit", "invalid json")
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -978,7 +974,7 @@ func Test_SubscribeForBackupsCompletions_HandlesInvalidJson(t *testing.T) {
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
err = registry.pubsubCompletions.Publish(ctx, "backup:completion", "invalid json")
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -1093,7 +1089,7 @@ func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) {
|
||||
receivedAll2 := []uuid.UUID{}
|
||||
receivedAll3 := []uuid.UUID{}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
for range 3 {
|
||||
select {
|
||||
case received := <-receivedBackups1:
|
||||
receivedAll1 = append(receivedAll1, received)
|
||||
@@ -1102,7 +1098,7 @@ func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
for range 3 {
|
||||
select {
|
||||
case received := <-receivedBackups2:
|
||||
receivedAll2 = append(receivedAll2, received)
|
||||
@@ -1111,7 +1107,7 @@ func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
for range 3 {
|
||||
select {
|
||||
case received := <-receivedBackups3:
|
||||
receivedAll3 = append(receivedAll3, received)
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -29,6 +28,7 @@ type BackupsScheduler struct {
|
||||
taskCancelManager *task_cancellation.TaskCancelManager
|
||||
backupNodesRegistry *BackupNodesRegistry
|
||||
databaseService *databases.DatabaseService
|
||||
billingService BillingService
|
||||
|
||||
lastBackupTime time.Time
|
||||
logger *slog.Logger
|
||||
@@ -36,68 +36,61 @@ type BackupsScheduler struct {
|
||||
backupToNodeRelations map[uuid.UUID]BackupToNodeRelation
|
||||
backuperNode *BackuperNode
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) Run(ctx context.Context) {
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
|
||||
if config.GetEnv().IsManyNodesMode {
|
||||
// wait other nodes to start
|
||||
time.Sleep(schedulerStartupDelay)
|
||||
}
|
||||
|
||||
if err := s.failBackupsInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail backups in progress", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err := s.backupNodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to subscribe to backup completions", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := s.backupNodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
|
||||
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(schedulerTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.checkDeadNodesAndFailBackups(); err != nil {
|
||||
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.runPendingBackups(); err != nil {
|
||||
s.logger.Error("Failed to run pending backups", "error", err)
|
||||
}
|
||||
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
if s.hasRun.Swap(true) {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
|
||||
if config.GetEnv().IsManyNodesMode {
|
||||
// wait other nodes to start
|
||||
time.Sleep(schedulerStartupDelay)
|
||||
}
|
||||
|
||||
if err := s.failBackupsInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail backups in progress", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err := s.backupNodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to subscribe to backup completions", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := s.backupNodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
|
||||
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(schedulerTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.checkDeadNodesAndFailBackups(); err != nil {
|
||||
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.runPendingBackups(); err != nil {
|
||||
s.logger.Error("Failed to run pending backups", "error", err)
|
||||
}
|
||||
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) IsSchedulerRunning() bool {
|
||||
@@ -127,6 +120,34 @@ func (s *BackupsScheduler) StartBackup(database *databases.Database, isCallNotif
|
||||
return
|
||||
}
|
||||
|
||||
if config.GetEnv().IsCloud {
|
||||
subscription, subErr := s.billingService.GetSubscription(s.logger, database.ID)
|
||||
if subErr != nil || !subscription.CanCreateNewBackups() {
|
||||
failMessage := "subscription has expired, please renew"
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: *backupConfig.StorageID,
|
||||
Status: backups_core.BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
IsSkipRetry: true,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
backup.GenerateFilename(database.Name)
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error(
|
||||
"failed to save failed backup for expired subscription",
|
||||
"database_id", database.ID,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check for existing in-progress backups
|
||||
inProgressBackups, err := s.backupRepository.FindByDatabaseIdAndStatus(
|
||||
database.ID,
|
||||
@@ -346,6 +367,27 @@ func (s *BackupsScheduler) runPendingBackups() error {
|
||||
continue
|
||||
}
|
||||
|
||||
if config.GetEnv().IsCloud {
|
||||
subscription, subErr := s.billingService.GetSubscription(s.logger, backupConfig.DatabaseID)
|
||||
if subErr != nil {
|
||||
s.logger.Warn(
|
||||
"failed to get subscription, skipping backup",
|
||||
"database_id", backupConfig.DatabaseID,
|
||||
"error", subErr,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
if !subscription.CanCreateNewBackups() {
|
||||
s.logger.Debug(
|
||||
"subscription is not active, skipping scheduled backup",
|
||||
"database_id", backupConfig.DatabaseID,
|
||||
"subscription_status", subscription.Status,
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
s.StartBackup(database, remainedBackupTryCount == 1)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
billing_models "databasus-backend/internal/features/billing/models"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
@@ -968,7 +969,7 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
// Start scheduler so it can handle task completions
|
||||
scheduler := CreateTestScheduler()
|
||||
scheduler := CreateTestScheduler(nil)
|
||||
schedulerCancel := StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
@@ -1065,7 +1066,7 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
// Start scheduler so it can handle task completions
|
||||
scheduler := CreateTestScheduler()
|
||||
scheduler := CreateTestScheduler(nil)
|
||||
schedulerCancel := StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
@@ -1332,7 +1333,7 @@ func Test_StartBackup_When2BackupsStartedForDifferentDatabases_BothUseCasesAreCa
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
// Create scheduler
|
||||
scheduler := CreateTestScheduler()
|
||||
scheduler := CreateTestScheduler(nil)
|
||||
schedulerCancel := StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
@@ -1458,3 +1459,313 @@ func Test_StartBackup_When2BackupsStartedForDifferentDatabases_BothUseCasesAreCa
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartBackup_WhenCloudAndSubscriptionExpired_CreatesFailedBackup(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{
|
||||
Status: billing_models.StatusExpired,
|
||||
},
|
||||
}
|
||||
scheduler := CreateTestScheduler(mockBilling)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
enableCloud(t)
|
||||
|
||||
scheduler.StartBackup(database, false)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
|
||||
newestBackup := backups[0]
|
||||
assert.Equal(t, backups_core.BackupStatusFailed, newestBackup.Status)
|
||||
assert.NotNil(t, newestBackup.FailMessage)
|
||||
assert.Equal(t, "subscription has expired, please renew", *newestBackup.FailMessage)
|
||||
assert.True(t, newestBackup.IsSkipRetry)
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartBackup_WhenCloudAndSubscriptionActive_ProceedsNormally(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{
|
||||
Status: billing_models.StatusActive,
|
||||
StorageGB: 10,
|
||||
},
|
||||
}
|
||||
scheduler := CreateTestScheduler(mockBilling)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
enableCloud(t)
|
||||
|
||||
scheduler.StartBackup(database, false)
|
||||
|
||||
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
|
||||
newestBackup := backups[0]
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, newestBackup.Status)
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_RunPendingBackups_WhenCloudAndSubscriptionExpired_SilentlySkips(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{
|
||||
Status: billing_models.StatusExpired,
|
||||
},
|
||||
}
|
||||
scheduler := CreateTestScheduler(mockBilling)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
|
||||
backupConfig.RetentionTimePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupRepository.Save(&backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
enableCloud(t)
|
||||
|
||||
scheduler.runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1, "No new backup should be created, scheduler silently skips expired subscriptions")
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartBackup_WhenNotCloudAndSubscriptionExpired_ProceedsNormally(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{
|
||||
Status: billing_models.StatusExpired,
|
||||
},
|
||||
}
|
||||
scheduler := CreateTestScheduler(mockBilling)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
scheduler.StartBackup(database, false)
|
||||
|
||||
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backups[0].Status,
|
||||
"Billing check should not apply in non-cloud mode")
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_RunPendingBackups_WhenNotCloudAndSubscriptionExpired_ProceedsNormally(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{
|
||||
Status: billing_models.StatusExpired,
|
||||
},
|
||||
}
|
||||
scheduler := CreateTestScheduler(mockBilling)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
|
||||
backupConfig.RetentionTimePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupRepository.Save(&backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
scheduler.runPendingBackups()
|
||||
|
||||
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 2, "Billing check should not apply in non-cloud mode, new backup should be created")
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package backuping
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -35,58 +34,70 @@ func CreateTestRouter() *gin.Engine {
|
||||
return router
|
||||
}
|
||||
|
||||
func CreateTestBackupCleaner(billingService BillingService) *BackupCleaner {
|
||||
return &BackupCleaner{
|
||||
backupRepository,
|
||||
storages.GetStorageService(),
|
||||
backups_config.GetBackupConfigService(),
|
||||
billingService,
|
||||
encryption.GetFieldEncryptor(),
|
||||
logger.GetLogger(),
|
||||
[]backups_core.BackupRemoveListener{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestBackuperNode() *BackuperNode {
|
||||
return &BackuperNode{
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
workspaceService: workspaces_services.GetWorkspaceService(),
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
notificationSender: notifiers.GetNotifierService(),
|
||||
backupCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
logger: logger.GetLogger(),
|
||||
createBackupUseCase: usecases.GetCreateBackupUsecase(),
|
||||
nodeID: uuid.New(),
|
||||
lastHeartbeat: time.Time{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
databases.GetDatabaseService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
taskCancelManager,
|
||||
backupNodesRegistry,
|
||||
logger.GetLogger(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
uuid.New(),
|
||||
time.Time{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestBackuperNodeWithUseCase(useCase backups_core.CreateBackupUsecase) *BackuperNode {
|
||||
return &BackuperNode{
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
workspaceService: workspaces_services.GetWorkspaceService(),
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
notificationSender: notifiers.GetNotifierService(),
|
||||
backupCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
logger: logger.GetLogger(),
|
||||
createBackupUseCase: useCase,
|
||||
nodeID: uuid.New(),
|
||||
lastHeartbeat: time.Time{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
databases.GetDatabaseService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
taskCancelManager,
|
||||
backupNodesRegistry,
|
||||
logger.GetLogger(),
|
||||
useCase,
|
||||
uuid.New(),
|
||||
time.Time{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestScheduler() *BackupsScheduler {
|
||||
func CreateTestScheduler(billingService BillingService) *BackupsScheduler {
|
||||
return &BackupsScheduler{
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
taskCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
lastBackupTime: time.Now().UTC(),
|
||||
logger: logger.GetLogger(),
|
||||
backupToNodeRelations: make(map[uuid.UUID]BackupToNodeRelation),
|
||||
backuperNode: CreateTestBackuperNode(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
taskCancelManager,
|
||||
backupNodesRegistry,
|
||||
databases.GetDatabaseService(),
|
||||
billingService,
|
||||
time.Now().UTC(),
|
||||
logger.GetLogger(),
|
||||
make(map[uuid.UUID]BackupToNodeRelation),
|
||||
CreateTestBackuperNode(),
|
||||
atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -40,12 +40,15 @@ func (c *BackupController) RegisterPublicRoutes(router *gin.RouterGroup) {
|
||||
|
||||
// GetBackups
|
||||
// @Summary Get backups for a database
|
||||
// @Description Get paginated backups for the specified database
|
||||
// @Description Get paginated backups for the specified database with optional filters
|
||||
// @Tags backups
|
||||
// @Produce json
|
||||
// @Param database_id query string true "Database ID"
|
||||
// @Param limit query int false "Number of items per page" default(10)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param status query []string false "Filter by backup status (can be repeated)" Enums(IN_PROGRESS, COMPLETED, FAILED, CANCELED)
|
||||
// @Param beforeDate query string false "Filter backups created before this date (RFC3339)" format(date-time)
|
||||
// @Param pgWalBackupType query string false "Filter by WAL backup type" Enums(PG_FULL_BACKUP, PG_WAL_SEGMENT)
|
||||
// @Success 200 {object} backups_dto.GetBackupsResponse
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
@@ -70,7 +73,9 @@ func (c *BackupController) GetBackups(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.backupService.GetBackups(user, databaseID, request.Limit, request.Offset)
|
||||
filters := c.buildBackupFilters(&request)
|
||||
|
||||
response, err := c.backupService.GetBackups(user, databaseID, request.Limit, request.Offset, filters)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -359,3 +364,35 @@ func (c *BackupController) startDownloadHeartbeat(ctx context.Context, userID uu
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *BackupController) buildBackupFilters(
|
||||
request *backups_dto.GetBackupsRequest,
|
||||
) *backups_core.BackupFilters {
|
||||
isHasFilters := len(request.Statuses) > 0 ||
|
||||
request.BeforeDate != nil ||
|
||||
request.PgWalBackupType != nil
|
||||
|
||||
if !isHasFilters {
|
||||
return nil
|
||||
}
|
||||
|
||||
filters := &backups_core.BackupFilters{}
|
||||
|
||||
if len(request.Statuses) > 0 {
|
||||
statuses := make([]backups_core.BackupStatus, 0, len(request.Statuses))
|
||||
for _, statusStr := range request.Statuses {
|
||||
statuses = append(statuses, backups_core.BackupStatus(statusStr))
|
||||
}
|
||||
|
||||
filters.Statuses = statuses
|
||||
}
|
||||
|
||||
filters.BeforeDate = request.BeforeDate
|
||||
|
||||
if request.PgWalBackupType != nil {
|
||||
walType := backups_core.PgWalBackupType(*request.PgWalBackupType)
|
||||
filters.PgWalBackupType = &walType
|
||||
}
|
||||
|
||||
return filters
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
@@ -140,6 +141,225 @@ func Test_GetBackups_PermissionsEnforced(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GetBackups_WithStatusFilter_ReturnsFilteredBackups(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
CreatedAt: now.Add(-3 * time.Hour),
|
||||
})
|
||||
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
|
||||
Status: backups_core.BackupStatusFailed,
|
||||
CreatedAt: now.Add(-2 * time.Hour),
|
||||
})
|
||||
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
|
||||
Status: backups_core.BackupStatusCanceled,
|
||||
CreatedAt: now.Add(-1 * time.Hour),
|
||||
})
|
||||
|
||||
// Single status filter
|
||||
var singleResponse backups_dto.GetBackupsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups?database_id=%s&status=COMPLETED", database.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&singleResponse,
|
||||
)
|
||||
|
||||
assert.Equal(t, int64(1), singleResponse.Total)
|
||||
assert.Len(t, singleResponse.Backups, 1)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, singleResponse.Backups[0].Status)
|
||||
|
||||
// Multiple status filter
|
||||
var multiResponse backups_dto.GetBackupsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/backups?database_id=%s&status=COMPLETED&status=FAILED",
|
||||
database.ID.String(),
|
||||
),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&multiResponse,
|
||||
)
|
||||
|
||||
assert.Equal(t, int64(2), multiResponse.Total)
|
||||
assert.Len(t, multiResponse.Backups, 2)
|
||||
|
||||
for _, backup := range multiResponse.Backups {
|
||||
assert.True(
|
||||
t,
|
||||
backup.Status == backups_core.BackupStatusCompleted ||
|
||||
backup.Status == backups_core.BackupStatusFailed,
|
||||
"expected COMPLETED or FAILED, got %s", backup.Status,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GetBackups_WithBeforeDateFilter_ReturnsFilteredBackups(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
now := time.Now().UTC()
|
||||
cutoff := now.Add(-1 * time.Hour)
|
||||
|
||||
olderBackup := CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
CreatedAt: now.Add(-3 * time.Hour),
|
||||
})
|
||||
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
CreatedAt: now,
|
||||
})
|
||||
|
||||
var response backups_dto.GetBackupsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/backups?database_id=%s&beforeDate=%s",
|
||||
database.ID.String(),
|
||||
cutoff.Format(time.RFC3339),
|
||||
),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.Equal(t, int64(1), response.Total)
|
||||
assert.Len(t, response.Backups, 1)
|
||||
assert.Equal(t, olderBackup.ID, response.Backups[0].ID)
|
||||
}
|
||||
|
||||
func Test_GetBackups_WithPgWalBackupTypeFilter_ReturnsFilteredBackups(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
now := time.Now().UTC()
|
||||
fullBackupType := backups_core.PgWalBackupTypeFullBackup
|
||||
walSegmentType := backups_core.PgWalBackupTypeWalSegment
|
||||
|
||||
fullBackup := CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
CreatedAt: now.Add(-2 * time.Hour),
|
||||
PgWalBackupType: &fullBackupType,
|
||||
})
|
||||
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
CreatedAt: now.Add(-1 * time.Hour),
|
||||
PgWalBackupType: &walSegmentType,
|
||||
})
|
||||
|
||||
var response backups_dto.GetBackupsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/backups?database_id=%s&pgWalBackupType=PG_FULL_BACKUP",
|
||||
database.ID.String(),
|
||||
),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.Equal(t, int64(1), response.Total)
|
||||
assert.Len(t, response.Backups, 1)
|
||||
assert.Equal(t, fullBackup.ID, response.Backups[0].ID)
|
||||
}
|
||||
|
||||
func Test_GetBackups_WithCombinedFilters_ReturnsFilteredBackups(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
now := time.Now().UTC()
|
||||
cutoff := now.Add(-1 * time.Hour)
|
||||
|
||||
// Old completed — should match
|
||||
oldCompleted := CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
CreatedAt: now.Add(-3 * time.Hour),
|
||||
})
|
||||
// Old failed — should NOT match (wrong status)
|
||||
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
|
||||
Status: backups_core.BackupStatusFailed,
|
||||
CreatedAt: now.Add(-2 * time.Hour),
|
||||
})
|
||||
// New completed — should NOT match (too recent)
|
||||
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
CreatedAt: now,
|
||||
})
|
||||
|
||||
var response backups_dto.GetBackupsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/backups?database_id=%s&status=COMPLETED&beforeDate=%s",
|
||||
database.ID.String(),
|
||||
cutoff.Format(time.RFC3339),
|
||||
),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.Equal(t, int64(1), response.Total)
|
||||
assert.Len(t, response.Backups, 1)
|
||||
assert.Equal(t, oldCompleted.ID, response.Backups[0].ID)
|
||||
}
|
||||
|
||||
func Test_CreateBackup_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -376,7 +596,7 @@ func Test_DeleteBackup_PermissionsEnforced(t *testing.T) {
|
||||
ownerUser, err := userService.GetUserFromToken(owner.Token)
|
||||
assert.NoError(t, err)
|
||||
|
||||
response, err := backups_services.GetBackupService().GetBackups(ownerUser, database.ID, 10, 0)
|
||||
response, err := backups_services.GetBackupService().GetBackups(ownerUser, database.ID, 10, 0, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, len(response.Backups))
|
||||
}
|
||||
@@ -1263,7 +1483,7 @@ func Test_MakeBackup_VerifyBackupAndMetadataFilesExistInStorage(t *testing.T) {
|
||||
backuperCancel := backuping.StartBackuperNodeForTest(t, backuperNode)
|
||||
defer backuping.StopBackuperNodeForTest(t, backuperCancel, backuperNode)
|
||||
|
||||
scheduler := backuping.CreateTestScheduler()
|
||||
scheduler := backuping.CreateTestScheduler(nil)
|
||||
schedulerCancel := backuping.StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
@@ -1297,14 +1517,14 @@ func Test_MakeBackup_VerifyBackupAndMetadataFilesExistInStorage(t *testing.T) {
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
|
||||
backupFile, err := backupStorage.GetFile(encryptor, backup.FileName)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
backupFile.Close()
|
||||
|
||||
metadataFile, err := backupStorage.GetFile(encryptor, backup.FileName+".metadata")
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
metadataContent, err := io.ReadAll(metadataFile)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
metadataFile.Close()
|
||||
|
||||
var storageMetadata backups_common.BackupMetadata
|
||||
@@ -1838,7 +2058,7 @@ func Test_DeleteBackup_RemovesBackupAndMetadataFilesFromDisk(t *testing.T) {
|
||||
backuperCancel := backuping.StartBackuperNodeForTest(t, backuperNode)
|
||||
defer backuping.StopBackuperNodeForTest(t, backuperCancel, backuperNode)
|
||||
|
||||
scheduler := backuping.CreateTestScheduler()
|
||||
scheduler := backuping.CreateTestScheduler(nil)
|
||||
schedulerCancel := backuping.StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
|
||||
@@ -95,3 +95,33 @@ func CreateTestBackup(databaseID, storageID uuid.UUID) *backups_core.Backup {
|
||||
|
||||
return backup
|
||||
}
|
||||
|
||||
type TestBackupOptions struct {
|
||||
Status backups_core.BackupStatus
|
||||
CreatedAt time.Time
|
||||
PgWalBackupType *backups_core.PgWalBackupType
|
||||
}
|
||||
|
||||
// CreateTestBackupWithOptions creates a test backup with custom status, time, and WAL type
|
||||
func CreateTestBackupWithOptions(
|
||||
databaseID, storageID uuid.UUID,
|
||||
opts TestBackupOptions,
|
||||
) *backups_core.Backup {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: databaseID,
|
||||
StorageID: storageID,
|
||||
Status: opts.Status,
|
||||
BackupSizeMb: 10.5,
|
||||
BackupDurationMs: 1000,
|
||||
PgWalBackupType: opts.PgWalBackupType,
|
||||
CreatedAt: opts.CreatedAt,
|
||||
}
|
||||
|
||||
repo := &backups_core.BackupRepository{}
|
||||
if err := repo.Save(backup); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return backup
|
||||
}
|
||||
|
||||
9
backend/internal/features/backups/backups/core/dto.go
Normal file
9
backend/internal/features/backups/backups/core/dto.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package backups_core
|
||||
|
||||
import "time"
|
||||
|
||||
type BackupFilters struct {
|
||||
Statuses []BackupStatus
|
||||
BeforeDate *time.Time
|
||||
PgWalBackupType *PgWalBackupType
|
||||
}
|
||||
@@ -422,3 +422,67 @@ func (r *BackupRepository) FindLastWalSegmentAfter(
|
||||
|
||||
return &backup, nil
|
||||
}
|
||||
|
||||
func (r *BackupRepository) FindByDatabaseIDWithFiltersAndPagination(
|
||||
databaseID uuid.UUID,
|
||||
filters *BackupFilters,
|
||||
limit, offset int,
|
||||
) ([]*Backup, error) {
|
||||
var backups []*Backup
|
||||
|
||||
query := storage.
|
||||
GetDb().
|
||||
Where("database_id = ?", databaseID)
|
||||
|
||||
if filters != nil {
|
||||
query = filters.applyToQuery(query)
|
||||
}
|
||||
|
||||
if err := query.
|
||||
Order("created_at DESC").
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&backups).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return backups, nil
|
||||
}
|
||||
|
||||
func (r *BackupRepository) CountByDatabaseIDWithFilters(
|
||||
databaseID uuid.UUID,
|
||||
filters *BackupFilters,
|
||||
) (int64, error) {
|
||||
var count int64
|
||||
|
||||
query := storage.
|
||||
GetDb().
|
||||
Model(&Backup{}).
|
||||
Where("database_id = ?", databaseID)
|
||||
|
||||
if filters != nil {
|
||||
query = filters.applyToQuery(query)
|
||||
}
|
||||
|
||||
if err := query.Count(&count).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (f *BackupFilters) applyToQuery(query *gorm.DB) *gorm.DB {
|
||||
if len(f.Statuses) > 0 {
|
||||
query = query.Where("status IN ?", f.Statuses)
|
||||
}
|
||||
|
||||
if f.BeforeDate != nil {
|
||||
query = query.Where("created_at < ?", *f.BeforeDate)
|
||||
}
|
||||
|
||||
if f.PgWalBackupType != nil {
|
||||
query = query.Where("pg_wal_backup_type = ?", *f.PgWalBackupType)
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
@@ -13,38 +12,31 @@ type DownloadTokenBackgroundService struct {
|
||||
downloadTokenService *DownloadTokenService
|
||||
logger *slog.Logger
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *DownloadTokenBackgroundService) Run(ctx context.Context) {
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
s.logger.Info("Starting download token cleanup background service")
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
|
||||
s.logger.Error("Failed to clean expired download tokens", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
if s.hasRun.Swap(true) {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
|
||||
s.logger.Info("Starting download token cleanup background service")
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
|
||||
s.logger.Error("Failed to clean expired download tokens", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/logger"
|
||||
@@ -37,8 +34,6 @@ func init() {
|
||||
downloadTokenBackgroundService = &DownloadTokenBackgroundService{
|
||||
downloadTokenService: downloadTokenService,
|
||||
logger: logger.GetLogger(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -11,9 +11,12 @@ import (
|
||||
)
|
||||
|
||||
type GetBackupsRequest struct {
|
||||
DatabaseID string `form:"database_id" binding:"required"`
|
||||
Limit int `form:"limit"`
|
||||
Offset int `form:"offset"`
|
||||
DatabaseID string `form:"database_id" binding:"required"`
|
||||
Limit int `form:"limit"`
|
||||
Offset int `form:"offset"`
|
||||
Statuses []string `form:"status"`
|
||||
BeforeDate *time.Time `form:"beforeDate"`
|
||||
PgWalBackupType *string `form:"pgWalBackupType"`
|
||||
}
|
||||
|
||||
type GetBackupsResponse struct {
|
||||
|
||||
@@ -2,7 +2,6 @@ package backups_services
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
@@ -59,26 +58,11 @@ func GetWalService() *PostgreWalBackupService {
|
||||
return walService
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
var SetupDependencies = sync.OnceFunc(func() {
|
||||
backups_config.
|
||||
GetBackupConfigService().
|
||||
SetDatabaseStorageChangeListener(backupService)
|
||||
|
||||
func SetupDependencies() {
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
backups_config.
|
||||
GetBackupConfigService().
|
||||
SetDatabaseStorageChangeListener(backupService)
|
||||
|
||||
databases.GetDatabaseService().AddDbRemoveListener(backupService)
|
||||
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
databases.GetDatabaseService().AddDbRemoveListener(backupService)
|
||||
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
|
||||
})
|
||||
|
||||
@@ -109,6 +109,7 @@ func (s *BackupService) GetBackups(
|
||||
user *users_models.User,
|
||||
databaseID uuid.UUID,
|
||||
limit, offset int,
|
||||
filters *backups_core.BackupFilters,
|
||||
) (*backups_dto.GetBackupsResponse, error) {
|
||||
database, err := s.databaseService.GetDatabaseByID(databaseID)
|
||||
if err != nil {
|
||||
@@ -134,12 +135,14 @@ func (s *BackupService) GetBackups(
|
||||
offset = 0
|
||||
}
|
||||
|
||||
backups, err := s.backupRepository.FindByDatabaseIDWithPagination(databaseID, limit, offset)
|
||||
backups, err := s.backupRepository.FindByDatabaseIDWithFiltersAndPagination(
|
||||
databaseID, filters, limit, offset,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
total, err := s.backupRepository.CountByDatabaseID(databaseID)
|
||||
total, err := s.backupRepository.CountByDatabaseIDWithFilters(databaseID, filters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -281,15 +281,9 @@ func (uc *CreateMariadbBackupUsecase) createTempMyCnfFile(
|
||||
mdbConfig *mariadbtypes.MariadbDatabase,
|
||||
password string,
|
||||
) (string, error) {
|
||||
tempFolder := config.GetEnv().TempFolder
|
||||
if err := os.MkdirAll(tempFolder, 0o700); err != nil {
|
||||
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
|
||||
}
|
||||
if err := os.Chmod(tempFolder, 0o700); err != nil {
|
||||
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
|
||||
}
|
||||
|
||||
tempDir, err := os.MkdirTemp(tempFolder, "mycnf_"+uuid.New().String())
|
||||
// Credential files use OS temp dir (/tmp) because some filesystems
|
||||
// (e.g. ZFS on TrueNAS) ignore chmod, causing "group or world access" errors.
|
||||
tempDir, err := os.MkdirTemp(os.TempDir(), "mycnf_"+uuid.New().String())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temp directory: %w", err)
|
||||
}
|
||||
|
||||
@@ -300,15 +300,9 @@ func (uc *CreateMysqlBackupUsecase) createTempMyCnfFile(
|
||||
myConfig *mysqltypes.MysqlDatabase,
|
||||
password string,
|
||||
) (string, error) {
|
||||
tempFolder := config.GetEnv().TempFolder
|
||||
if err := os.MkdirAll(tempFolder, 0o700); err != nil {
|
||||
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
|
||||
}
|
||||
if err := os.Chmod(tempFolder, 0o700); err != nil {
|
||||
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
|
||||
}
|
||||
|
||||
tempDir, err := os.MkdirTemp(tempFolder, "mycnf_"+uuid.New().String())
|
||||
// Credential files use OS temp dir (/tmp) because some filesystems
|
||||
// (e.g. ZFS on TrueNAS) ignore chmod, causing "group or world access" errors.
|
||||
tempDir, err := os.MkdirTemp(os.TempDir(), "mycnf_"+uuid.New().String())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temp directory: %w", err)
|
||||
}
|
||||
|
||||
@@ -747,15 +747,9 @@ func (uc *CreatePostgresqlBackupUsecase) createTempPgpassFile(
|
||||
escapedPassword,
|
||||
)
|
||||
|
||||
tempFolder := config.GetEnv().TempFolder
|
||||
if err := os.MkdirAll(tempFolder, 0o700); err != nil {
|
||||
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
|
||||
}
|
||||
if err := os.Chmod(tempFolder, 0o700); err != nil {
|
||||
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
|
||||
}
|
||||
|
||||
tempDir, err := os.MkdirTemp(tempFolder, "pgpass_"+uuid.New().String())
|
||||
// Credential files use OS temp dir (/tmp) because some filesystems
|
||||
// (e.g. ZFS on TrueNAS) ignore chmod, causing "group or world access" errors.
|
||||
tempDir, err := os.MkdirTemp(os.TempDir(), "pgpass_"+uuid.New().String())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temporary directory: %w", err)
|
||||
}
|
||||
|
||||
@@ -16,7 +16,6 @@ type BackupConfigController struct {
|
||||
|
||||
func (c *BackupConfigController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
router.POST("/backup-configs/save", c.SaveBackupConfig)
|
||||
router.GET("/backup-configs/database/:id/plan", c.GetDatabasePlan)
|
||||
router.GET("/backup-configs/database/:id", c.GetBackupConfigByDbID)
|
||||
router.GET("/backup-configs/storage/:id/is-using", c.IsStorageUsing)
|
||||
router.GET("/backup-configs/storage/:id/databases-count", c.CountDatabasesForStorage)
|
||||
@@ -93,39 +92,6 @@ func (c *BackupConfigController) GetBackupConfigByDbID(ctx *gin.Context) {
|
||||
ctx.JSON(http.StatusOK, backupConfig)
|
||||
}
|
||||
|
||||
// GetDatabasePlan
|
||||
// @Summary Get database plan by database ID
|
||||
// @Description Get the plan limits for a specific database (max backup size, max total size, max storage period)
|
||||
// @Tags backup-configs
|
||||
// @Produce json
|
||||
// @Param id path string true "Database ID"
|
||||
// @Success 200 {object} plans.DatabasePlan
|
||||
// @Failure 400 {object} map[string]string "Invalid database ID"
|
||||
// @Failure 401 {object} map[string]string "User not authenticated"
|
||||
// @Failure 404 {object} map[string]string "Database not found or access denied"
|
||||
// @Router /backup-configs/database/{id}/plan [get]
|
||||
func (c *BackupConfigController) GetDatabasePlan(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
|
||||
return
|
||||
}
|
||||
|
||||
plan, err := c.backupConfigService.GetDatabasePlan(user, id)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusNotFound, gin.H{"error": "database plan not found"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, plan)
|
||||
}
|
||||
|
||||
// IsStorageUsing
|
||||
// @Summary Check if storage is being used
|
||||
// @Description Check if a storage is currently being used by any backup configuration
|
||||
|
||||
@@ -17,14 +17,12 @@ import (
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/features/storages"
|
||||
local_storage "databasus-backend/internal/features/storages/models/local"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/storage"
|
||||
"databasus-backend/internal/util/period"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
"databasus-backend/internal/util/tools"
|
||||
@@ -326,218 +324,13 @@ func Test_GetBackupConfigByDbID_ReturnsDefaultConfigForNewDatabase(t *testing.T)
|
||||
&response,
|
||||
)
|
||||
|
||||
var plan plans.DatabasePlan
|
||||
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&plan,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, response.DatabaseID)
|
||||
assert.False(t, response.IsBackupsEnabled)
|
||||
assert.Equal(t, plan.MaxStoragePeriod, response.RetentionTimePeriod)
|
||||
assert.Equal(t, plan.MaxBackupSizeMB, response.MaxBackupSizeMB)
|
||||
assert.Equal(t, plan.MaxBackupsTotalSizeMB, response.MaxBackupsTotalSizeMB)
|
||||
assert.True(t, response.IsRetryIfFailed)
|
||||
assert.Equal(t, 3, response.MaxFailedTriesCount)
|
||||
assert.NotNil(t, response.BackupInterval)
|
||||
}
|
||||
|
||||
func Test_GetDatabasePlan_ForNewDatabase_PlanAlwaysReturned(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
var response plans.DatabasePlan
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, response.DatabaseID)
|
||||
assert.NotNil(t, response.MaxBackupSizeMB)
|
||||
assert.NotNil(t, response.MaxBackupsTotalSizeMB)
|
||||
assert.NotEmpty(t, response.MaxStoragePeriod)
|
||||
}
|
||||
|
||||
func Test_SaveBackupConfig_WhenPlanLimitsAreAdjusted_ValidationEnforced(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Get plan via API (triggers auto-creation)
|
||||
var plan plans.DatabasePlan
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&plan,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, plan.DatabaseID)
|
||||
|
||||
// Adjust plan limits directly in database to fixed restrictive values
|
||||
err := storage.GetDb().Model(&plans.DatabasePlan{}).
|
||||
Where("database_id = ?", database.ID).
|
||||
Updates(map[string]any{
|
||||
"max_backup_size_mb": 100,
|
||||
"max_backups_total_size_mb": 1000,
|
||||
"max_storage_period": period.PeriodMonth,
|
||||
}).Error
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test 1: Try to save backup config with exceeded backup size limit
|
||||
timeOfDay := "04:00"
|
||||
backupConfigExceededSize := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
MaxBackupSizeMB: 200, // Exceeds limit of 100
|
||||
MaxBackupsTotalSizeMB: 800,
|
||||
}
|
||||
|
||||
respExceededSize := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner.Token,
|
||||
backupConfigExceededSize,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
assert.Contains(t, string(respExceededSize.Body), "max backup size exceeds plan limit")
|
||||
|
||||
// Test 2: Try to save backup config with exceeded total size limit
|
||||
backupConfigExceededTotal := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
MaxBackupSizeMB: 50,
|
||||
MaxBackupsTotalSizeMB: 2000, // Exceeds limit of 1000
|
||||
}
|
||||
|
||||
respExceededTotal := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner.Token,
|
||||
backupConfigExceededTotal,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
assert.Contains(t, string(respExceededTotal.Body), "max total backups size exceeds plan limit")
|
||||
|
||||
// Test 3: Try to save backup config with exceeded storage period limit
|
||||
backupConfigExceededPeriod := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodYear, // Exceeds limit of Month
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
MaxBackupSizeMB: 80,
|
||||
MaxBackupsTotalSizeMB: 800,
|
||||
}
|
||||
|
||||
respExceededPeriod := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner.Token,
|
||||
backupConfigExceededPeriod,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
assert.Contains(t, string(respExceededPeriod.Body), "storage period exceeds plan limit")
|
||||
|
||||
// Test 4: Save backup config within all limits - should succeed
|
||||
backupConfigValid := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek, // Within Month limit
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
MaxBackupSizeMB: 80, // Within 100 limit
|
||||
MaxBackupsTotalSizeMB: 800, // Within 1000 limit
|
||||
}
|
||||
|
||||
var responseValid BackupConfig
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner.Token,
|
||||
backupConfigValid,
|
||||
http.StatusOK,
|
||||
&responseValid,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, responseValid.DatabaseID)
|
||||
assert.Equal(t, int64(80), responseValid.MaxBackupSizeMB)
|
||||
assert.Equal(t, int64(800), responseValid.MaxBackupsTotalSizeMB)
|
||||
assert.Equal(t, period.PeriodWeek, responseValid.RetentionTimePeriod)
|
||||
}
|
||||
|
||||
func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -2,14 +2,11 @@ package backups_config
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/features/storages"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -20,7 +17,6 @@ var (
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
plans.GetDatabasePlanService(),
|
||||
nil,
|
||||
}
|
||||
)
|
||||
@@ -37,21 +33,6 @@ func GetBackupConfigService() *BackupConfigService {
|
||||
return backupConfigService
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
var SetupDependencies = sync.OnceFunc(func() {
|
||||
storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService)
|
||||
})
|
||||
|
||||
@@ -7,5 +7,5 @@ type TransferDatabaseRequest struct {
|
||||
TargetStorageID *uuid.UUID `json:"targetStorageId,omitempty"`
|
||||
IsTransferWithStorage bool `json:"isTransferWithStorage,omitempty"`
|
||||
IsTransferWithNotifiers bool `json:"isTransferWithNotifiers,omitempty"`
|
||||
TargetNotifierIDs []uuid.UUID `json:"targetNotifierIds,omitempty"`
|
||||
TargetNotifierIDs []uuid.UUID `json:"targetNotifierIds,omitzero"`
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/features/storages"
|
||||
"databasus-backend/internal/util/period"
|
||||
)
|
||||
@@ -29,8 +28,8 @@ type BackupConfig struct {
|
||||
RetentionGfsMonths int `json:"retentionGfsMonths" gorm:"column:retention_gfs_months;type:int;not null;default:0"`
|
||||
RetentionGfsYears int `json:"retentionGfsYears" gorm:"column:retention_gfs_years;type:int;not null;default:0"`
|
||||
|
||||
BackupIntervalID uuid.UUID `json:"backupIntervalId" gorm:"column:backup_interval_id;type:uuid;not null"`
|
||||
BackupInterval *intervals.Interval `json:"backupInterval,omitempty" gorm:"foreignKey:BackupIntervalID"`
|
||||
BackupIntervalID uuid.UUID `json:"backupIntervalId" gorm:"column:backup_interval_id;type:uuid;not null"`
|
||||
BackupInterval *intervals.Interval `json:"backupInterval,omitzero" gorm:"foreignKey:BackupIntervalID"`
|
||||
|
||||
Storage *storages.Storage `json:"storage" gorm:"foreignKey:StorageID"`
|
||||
StorageID *uuid.UUID `json:"storageId" gorm:"column:storage_id;type:uuid;"`
|
||||
@@ -42,11 +41,6 @@ type BackupConfig struct {
|
||||
MaxFailedTriesCount int `json:"maxFailedTriesCount" gorm:"column:max_failed_tries_count;type:int;not null"`
|
||||
|
||||
Encryption BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
|
||||
|
||||
// MaxBackupSizeMB limits individual backup size. 0 = unlimited.
|
||||
MaxBackupSizeMB int64 `json:"maxBackupSizeMb" gorm:"column:max_backup_size_mb;type:int;not null"`
|
||||
// MaxBackupsTotalSizeMB limits total size of all backups. 0 = unlimited.
|
||||
MaxBackupsTotalSizeMB int64 `json:"maxBackupsTotalSizeMb" gorm:"column:max_backups_total_size_mb;type:int;not null"`
|
||||
}
|
||||
|
||||
func (h *BackupConfig) TableName() string {
|
||||
@@ -86,12 +80,12 @@ func (b *BackupConfig) AfterFind(tx *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *BackupConfig) Validate(plan *plans.DatabasePlan) error {
|
||||
func (b *BackupConfig) Validate() error {
|
||||
if b.BackupIntervalID == uuid.Nil && b.BackupInterval == nil {
|
||||
return errors.New("backup interval is required")
|
||||
}
|
||||
|
||||
if err := b.validateRetentionPolicy(plan); err != nil {
|
||||
if err := b.validateRetentionPolicy(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -110,67 +104,38 @@ func (b *BackupConfig) Validate(plan *plans.DatabasePlan) error {
|
||||
}
|
||||
}
|
||||
|
||||
if b.MaxBackupSizeMB < 0 {
|
||||
return errors.New("max backup size must be non-negative")
|
||||
}
|
||||
|
||||
if b.MaxBackupsTotalSizeMB < 0 {
|
||||
return errors.New("max backups total size must be non-negative")
|
||||
}
|
||||
|
||||
if plan.MaxBackupSizeMB > 0 {
|
||||
if b.MaxBackupSizeMB == 0 || b.MaxBackupSizeMB > plan.MaxBackupSizeMB {
|
||||
return errors.New("max backup size exceeds plan limit")
|
||||
}
|
||||
}
|
||||
|
||||
if plan.MaxBackupsTotalSizeMB > 0 {
|
||||
if b.MaxBackupsTotalSizeMB == 0 ||
|
||||
b.MaxBackupsTotalSizeMB > plan.MaxBackupsTotalSizeMB {
|
||||
return errors.New("max total backups size exceeds plan limit")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *BackupConfig) Copy(newDatabaseID uuid.UUID) *BackupConfig {
|
||||
return &BackupConfig{
|
||||
DatabaseID: newDatabaseID,
|
||||
IsBackupsEnabled: b.IsBackupsEnabled,
|
||||
RetentionPolicyType: b.RetentionPolicyType,
|
||||
RetentionTimePeriod: b.RetentionTimePeriod,
|
||||
RetentionCount: b.RetentionCount,
|
||||
RetentionGfsHours: b.RetentionGfsHours,
|
||||
RetentionGfsDays: b.RetentionGfsDays,
|
||||
RetentionGfsWeeks: b.RetentionGfsWeeks,
|
||||
RetentionGfsMonths: b.RetentionGfsMonths,
|
||||
RetentionGfsYears: b.RetentionGfsYears,
|
||||
BackupIntervalID: uuid.Nil,
|
||||
BackupInterval: b.BackupInterval.Copy(),
|
||||
StorageID: b.StorageID,
|
||||
SendNotificationsOn: b.SendNotificationsOn,
|
||||
IsRetryIfFailed: b.IsRetryIfFailed,
|
||||
MaxFailedTriesCount: b.MaxFailedTriesCount,
|
||||
Encryption: b.Encryption,
|
||||
MaxBackupSizeMB: b.MaxBackupSizeMB,
|
||||
MaxBackupsTotalSizeMB: b.MaxBackupsTotalSizeMB,
|
||||
DatabaseID: newDatabaseID,
|
||||
IsBackupsEnabled: b.IsBackupsEnabled,
|
||||
RetentionPolicyType: b.RetentionPolicyType,
|
||||
RetentionTimePeriod: b.RetentionTimePeriod,
|
||||
RetentionCount: b.RetentionCount,
|
||||
RetentionGfsHours: b.RetentionGfsHours,
|
||||
RetentionGfsDays: b.RetentionGfsDays,
|
||||
RetentionGfsWeeks: b.RetentionGfsWeeks,
|
||||
RetentionGfsMonths: b.RetentionGfsMonths,
|
||||
RetentionGfsYears: b.RetentionGfsYears,
|
||||
BackupIntervalID: uuid.Nil,
|
||||
BackupInterval: b.BackupInterval.Copy(),
|
||||
StorageID: b.StorageID,
|
||||
SendNotificationsOn: b.SendNotificationsOn,
|
||||
IsRetryIfFailed: b.IsRetryIfFailed,
|
||||
MaxFailedTriesCount: b.MaxFailedTriesCount,
|
||||
Encryption: b.Encryption,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BackupConfig) validateRetentionPolicy(plan *plans.DatabasePlan) error {
|
||||
func (b *BackupConfig) validateRetentionPolicy() error {
|
||||
switch b.RetentionPolicyType {
|
||||
case RetentionPolicyTypeTimePeriod, "":
|
||||
if b.RetentionTimePeriod == "" {
|
||||
return errors.New("retention time period is required")
|
||||
}
|
||||
|
||||
if plan.MaxStoragePeriod != period.PeriodForever {
|
||||
if b.RetentionTimePeriod.CompareTo(plan.MaxStoragePeriod) > 0 {
|
||||
return errors.New("storage period exceeds plan limit")
|
||||
}
|
||||
}
|
||||
|
||||
case RetentionPolicyTypeCount:
|
||||
if b.RetentionCount <= 0 {
|
||||
return errors.New("retention count must be greater than 0")
|
||||
|
||||
@@ -6,248 +6,34 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/util/period"
|
||||
)
|
||||
|
||||
func Test_Validate_WhenRetentionTimePeriodIsWeekAndPlanAllowsMonth_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodWeek
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodMonth
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenRetentionTimePeriodIsYearAndPlanAllowsMonth_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodYear
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodMonth
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "storage period exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenRetentionTimePeriodIsForeverAndPlanAllowsForever_ValidationPasses(
|
||||
t *testing.T,
|
||||
) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodForever
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodForever
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenRetentionTimePeriodIsForeverAndPlanAllowsYear_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodForever
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodYear
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "storage period exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenRetentionTimePeriodEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodMonth
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodMonth
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenBackupSize100MBAndPlanAllows500MB_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = 100
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupSizeMB = 500
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenBackupSize500MBAndPlanAllows100MB_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = 500
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupSizeMB = 100
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max backup size exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenBackupSizeIsUnlimitedAndPlanAllowsUnlimited_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupSizeMB = 0
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenBackupSizeIsUnlimitedAndPlanHas500MBLimit_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupSizeMB = 500
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max backup size exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenBackupSizeEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = 500
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupSizeMB = 500
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenTotalSize1GBAndPlanAllows5GB_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = 1000
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupsTotalSizeMB = 5000
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenTotalSize5GBAndPlanAllows1GB_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = 5000
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupsTotalSizeMB = 1000
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max total backups size exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenTotalSizeIsUnlimitedAndPlanAllowsUnlimited_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupsTotalSizeMB = 0
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenTotalSizeIsUnlimitedAndPlanHas1GBLimit_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupsTotalSizeMB = 1000
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max total backups size exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenTotalSizeEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = 5000
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupsTotalSizeMB = 5000
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenAllLimitsAreUnlimitedInPlan_AnyConfigurationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodForever
|
||||
config.MaxBackupSizeMB = 0
|
||||
config.MaxBackupsTotalSizeMB = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenMultipleLimitsExceeded_ValidationFailsWithFirstError(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodYear
|
||||
config.MaxBackupSizeMB = 500
|
||||
config.MaxBackupsTotalSizeMB = 5000
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodMonth
|
||||
plan.MaxBackupSizeMB = 100
|
||||
plan.MaxBackupsTotalSizeMB = 1000
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.Error(t, err)
|
||||
assert.EqualError(t, err, "storage period exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenConfigHasInvalidIntervalButPlanIsValid_ValidationFailsOnInterval(
|
||||
t *testing.T,
|
||||
) {
|
||||
func Test_Validate_WhenIntervalIsMissing_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.BackupIntervalID = uuid.Nil
|
||||
config.BackupInterval = nil
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.EqualError(t, err, "backup interval is required")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenIntervalIsMissing_ValidationFailsRegardlessOfPlan(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.BackupIntervalID = uuid.Nil
|
||||
config.BackupInterval = nil
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "backup interval is required")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenRetryEnabledButMaxTriesIsZero_ValidationFailsRegardlessOfPlan(t *testing.T) {
|
||||
func Test_Validate_WhenRetryEnabledButMaxTriesIsZero_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.IsRetryIfFailed = true
|
||||
config.MaxFailedTriesCount = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.EqualError(t, err, "max failed tries count must be greater than 0")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenEncryptionIsInvalid_ValidationFailsRegardlessOfPlan(t *testing.T) {
|
||||
func Test_Validate_WhenEncryptionIsInvalid_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.Encryption = "INVALID"
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.EqualError(t, err, "encryption must be NONE or ENCRYPTED")
|
||||
}
|
||||
|
||||
@@ -255,125 +41,16 @@ func Test_Validate_WhenRetentionTimePeriodIsEmpty_ValidationFails(t *testing.T)
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = ""
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.EqualError(t, err, "retention time period is required")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenMaxBackupSizeIsNegative_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = -100
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max backup size must be non-negative")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenMaxTotalSizeIsNegative_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = -1000
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max backups total size must be non-negative")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenPlanLimitsAreAtBoundary_ValidationWorks(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
configPeriod period.TimePeriod
|
||||
planPeriod period.TimePeriod
|
||||
configSize int64
|
||||
planSize int64
|
||||
configTotal int64
|
||||
planTotal int64
|
||||
shouldSucceed bool
|
||||
}{
|
||||
{
|
||||
name: "all values just under limit",
|
||||
configPeriod: period.PeriodWeek,
|
||||
planPeriod: period.PeriodMonth,
|
||||
configSize: 99,
|
||||
planSize: 100,
|
||||
configTotal: 999,
|
||||
planTotal: 1000,
|
||||
shouldSucceed: true,
|
||||
},
|
||||
{
|
||||
name: "all values equal to limit",
|
||||
configPeriod: period.PeriodMonth,
|
||||
planPeriod: period.PeriodMonth,
|
||||
configSize: 100,
|
||||
planSize: 100,
|
||||
configTotal: 1000,
|
||||
planTotal: 1000,
|
||||
shouldSucceed: true,
|
||||
},
|
||||
{
|
||||
name: "period just over limit",
|
||||
configPeriod: period.Period3Month,
|
||||
planPeriod: period.PeriodMonth,
|
||||
configSize: 100,
|
||||
planSize: 100,
|
||||
configTotal: 1000,
|
||||
planTotal: 1000,
|
||||
shouldSucceed: false,
|
||||
},
|
||||
{
|
||||
name: "size just over limit",
|
||||
configPeriod: period.PeriodMonth,
|
||||
planPeriod: period.PeriodMonth,
|
||||
configSize: 101,
|
||||
planSize: 100,
|
||||
configTotal: 1000,
|
||||
planTotal: 1000,
|
||||
shouldSucceed: false,
|
||||
},
|
||||
{
|
||||
name: "total size just over limit",
|
||||
configPeriod: period.PeriodMonth,
|
||||
planPeriod: period.PeriodMonth,
|
||||
configSize: 100,
|
||||
planSize: 100,
|
||||
configTotal: 1001,
|
||||
planTotal: 1000,
|
||||
shouldSucceed: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = tt.configPeriod
|
||||
config.MaxBackupSizeMB = tt.configSize
|
||||
config.MaxBackupsTotalSizeMB = tt.configTotal
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = tt.planPeriod
|
||||
plan.MaxBackupSizeMB = tt.planSize
|
||||
plan.MaxBackupsTotalSizeMB = tt.planTotal
|
||||
|
||||
err := config.Validate(plan)
|
||||
if tt.shouldSucceed {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Validate_WhenPolicyTypeIsCount_RequiresPositiveCount(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionPolicyType = RetentionPolicyTypeCount
|
||||
config.RetentionCount = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.EqualError(t, err, "retention count must be greater than 0")
|
||||
}
|
||||
|
||||
@@ -382,9 +59,7 @@ func Test_Validate_WhenPolicyTypeIsCount_WithPositiveCount_ValidationPasses(t *t
|
||||
config.RetentionPolicyType = RetentionPolicyTypeCount
|
||||
config.RetentionCount = 10
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -396,9 +71,7 @@ func Test_Validate_WhenPolicyTypeIsGFS_RequiresAtLeastOneField(t *testing.T) {
|
||||
config.RetentionGfsMonths = 0
|
||||
config.RetentionGfsYears = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.EqualError(t, err, "at least one GFS retention field must be greater than 0")
|
||||
}
|
||||
|
||||
@@ -407,9 +80,7 @@ func Test_Validate_WhenPolicyTypeIsGFS_WithOnlyHours_ValidationPasses(t *testing
|
||||
config.RetentionPolicyType = RetentionPolicyTypeGFS
|
||||
config.RetentionGfsHours = 24
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -418,9 +89,7 @@ func Test_Validate_WhenPolicyTypeIsGFS_WithOnlyDays_ValidationPasses(t *testing.
|
||||
config.RetentionPolicyType = RetentionPolicyTypeGFS
|
||||
config.RetentionGfsDays = 7
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -433,9 +102,7 @@ func Test_Validate_WhenPolicyTypeIsGFS_WithAllFields_ValidationPasses(t *testing
|
||||
config.RetentionGfsMonths = 12
|
||||
config.RetentionGfsYears = 3
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -443,35 +110,59 @@ func Test_Validate_WhenPolicyTypeIsInvalid_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionPolicyType = "INVALID"
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.EqualError(t, err, "invalid retention policy type")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenCloudAndEncryptionIsNotEncrypted_ValidationFails(t *testing.T) {
|
||||
enableCloud(t)
|
||||
|
||||
backupConfig := createValidBackupConfig()
|
||||
backupConfig.Encryption = BackupEncryptionNone
|
||||
|
||||
err := backupConfig.Validate()
|
||||
assert.EqualError(t, err, "encryption is mandatory for cloud storage")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenCloudAndEncryptionIsEncrypted_ValidationPasses(t *testing.T) {
|
||||
enableCloud(t)
|
||||
|
||||
backupConfig := createValidBackupConfig()
|
||||
backupConfig.Encryption = BackupEncryptionEncrypted
|
||||
|
||||
err := backupConfig.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenNotCloudAndEncryptionIsNotEncrypted_ValidationPasses(t *testing.T) {
|
||||
backupConfig := createValidBackupConfig()
|
||||
backupConfig.Encryption = BackupEncryptionNone
|
||||
|
||||
err := backupConfig.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func enableCloud(t *testing.T) {
|
||||
t.Helper()
|
||||
config.GetEnv().IsCloud = true
|
||||
t.Cleanup(func() {
|
||||
config.GetEnv().IsCloud = false
|
||||
})
|
||||
}
|
||||
|
||||
func createValidBackupConfig() *BackupConfig {
|
||||
intervalID := uuid.New()
|
||||
return &BackupConfig{
|
||||
DatabaseID: uuid.New(),
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodMonth,
|
||||
BackupIntervalID: intervalID,
|
||||
BackupInterval: &intervals.Interval{ID: intervalID},
|
||||
SendNotificationsOn: []BackupNotificationType{},
|
||||
IsRetryIfFailed: false,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
MaxBackupSizeMB: 100,
|
||||
MaxBackupsTotalSizeMB: 1000,
|
||||
}
|
||||
}
|
||||
|
||||
func createUnlimitedPlan() *plans.DatabasePlan {
|
||||
return &plans.DatabasePlan{
|
||||
DatabaseID: uuid.New(),
|
||||
MaxBackupSizeMB: 0,
|
||||
MaxBackupsTotalSizeMB: 0,
|
||||
MaxStoragePeriod: period.PeriodForever,
|
||||
return &BackupConfig{
|
||||
DatabaseID: uuid.New(),
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodMonth,
|
||||
BackupIntervalID: intervalID,
|
||||
BackupInterval: &intervals.Interval{ID: intervalID},
|
||||
SendNotificationsOn: []BackupNotificationType{},
|
||||
IsRetryIfFailed: false,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,6 +62,14 @@ func (r *BackupConfigRepository) FindByDatabaseID(databaseID uuid.UUID) (*Backup
|
||||
GetDb().
|
||||
Preload("BackupInterval").
|
||||
Preload("Storage").
|
||||
Preload("Storage.LocalStorage").
|
||||
Preload("Storage.S3Storage").
|
||||
Preload("Storage.GoogleDriveStorage").
|
||||
Preload("Storage.NASStorage").
|
||||
Preload("Storage.AzureBlobStorage").
|
||||
Preload("Storage.FTPStorage").
|
||||
Preload("Storage.SFTPStorage").
|
||||
Preload("Storage.RcloneStorage").
|
||||
Where("database_id = ?", databaseID).
|
||||
First(&backupConfig).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
@@ -81,6 +89,14 @@ func (r *BackupConfigRepository) GetWithEnabledBackups() ([]*BackupConfig, error
|
||||
GetDb().
|
||||
Preload("BackupInterval").
|
||||
Preload("Storage").
|
||||
Preload("Storage.LocalStorage").
|
||||
Preload("Storage.S3Storage").
|
||||
Preload("Storage.GoogleDriveStorage").
|
||||
Preload("Storage.NASStorage").
|
||||
Preload("Storage.AzureBlobStorage").
|
||||
Preload("Storage.FTPStorage").
|
||||
Preload("Storage.SFTPStorage").
|
||||
Preload("Storage.RcloneStorage").
|
||||
Where("is_backups_enabled = ?", true).
|
||||
Find(&backupConfigs).Error; err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -8,10 +8,10 @@ import (
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_models "databasus-backend/internal/features/users/models"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/period"
|
||||
)
|
||||
|
||||
type BackupConfigService struct {
|
||||
@@ -20,7 +20,6 @@ type BackupConfigService struct {
|
||||
storageService *storages.StorageService
|
||||
notifierService *notifiers.NotifierService
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
databasePlanService *plans.DatabasePlanService
|
||||
|
||||
dbStorageChangeListener BackupConfigStorageChangeListener
|
||||
}
|
||||
@@ -46,12 +45,7 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
|
||||
user *users_models.User,
|
||||
backupConfig *BackupConfig,
|
||||
) (*BackupConfig, error) {
|
||||
plan, err := s.databasePlanService.GetDatabasePlan(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := backupConfig.Validate(plan); err != nil {
|
||||
if err := backupConfig.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -88,12 +82,7 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
|
||||
func (s *BackupConfigService) SaveBackupConfig(
|
||||
backupConfig *BackupConfig,
|
||||
) (*BackupConfig, error) {
|
||||
plan, err := s.databasePlanService.GetDatabasePlan(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := backupConfig.Validate(plan); err != nil {
|
||||
if err := backupConfig.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -131,18 +120,6 @@ func (s *BackupConfigService) GetBackupConfigByDbIdWithAuth(
|
||||
return s.GetBackupConfigByDbId(databaseID)
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) GetDatabasePlan(
|
||||
user *users_models.User,
|
||||
databaseID uuid.UUID,
|
||||
) (*plans.DatabasePlan, error) {
|
||||
_, err := s.databaseService.GetDatabase(user, databaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.databasePlanService.GetDatabasePlan(databaseID)
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) GetBackupConfigByDbId(
|
||||
databaseID uuid.UUID,
|
||||
) (*BackupConfig, error) {
|
||||
@@ -322,20 +299,13 @@ func (s *BackupConfigService) TransferDatabaseToWorkspace(
|
||||
func (s *BackupConfigService) initializeDefaultConfig(
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
plan, err := s.databasePlanService.GetDatabasePlan(databaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
timeOfDay := "04:00"
|
||||
|
||||
_, err = s.backupConfigRepository.Save(&BackupConfig{
|
||||
DatabaseID: databaseID,
|
||||
IsBackupsEnabled: false,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: plan.MaxStoragePeriod,
|
||||
MaxBackupSizeMB: plan.MaxBackupSizeMB,
|
||||
MaxBackupsTotalSizeMB: plan.MaxBackupsTotalSizeMB,
|
||||
_, err := s.backupConfigRepository.Save(&BackupConfig{
|
||||
DatabaseID: databaseID,
|
||||
IsBackupsEnabled: false,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.Period3Month,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
|
||||
305
backend/internal/features/billing/controller.go
Normal file
305
backend/internal/features/billing/controller.go
Normal file
@@ -0,0 +1,305 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
type BillingController struct {
|
||||
billingService *BillingService
|
||||
}
|
||||
|
||||
func (c *BillingController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
billing := router.Group("/billing")
|
||||
|
||||
billing.POST("/subscription", c.CreateSubscription)
|
||||
billing.POST("/subscription/change-storage", c.ChangeSubscriptionStorage)
|
||||
billing.POST("/subscription/portal/:subscription_id", c.GetPortalSession)
|
||||
billing.GET("/subscription/events/:subscription_id", c.GetSubscriptionEvents)
|
||||
billing.GET("/subscription/invoices/:subscription_id", c.GetInvoices)
|
||||
billing.GET("/subscription/:database_id", c.GetSubscription)
|
||||
}
|
||||
|
||||
// CreateSubscription
|
||||
// @Summary Create a new subscription
|
||||
// @Description Create a billing subscription for the specified database with the given storage
|
||||
// @Tags billing
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body CreateSubscriptionRequest true "Subscription creation data"
|
||||
// @Success 200 {object} CreateSubscriptionResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /billing/subscription [post]
|
||||
func (c *BillingController) CreateSubscription(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(401, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var request CreateSubscriptionRequest
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(400, gin.H{"error": "Invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
log := logger.GetLogger().With(
|
||||
"request_id", uuid.New(),
|
||||
"database_id", request.DatabaseID,
|
||||
"user_id", user.ID,
|
||||
)
|
||||
|
||||
transactionID, err := c.billingService.CreateSubscription(
|
||||
log,
|
||||
user,
|
||||
request.DatabaseID,
|
||||
request.StorageGB,
|
||||
)
|
||||
if err != nil {
|
||||
log.Error("Failed to create subscription", "error", err)
|
||||
ctx.JSON(500, gin.H{"error": "Failed to create subscription"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(200, CreateSubscriptionResponse{PaddleTransactionID: transactionID})
|
||||
}
|
||||
|
||||
// ChangeSubscriptionStorage
|
||||
// @Summary Change subscription storage
|
||||
// @Description Update the storage allocation for an existing subscription
|
||||
// @Tags billing
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body ChangeStorageRequest true "New storage configuration"
|
||||
// @Success 200 {object} ChangeStorageResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /billing/subscription/change-storage [post]
|
||||
func (c *BillingController) ChangeSubscriptionStorage(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(401, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var request ChangeStorageRequest
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(400, gin.H{"error": "Invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
log := logger.GetLogger().With(
|
||||
"request_id", uuid.New(),
|
||||
"database_id", request.DatabaseID,
|
||||
"user_id", user.ID,
|
||||
)
|
||||
|
||||
result, err := c.billingService.ChangeSubscriptionStorage(log, user, request.DatabaseID, request.StorageGB)
|
||||
if err != nil {
|
||||
log.Error("Failed to change subscription storage", "error", err)
|
||||
ctx.JSON(500, gin.H{"error": "Failed to change subscription storage"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(200, ChangeStorageResponse{
|
||||
ApplyMode: result.ApplyMode,
|
||||
CurrentGB: result.CurrentGB,
|
||||
PendingGB: result.PendingGB,
|
||||
})
|
||||
}
|
||||
|
||||
// GetPortalSession
|
||||
// @Summary Get billing portal session
|
||||
// @Description Generate a portal session URL for managing the subscription
|
||||
// @Tags billing
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param subscription_id path string true "Subscription ID"
|
||||
// @Success 200 {object} GetPortalSessionResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /billing/subscription/portal/{subscription_id} [post]
|
||||
func (c *BillingController) GetPortalSession(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(401, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
subscriptionID := ctx.Param("subscription_id")
|
||||
if subscriptionID == "" {
|
||||
ctx.JSON(400, gin.H{"error": "Subscription ID is required"})
|
||||
return
|
||||
}
|
||||
|
||||
log := logger.GetLogger().With(
|
||||
"request_id", uuid.New(),
|
||||
"subscription_id", subscriptionID,
|
||||
"user_id", user.ID,
|
||||
)
|
||||
|
||||
url, err := c.billingService.GetPortalURL(log, user, uuid.MustParse(subscriptionID))
|
||||
if err != nil {
|
||||
log.Error("Failed to get portal session", "error", err)
|
||||
ctx.JSON(500, gin.H{"error": "Failed to get portal session"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(200, GetPortalSessionResponse{PortalURL: url})
|
||||
}
|
||||
|
||||
// GetSubscriptionEvents
|
||||
// @Summary Get subscription events
|
||||
// @Description Retrieve the event history for a subscription
|
||||
// @Tags billing
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param subscription_id path string true "Subscription ID"
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Success 200 {object} GetSubscriptionEventsResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /billing/subscription/events/{subscription_id} [get]
|
||||
func (c *BillingController) GetSubscriptionEvents(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(401, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
subscriptionID, err := uuid.Parse(ctx.Param("subscription_id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid subscription ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var request PaginatedRequest
|
||||
if err := ctx.ShouldBindQuery(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
log := logger.GetLogger().With(
|
||||
"request_id", uuid.New(),
|
||||
"subscription_id", subscriptionID,
|
||||
"user_id", user.ID,
|
||||
)
|
||||
|
||||
response, err := c.billingService.GetSubscriptionEvents(log, user, subscriptionID, request.Limit, request.Offset)
|
||||
if err != nil {
|
||||
log.Error("Failed to get subscription events", "error", err)
|
||||
ctx.JSON(500, gin.H{"error": "Failed to get subscription events"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(200, response)
|
||||
}
|
||||
|
||||
// GetInvoices
|
||||
// @Summary Get subscription invoices
|
||||
// @Description Retrieve all invoices for a subscription
|
||||
// @Tags billing
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param subscription_id path string true "Subscription ID"
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Success 200 {object} GetInvoicesResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /billing/subscription/invoices/{subscription_id} [get]
|
||||
func (c *BillingController) GetInvoices(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(401, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
subscriptionID, err := uuid.Parse(ctx.Param("subscription_id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid subscription ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var request PaginatedRequest
|
||||
if err := ctx.ShouldBindQuery(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
log := logger.GetLogger().With(
|
||||
"request_id", uuid.New(),
|
||||
"subscription_id", subscriptionID,
|
||||
"user_id", user.ID,
|
||||
)
|
||||
|
||||
response, err := c.billingService.GetSubscriptionInvoices(log, user, subscriptionID, request.Limit, request.Offset)
|
||||
if err != nil {
|
||||
log.Error("Failed to get invoices", "error", err)
|
||||
ctx.JSON(500, gin.H{"error": "Failed to get invoices"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(200, response)
|
||||
}
|
||||
|
||||
// GetSubscription
|
||||
// @Summary Get subscription by database
|
||||
// @Description Retrieve the subscription associated with a specific database
|
||||
// @Tags billing
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param database_id path string true "Database ID"
|
||||
// @Success 200 {object} billing_models.Subscription
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /billing/subscription/{database_id} [get]
|
||||
func (c *BillingController) GetSubscription(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(401, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
databaseID, err := uuid.Parse(ctx.Param("database_id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid database ID"})
|
||||
return
|
||||
}
|
||||
|
||||
log := logger.GetLogger().With(
|
||||
"request_id", uuid.New(),
|
||||
"database_id", databaseID,
|
||||
"user_id", user.ID,
|
||||
)
|
||||
|
||||
subscription, err := c.billingService.GetSubscriptionByDatabaseID(log, user, databaseID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSubscriptionNotFound) {
|
||||
ctx.JSON(http.StatusNotFound, gin.H{"error": "Subscription not found"})
|
||||
return
|
||||
}
|
||||
|
||||
log.Error("failed to get subscription", "error", err)
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get subscription"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(200, subscription)
|
||||
}
|
||||
1450
backend/internal/features/billing/controller_test.go
Normal file
1450
backend/internal/features/billing/controller_test.go
Normal file
File diff suppressed because it is too large
Load Diff
35
backend/internal/features/billing/di.go
Normal file
35
backend/internal/features/billing/di.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
billing_repositories "databasus-backend/internal/features/billing/repositories"
|
||||
"databasus-backend/internal/features/databases"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
)
|
||||
|
||||
var (
|
||||
billingService = &BillingService{
|
||||
&billing_repositories.SubscriptionRepository{},
|
||||
&billing_repositories.SubscriptionEventRepository{},
|
||||
&billing_repositories.InvoiceRepository{},
|
||||
nil, // billing provider will be set later to avoid circular dependency
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
*databases.GetDatabaseService(),
|
||||
atomic.Bool{},
|
||||
}
|
||||
billingController = &BillingController{billingService}
|
||||
)
|
||||
|
||||
func GetBillingService() *BillingService {
|
||||
return billingService
|
||||
}
|
||||
|
||||
func GetBillingController() *BillingController {
|
||||
return billingController
|
||||
}
|
||||
|
||||
var SetupDependencies = sync.OnceFunc(func() {
|
||||
databases.GetDatabaseService().AddDbCreationListener(billingService)
|
||||
})
|
||||
67
backend/internal/features/billing/dto.go
Normal file
67
backend/internal/features/billing/dto.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
billing_models "databasus-backend/internal/features/billing/models"
|
||||
)
|
||||
|
||||
type CreateSubscriptionRequest struct {
|
||||
DatabaseID uuid.UUID `json:"databaseId" validate:"required"`
|
||||
StorageGB int `json:"storageGb" validate:"required,min=1"`
|
||||
}
|
||||
|
||||
type CreateSubscriptionResponse struct {
|
||||
PaddleTransactionID string `json:"paddleTransactionId"`
|
||||
}
|
||||
|
||||
type ChangeStorageApplyMode string
|
||||
|
||||
const (
|
||||
ChangeStorageApplyImmediate ChangeStorageApplyMode = "immediate"
|
||||
ChangeStorageApplyNextCycle ChangeStorageApplyMode = "next_cycle"
|
||||
)
|
||||
|
||||
type ChangeStorageRequest struct {
|
||||
DatabaseID uuid.UUID `json:"databaseId" validate:"required"`
|
||||
StorageGB int `json:"storageGb" validate:"required,min=1"`
|
||||
}
|
||||
|
||||
type ChangeStorageResponse struct {
|
||||
ApplyMode ChangeStorageApplyMode `json:"applyMode"`
|
||||
CurrentGB int `json:"currentGb"`
|
||||
PendingGB *int `json:"pendingGb,omitempty"`
|
||||
}
|
||||
|
||||
type PortalResponse struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type ChangeStorageResult struct {
|
||||
ApplyMode ChangeStorageApplyMode
|
||||
CurrentGB int
|
||||
PendingGB *int
|
||||
}
|
||||
|
||||
type GetPortalSessionResponse struct {
|
||||
PortalURL string `json:"url"`
|
||||
}
|
||||
|
||||
type PaginatedRequest struct {
|
||||
Limit int `form:"limit" json:"limit"`
|
||||
Offset int `form:"offset" json:"offset"`
|
||||
}
|
||||
|
||||
type GetSubscriptionEventsResponse struct {
|
||||
Events []*billing_models.SubscriptionEvent `json:"events"`
|
||||
Total int64 `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
}
|
||||
|
||||
type GetInvoicesResponse struct {
|
||||
Invoices []*billing_models.Invoice `json:"invoices"`
|
||||
Total int64 `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
}
|
||||
15
backend/internal/features/billing/errors.go
Normal file
15
backend/internal/features/billing/errors.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package billing
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrInvalidStorage = errors.New("storage must be between 20 and 10000 GB")
|
||||
ErrAlreadySubscribed = errors.New("database already has an active subscription")
|
||||
ErrExceedsUsage = errors.New("cannot downgrade below current storage usage")
|
||||
ErrNoChange = errors.New("requested storage is the same as current")
|
||||
ErrDuplicate = errors.New("duplicate event already processed")
|
||||
ErrProviderUnavailable = errors.New("payment provider unavailable")
|
||||
ErrNoActiveSubscription = errors.New("no active subscription for this database")
|
||||
ErrAccessDenied = errors.New("user does not have access to this database")
|
||||
ErrSubscriptionNotFound = errors.New("subscription not found")
|
||||
)
|
||||
24
backend/internal/features/billing/models/invoice.go
Normal file
24
backend/internal/features/billing/models/invoice.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package billing_models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Invoice struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
|
||||
SubscriptionID uuid.UUID `json:"subscriptionId" gorm:"column:subscription_id;type:uuid;not null"`
|
||||
ProviderInvoiceID string `json:"providerInvoiceId" gorm:"column:provider_invoice_id;type:text;not null"`
|
||||
AmountCents int64 `json:"amountCents" gorm:"column:amount_cents;type:bigint;not null"`
|
||||
StorageGB int `json:"storageGb" gorm:"column:storage_gb;type:int;not null"`
|
||||
PeriodStart time.Time `json:"periodStart" gorm:"column:period_start;type:timestamptz;not null"`
|
||||
PeriodEnd time.Time `json:"periodEnd" gorm:"column:period_end;type:timestamptz;not null"`
|
||||
Status InvoiceStatus `json:"status" gorm:"column:status;type:text;not null"`
|
||||
PaidAt *time.Time `json:"paidAt,omitzero" gorm:"column:paid_at;type:timestamptz"`
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;type:timestamptz;not null"`
|
||||
}
|
||||
|
||||
func (Invoice) TableName() string {
|
||||
return "invoices"
|
||||
}
|
||||
11
backend/internal/features/billing/models/invoice_status.go
Normal file
11
backend/internal/features/billing/models/invoice_status.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package billing_models
|
||||
|
||||
type InvoiceStatus string
|
||||
|
||||
const (
|
||||
InvoiceStatusPending InvoiceStatus = "pending"
|
||||
InvoiceStatusPaid InvoiceStatus = "paid"
|
||||
InvoiceStatusFailed InvoiceStatus = "failed"
|
||||
InvoiceStatusRefunded InvoiceStatus = "refunded"
|
||||
InvoiceStatusDisputed InvoiceStatus = "disputed"
|
||||
)
|
||||
72
backend/internal/features/billing/models/subscription.go
Normal file
72
backend/internal/features/billing/models/subscription.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package billing_models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
)
|
||||
|
||||
type Subscription struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
|
||||
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;not null"`
|
||||
Status SubscriptionStatus `json:"status" gorm:"column:status;type:text;not null"`
|
||||
|
||||
StorageGB int `json:"storageGb" gorm:"column:storage_gb;type:int;not null"`
|
||||
PendingStorageGB *int `json:"pendingStorageGb,omitempty" gorm:"column:pending_storage_gb;type:int"`
|
||||
|
||||
CurrentPeriodStart time.Time `json:"currentPeriodStart" gorm:"column:current_period_start;type:timestamptz;not null"`
|
||||
CurrentPeriodEnd time.Time `json:"currentPeriodEnd" gorm:"column:current_period_end;type:timestamptz;not null"`
|
||||
CanceledAt *time.Time `json:"canceledAt,omitzero" gorm:"column:canceled_at;type:timestamptz"`
|
||||
|
||||
DataRetentionGracePeriodUntil *time.Time `json:"dataRetentionGracePeriodUntil,omitzero" gorm:"column:data_retention_grace_period_until;type:timestamptz"`
|
||||
|
||||
ProviderName *string `json:"providerName,omitempty" gorm:"column:provider_name;type:text"`
|
||||
ProviderSubID *string `json:"providerSubId,omitempty" gorm:"column:provider_sub_id;type:text"`
|
||||
ProviderCustomerID *string `json:"providerCustomerId,omitempty" gorm:"column:provider_customer_id;type:text"`
|
||||
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;type:timestamptz;not null"`
|
||||
UpdatedAt time.Time `json:"updatedAt" gorm:"column:updated_at;type:timestamptz;not null"`
|
||||
}
|
||||
|
||||
func (Subscription) TableName() string {
|
||||
return "subscriptions"
|
||||
}
|
||||
|
||||
func (s *Subscription) PriceCents() int64 {
|
||||
return int64(s.StorageGB) * config.GetEnv().PricePerGBCents
|
||||
}
|
||||
|
||||
// CanCreateNewBackups - whether it is allowed to create new backups
|
||||
// by scheduler or for user manually. Clarification: in grace period
|
||||
// user can download, delete and restore backups, but cannot create new ones
|
||||
func (s *Subscription) CanCreateNewBackups() bool {
|
||||
switch s.Status {
|
||||
case StatusActive, StatusPastDue:
|
||||
return true
|
||||
case StatusTrial, StatusCanceled:
|
||||
return time.Now().Before(s.CurrentPeriodEnd)
|
||||
case StatusExpired:
|
||||
return false
|
||||
default:
|
||||
panic("unknown subscription status")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Subscription) GetBackupsStorageGB() int {
|
||||
switch s.Status {
|
||||
case StatusActive, StatusPastDue, StatusCanceled:
|
||||
return s.StorageGB
|
||||
case StatusTrial:
|
||||
if time.Now().Before(s.CurrentPeriodEnd) {
|
||||
return s.StorageGB
|
||||
}
|
||||
|
||||
return 0
|
||||
case StatusExpired:
|
||||
return 0
|
||||
default:
|
||||
panic("unknown subscription status")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package billing_models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type SubscriptionEvent struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
|
||||
SubscriptionID uuid.UUID `json:"subscriptionId" gorm:"column:subscription_id;type:uuid;not null"`
|
||||
ProviderEventID *string `json:"providerEventId,omitempty" gorm:"column:provider_event_id;type:text"`
|
||||
Type SubscriptionEventType `json:"type" gorm:"column:type;type:text;not null"`
|
||||
|
||||
OldStorageGB *int `json:"oldStorageGb,omitempty" gorm:"column:old_storage_gb;type:int"`
|
||||
NewStorageGB *int `json:"newStorageGb,omitempty" gorm:"column:new_storage_gb;type:int"`
|
||||
OldStatus *SubscriptionStatus `json:"oldStatus,omitempty" gorm:"column:old_status;type:text"`
|
||||
NewStatus *SubscriptionStatus `json:"newStatus,omitempty" gorm:"column:new_status;type:text"`
|
||||
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;type:timestamptz;not null"`
|
||||
}
|
||||
|
||||
func (SubscriptionEvent) TableName() string {
|
||||
return "subscription_events"
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package billing_models
|
||||
|
||||
type SubscriptionEventType string
|
||||
|
||||
const (
|
||||
EventCreated SubscriptionEventType = "subscription.created"
|
||||
EventUpgraded SubscriptionEventType = "subscription.upgraded"
|
||||
EventDowngraded SubscriptionEventType = "subscription.downgraded"
|
||||
EventNewBillingCycleStarted SubscriptionEventType = "subscription.new_billing_cycle_started"
|
||||
EventCanceled SubscriptionEventType = "subscription.canceled"
|
||||
EventReactivated SubscriptionEventType = "subscription.reactivated"
|
||||
EventExpired SubscriptionEventType = "subscription.expired"
|
||||
EventPastDue SubscriptionEventType = "subscription.past_due"
|
||||
EventRecoveredFromPastDue SubscriptionEventType = "subscription.recovered_from_past_due"
|
||||
EventRefund SubscriptionEventType = "payment.refund"
|
||||
EventDispute SubscriptionEventType = "payment.dispute"
|
||||
)
|
||||
@@ -0,0 +1,11 @@
|
||||
package billing_models
|
||||
|
||||
type SubscriptionStatus string
|
||||
|
||||
const (
|
||||
StatusTrial SubscriptionStatus = "trial" // trial period (~24h after DB creation)
|
||||
StatusActive SubscriptionStatus = "active" // paid, everything works
|
||||
StatusPastDue SubscriptionStatus = "past_due" // payment failed, trying to charge again, but everything still works
|
||||
StatusCanceled SubscriptionStatus = "canceled" // subscription canceled by user or after past_due (grace period is active)
|
||||
StatusExpired SubscriptionStatus = "expired" // grace period ended, data marked for deletion, can come from canceled and trial
|
||||
)
|
||||
22
backend/internal/features/billing/models/webhook_event.go
Normal file
22
backend/internal/features/billing/models/webhook_event.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package billing_models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type WebhookEvent struct {
|
||||
RequestID uuid.UUID
|
||||
ProviderEventID string
|
||||
DatabaseID *uuid.UUID
|
||||
Type WebhookEventType
|
||||
ProviderSubscriptionID string
|
||||
ProviderCustomerID string
|
||||
ProviderInvoiceID string
|
||||
QuantityGB int
|
||||
Status SubscriptionStatus
|
||||
PeriodStart *time.Time
|
||||
PeriodEnd *time.Time
|
||||
AmountCents int64
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package billing_models
|
||||
|
||||
type WebhookEventType string
|
||||
|
||||
const (
|
||||
WHEventSubscriptionCreated WebhookEventType = "subscription.created"
|
||||
WHEventSubscriptionUpdated WebhookEventType = "subscription.updated"
|
||||
WHEventSubscriptionCanceled WebhookEventType = "subscription.canceled"
|
||||
WHEventSubscriptionPastDue WebhookEventType = "subscription.past_due"
|
||||
WHEventSubscriptionReactivated WebhookEventType = "subscription.reactivated"
|
||||
WHEventPaymentSucceeded WebhookEventType = "payment.succeeded"
|
||||
WHEventSubscriptionDisputeCreated WebhookEventType = "dispute.created"
|
||||
)
|
||||
5
backend/internal/features/billing/paddle/README.md
Normal file
5
backend/internal/features/billing/paddle/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
**Paddle hints:**
|
||||
|
||||
- **max_quantity on price:** Paddle limits `quantity` on a price to 100 by default. You need to explicitly set the range (`quantity: {minimum: 20, maximum: 10000}`) when creating a price via API or dashboard. Otherwise requests with quantity > 100 will return an error.
|
||||
- **Full items list on update:** Unlike Stripe, Paddle requires sending **all** subscription items in `PATCH /subscriptions/{id}`, not just the changed ones. `proration_billing_mode` is also required. Without this you can accidentally remove a line item or get a 400.
|
||||
- **Webhook events mapping:** Paddle uses `transaction.completed` instead of `payment.succeeded`, `transaction.payment_failed` instead of `payment.failed`, `adjustment.created` instead of `dispute.created`.
|
||||
83
backend/internal/features/billing/paddle/controller.go
Normal file
83
backend/internal/features/billing/paddle/controller.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package billing_paddle
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
billing_webhooks "databasus-backend/internal/features/billing/webhooks"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
type PaddleBillingController struct {
|
||||
paddleBillingService *PaddleBillingService
|
||||
}
|
||||
|
||||
func (c *PaddleBillingController) RegisterPublicRoutes(router *gin.RouterGroup) {
|
||||
router.POST("/billing/paddle/webhook", c.HandlePaddleWebhook)
|
||||
}
|
||||
|
||||
// HandlePaddleWebhook
|
||||
// @Summary Handle Paddle webhook
|
||||
// @Description Process incoming webhook events from Paddle payment provider
|
||||
// @Tags billing
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param Paddle-Signature header string true "Paddle webhook signature"
|
||||
// @Success 200
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500
|
||||
// @Router /billing/paddle/webhook [post]
|
||||
func (c *PaddleBillingController) HandlePaddleWebhook(ctx *gin.Context) {
|
||||
requestID := uuid.New()
|
||||
log := logger.GetLogger().With("request_id", requestID)
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(ctx.Request.Body, 1<<20))
|
||||
if err != nil {
|
||||
log.Error("failed to read webhook request body", "error", err)
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Failed to read request body"})
|
||||
return
|
||||
}
|
||||
|
||||
headers := make(map[string]string)
|
||||
for k := range ctx.Request.Header {
|
||||
headers[k] = ctx.GetHeader(k)
|
||||
}
|
||||
|
||||
if err := c.paddleBillingService.VerifyWebhookSignature(body, headers); err != nil {
|
||||
log.Warn("paddle webhook signature verification failed", "error", err)
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid webhook signature"})
|
||||
return
|
||||
}
|
||||
|
||||
var webhookDTO PaddleWebhookDTO
|
||||
if err := json.Unmarshal(body, &webhookDTO); err != nil {
|
||||
log.Error("failed to unmarshal webhook payload", "error", err)
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid webhook payload"})
|
||||
return
|
||||
}
|
||||
|
||||
log = log.With(
|
||||
"provider_event_id", webhookDTO.EventID,
|
||||
"event_type", webhookDTO.EventType,
|
||||
)
|
||||
|
||||
if err := c.paddleBillingService.ProcessWebhookEvent(log, requestID, webhookDTO, body); err != nil {
|
||||
if errors.Is(err, billing_webhooks.ErrDuplicateWebhook) {
|
||||
log.Info("duplicate webhook event, returning 200 to not force retry")
|
||||
ctx.Status(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
log.Error("Failed to process paddle webhook", "error", err)
|
||||
ctx.Status(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Status(http.StatusOK)
|
||||
}
|
||||
1056
backend/internal/features/billing/paddle/controller_test.go
Normal file
1056
backend/internal/features/billing/paddle/controller_test.go
Normal file
File diff suppressed because it is too large
Load Diff
72
backend/internal/features/billing/paddle/di.go
Normal file
72
backend/internal/features/billing/paddle/di.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package billing_paddle
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/PaddleHQ/paddle-go-sdk"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/billing"
|
||||
billing_webhooks "databasus-backend/internal/features/billing/webhooks"
|
||||
)
|
||||
|
||||
var (
|
||||
paddleBillingService *PaddleBillingService
|
||||
paddleBillingController *PaddleBillingController
|
||||
)
|
||||
|
||||
var initPaddle = sync.OnceFunc(func() {
|
||||
if config.GetEnv().IsPaddleSandbox {
|
||||
paddleClient, err := paddle.NewSandbox(config.GetEnv().PaddleApiKey)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
paddleBillingService = &PaddleBillingService{
|
||||
paddleClient,
|
||||
paddle.NewWebhookVerifier(config.GetEnv().PaddleWebhookSecret),
|
||||
config.GetEnv().PaddlePriceID,
|
||||
billing_webhooks.WebhookRepository{},
|
||||
billing.GetBillingService(),
|
||||
}
|
||||
} else {
|
||||
paddleClient, err := paddle.New(config.GetEnv().PaddleApiKey)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
paddleBillingService = &PaddleBillingService{
|
||||
paddleClient,
|
||||
paddle.NewWebhookVerifier(config.GetEnv().PaddleWebhookSecret),
|
||||
config.GetEnv().PaddlePriceID,
|
||||
billing_webhooks.WebhookRepository{},
|
||||
billing.GetBillingService(),
|
||||
}
|
||||
}
|
||||
|
||||
paddleBillingController = &PaddleBillingController{paddleBillingService}
|
||||
})
|
||||
|
||||
func GetPaddleBillingService() *PaddleBillingService {
|
||||
if !config.GetEnv().IsCloud {
|
||||
return nil
|
||||
}
|
||||
|
||||
initPaddle()
|
||||
return paddleBillingService
|
||||
}
|
||||
|
||||
func GetPaddleBillingController() *PaddleBillingController {
|
||||
if !config.GetEnv().IsCloud {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure service + controller are initialized
|
||||
GetPaddleBillingService()
|
||||
|
||||
return paddleBillingController
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
billing.GetBillingService().SetBillingProvider(GetPaddleBillingService())
|
||||
}
|
||||
9
backend/internal/features/billing/paddle/dto.go
Normal file
9
backend/internal/features/billing/paddle/dto.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package billing_paddle
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type PaddleWebhookDTO struct {
|
||||
EventID string `json:"event_id"`
|
||||
EventType string `json:"event_type"`
|
||||
Data json.RawMessage
|
||||
}
|
||||
50
backend/internal/features/billing/paddle/dto_test.go
Normal file
50
backend/internal/features/billing/paddle/dto_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package billing_paddle
|
||||
|
||||
import "time"
|
||||
|
||||
type TestSubscriptionCreatedPayload struct {
|
||||
EventID string
|
||||
SubID string
|
||||
CustomerID string
|
||||
DatabaseID string
|
||||
QuantityGB int
|
||||
PeriodStart time.Time
|
||||
PeriodEnd time.Time
|
||||
}
|
||||
|
||||
type TestSubscriptionUpdatedPayload struct {
|
||||
EventID string
|
||||
SubID string
|
||||
CustomerID string
|
||||
QuantityGB int
|
||||
PeriodStart time.Time
|
||||
PeriodEnd time.Time
|
||||
HasScheduledChange bool
|
||||
ScheduledChangeAction string
|
||||
}
|
||||
|
||||
type TestSubscriptionCanceledPayload struct {
|
||||
EventID string
|
||||
SubID string
|
||||
CustomerID string
|
||||
}
|
||||
|
||||
type TestTransactionCompletedPayload struct {
|
||||
EventID string
|
||||
TxnID string
|
||||
SubID string
|
||||
CustomerID string
|
||||
TotalCents int64
|
||||
QuantityGB int
|
||||
PeriodStart time.Time
|
||||
PeriodEnd time.Time
|
||||
}
|
||||
|
||||
type TestSubscriptionPastDuePayload struct {
|
||||
EventID string
|
||||
SubID string
|
||||
CustomerID string
|
||||
QuantityGB int
|
||||
PeriodStart time.Time
|
||||
PeriodEnd time.Time
|
||||
}
|
||||
638
backend/internal/features/billing/paddle/service.go
Normal file
638
backend/internal/features/billing/paddle/service.go
Normal file
@@ -0,0 +1,638 @@
|
||||
package billing_paddle
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/PaddleHQ/paddle-go-sdk"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/features/billing"
|
||||
billing_models "databasus-backend/internal/features/billing/models"
|
||||
billing_provider "databasus-backend/internal/features/billing/provider"
|
||||
billing_webhooks "databasus-backend/internal/features/billing/webhooks"
|
||||
)
|
||||
|
||||
type PaddleBillingService struct {
|
||||
client *paddle.SDK
|
||||
webhookVerified *paddle.WebhookVerifier
|
||||
priceID string
|
||||
webhookRepository billing_webhooks.WebhookRepository
|
||||
billingService *billing.BillingService
|
||||
}
|
||||
|
||||
func (s *PaddleBillingService) GetProviderName() billing_provider.ProviderName {
|
||||
return billing_provider.ProviderPaddle
|
||||
}
|
||||
|
||||
func (s *PaddleBillingService) CreateCheckoutSession(
|
||||
logger *slog.Logger,
|
||||
request billing_provider.CheckoutRequest,
|
||||
) (string, error) {
|
||||
logger = logger.With("database_id", request.DatabaseID)
|
||||
logger.Debug(fmt.Sprintf("paddle: creating checkout session for %d GB", request.StorageGB))
|
||||
|
||||
txRequest := &paddle.CreateTransactionRequest{
|
||||
Items: []paddle.CreateTransactionItems{
|
||||
*paddle.NewCreateTransactionItemsCatalogItem(&paddle.CatalogItem{
|
||||
PriceID: s.priceID,
|
||||
Quantity: request.StorageGB,
|
||||
}),
|
||||
},
|
||||
CustomData: paddle.CustomData{"database_id": request.DatabaseID.String()},
|
||||
Checkout: &paddle.TransactionCheckout{
|
||||
URL: &request.SuccessURL,
|
||||
},
|
||||
}
|
||||
|
||||
tx, err := s.client.CreateTransaction(context.Background(), txRequest)
|
||||
if err != nil {
|
||||
logger.Error("paddle: failed to create transaction", "error", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
return tx.ID, nil
|
||||
}
|
||||
|
||||
func (s *PaddleBillingService) UpgradeQuantityWithSurcharge(
|
||||
logger *slog.Logger,
|
||||
providerSubscriptionID string,
|
||||
quantityGB int,
|
||||
) error {
|
||||
logger = logger.With("provider_subscription_id", providerSubscriptionID)
|
||||
logger.Debug(fmt.Sprintf("paddle: applying upgrade: new storage %d GB", quantityGB))
|
||||
|
||||
// important: paddle requires to send all items
|
||||
// in the subscription when updating, not just the changed one
|
||||
subscription, err := s.client.GetSubscription(context.Background(), &paddle.GetSubscriptionRequest{
|
||||
SubscriptionID: providerSubscriptionID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("paddle: failed to get subscription", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
currentQuantity := subscription.Items[0].Quantity
|
||||
if currentQuantity == quantityGB {
|
||||
logger.Info("paddle: subscription already at requested quantity, skipping upgrade",
|
||||
"current_quantity_gb", currentQuantity,
|
||||
"requested_quantity_gb", quantityGB,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
priceID := subscription.Items[0].Price.ID
|
||||
|
||||
_, err = s.client.UpdateSubscription(context.Background(), &paddle.UpdateSubscriptionRequest{
|
||||
SubscriptionID: providerSubscriptionID,
|
||||
Items: paddle.NewPatchField([]paddle.SubscriptionUpdateCatalogItem{
|
||||
{
|
||||
PriceID: priceID,
|
||||
Quantity: quantityGB,
|
||||
},
|
||||
}),
|
||||
ProrationBillingMode: paddle.NewPatchField(paddle.ProrationBillingModeProratedImmediately),
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("paddle: failed to update subscription", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Debug("paddle: successfully applied upgrade")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PaddleBillingService) ScheduleQuantityDowngradeFromNextBillingCycle(
|
||||
logger *slog.Logger,
|
||||
providerSubscriptionID string,
|
||||
quantityGB int,
|
||||
) error {
|
||||
logger = logger.With("provider_subscription_id", providerSubscriptionID)
|
||||
logger.Debug(fmt.Sprintf("paddle: scheduling downgrade from next billing cycle: new storage %d GB", quantityGB))
|
||||
|
||||
// important: paddle requires to send all items
|
||||
// in the subscription when updating, not just the changed one
|
||||
subscription, err := s.client.GetSubscription(context.Background(), &paddle.GetSubscriptionRequest{
|
||||
SubscriptionID: providerSubscriptionID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("paddle: failed to get subscription", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
currentQuantity := subscription.Items[0].Quantity
|
||||
if currentQuantity == quantityGB {
|
||||
logger.Info("paddle: subscription already at requested quantity, skipping downgrade",
|
||||
"current_quantity_gb", currentQuantity,
|
||||
"requested_quantity_gb", quantityGB,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
if subscription.ScheduledChange != nil {
|
||||
logger.Info("paddle: subscription already has a scheduled change, skipping downgrade")
|
||||
return nil
|
||||
}
|
||||
|
||||
priceID := subscription.Items[0].Price.ID
|
||||
|
||||
// apply downgrade from next billing cycle by setting the proration billing mode to "prorate on next billing period"
|
||||
_, err = s.client.UpdateSubscription(context.Background(), &paddle.UpdateSubscriptionRequest{
|
||||
SubscriptionID: providerSubscriptionID,
|
||||
Items: paddle.NewPatchField([]paddle.SubscriptionUpdateCatalogItem{
|
||||
{
|
||||
PriceID: priceID,
|
||||
Quantity: quantityGB,
|
||||
},
|
||||
}),
|
||||
ProrationBillingMode: paddle.NewPatchField(paddle.ProrationBillingModeFullNextBillingPeriod),
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("paddle: failed to update subscription for downgrade", "error", err)
|
||||
return fmt.Errorf("failed to update subscription: %w", err)
|
||||
}
|
||||
|
||||
logger.Debug("paddle: successfully scheduled downgrade from next billing cycle")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PaddleBillingService) GetSubscription(
|
||||
logger *slog.Logger,
|
||||
providerSubscriptionID string,
|
||||
) (billing_provider.ProviderSubscription, error) {
|
||||
logger = logger.With("provider_subscription_id", providerSubscriptionID)
|
||||
logger.Debug("paddle: getting subscription details")
|
||||
|
||||
subscription, err := s.client.GetSubscription(context.Background(), &paddle.GetSubscriptionRequest{
|
||||
SubscriptionID: providerSubscriptionID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("paddle: failed to get subscription", "error", err)
|
||||
return billing_provider.ProviderSubscription{}, err
|
||||
}
|
||||
|
||||
logger.Debug(
|
||||
fmt.Sprintf(
|
||||
"paddle: successfully got subscription details: status=%s, quantity=%d",
|
||||
subscription.Status,
|
||||
subscription.Items[0].Quantity,
|
||||
),
|
||||
)
|
||||
|
||||
return s.toProviderSubscription(logger, subscription)
|
||||
}
|
||||
|
||||
func (s *PaddleBillingService) CreatePortalSession(
|
||||
logger *slog.Logger,
|
||||
providerCustomerID, returnURL string,
|
||||
) (string, error) {
|
||||
logger = logger.With("provider_customer_id", providerCustomerID)
|
||||
logger.Debug("paddle: creating portal session")
|
||||
|
||||
subscriptions, err := s.client.ListSubscriptions(context.Background(), &paddle.ListSubscriptionsRequest{
|
||||
CustomerID: []string{providerCustomerID},
|
||||
Status: []string{
|
||||
string(paddle.SubscriptionStatusActive),
|
||||
string(paddle.SubscriptionStatusPastDue),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("paddle: failed to list subscriptions for portal session", "error", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
res := subscriptions.Next(context.Background())
|
||||
if !res.Ok() {
|
||||
if res.Err() != nil {
|
||||
logger.Error("paddle: failed to iterate subscriptions", "error", res.Err())
|
||||
return "", res.Err()
|
||||
}
|
||||
|
||||
logger.Error("paddle: no active subscriptions found for customer")
|
||||
return "", fmt.Errorf("no active subscriptions found for customer %s", providerCustomerID)
|
||||
}
|
||||
|
||||
subscription := res.Value()
|
||||
if subscription.ManagementURLs.UpdatePaymentMethod == nil {
|
||||
logger.Error("paddle: subscription has no management URL")
|
||||
return "", fmt.Errorf("subscription %s has no management URL", subscription.ID)
|
||||
}
|
||||
|
||||
return *subscription.ManagementURLs.UpdatePaymentMethod, nil
|
||||
}
|
||||
|
||||
func (s *PaddleBillingService) VerifyWebhookSignature(body []byte, headers map[string]string) error {
|
||||
req, _ := http.NewRequestWithContext(context.Background(), "POST", "/", bytes.NewReader(body))
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
ok, err := s.webhookVerified.Verify(req)
|
||||
if err != nil || !ok {
|
||||
return fmt.Errorf("failed to verify webhook signature: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PaddleBillingService) ProcessWebhookEvent(
|
||||
logger *slog.Logger,
|
||||
requestID uuid.UUID,
|
||||
webhookDTO PaddleWebhookDTO,
|
||||
rawBody []byte,
|
||||
) error {
|
||||
webhookEvent, err := s.normalizeWebhookEvent(
|
||||
logger,
|
||||
requestID,
|
||||
webhookDTO.EventID,
|
||||
webhookDTO.EventType,
|
||||
webhookDTO.Data,
|
||||
)
|
||||
if err != nil {
|
||||
if errors.Is(err, billing_webhooks.ErrUnsupportedEventType) {
|
||||
return s.skipWebhookEvent(logger, requestID, webhookDTO, rawBody)
|
||||
}
|
||||
|
||||
logger.Error("paddle: failed to normalize webhook event", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
logArgs := []any{
|
||||
"provider_event_id", webhookEvent.ProviderEventID,
|
||||
"provider_subscription_id", webhookEvent.ProviderSubscriptionID,
|
||||
"provider_customer_id", webhookEvent.ProviderCustomerID,
|
||||
}
|
||||
if webhookEvent.DatabaseID != nil {
|
||||
logArgs = append(logArgs, "database_id", webhookEvent.DatabaseID)
|
||||
}
|
||||
|
||||
logger = logger.With(logArgs...)
|
||||
|
||||
existingRecord, err := s.webhookRepository.FindSuccessfulByProviderEventID(webhookEvent.ProviderEventID)
|
||||
if err == nil && existingRecord != nil {
|
||||
logger.Info("paddle: webhook already processed successfully, skipping",
|
||||
"existing_request_id", existingRecord.RequestID,
|
||||
)
|
||||
return billing_webhooks.ErrDuplicateWebhook
|
||||
}
|
||||
|
||||
webhookRecord := &billing_webhooks.WebhookRecord{
|
||||
RequestID: requestID,
|
||||
ProviderName: billing_provider.ProviderPaddle,
|
||||
EventType: string(webhookEvent.Type),
|
||||
ProviderEventID: webhookEvent.ProviderEventID,
|
||||
RawPayload: string(rawBody),
|
||||
}
|
||||
if err := s.webhookRepository.Insert(webhookRecord); err != nil {
|
||||
logger.Error("paddle: failed to save webhook record", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.processWebhookEvent(logger, webhookEvent); err != nil {
|
||||
logger.Error("paddle: failed to process webhook event", "error", err)
|
||||
|
||||
if markErr := s.webhookRepository.MarkError(requestID.String(), err.Error()); markErr != nil {
|
||||
logger.Error("paddle: failed to mark webhook as errored", "error", markErr)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
if markErr := s.webhookRepository.MarkProcessed(requestID.String()); markErr != nil {
|
||||
logger.Error("paddle: failed to mark webhook as processed", "error", markErr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PaddleBillingService) skipWebhookEvent(
|
||||
logger *slog.Logger,
|
||||
requestID uuid.UUID,
|
||||
webhookDTO PaddleWebhookDTO,
|
||||
rawBody []byte,
|
||||
) error {
|
||||
webhookRecord := &billing_webhooks.WebhookRecord{
|
||||
RequestID: requestID,
|
||||
ProviderName: billing_provider.ProviderPaddle,
|
||||
EventType: webhookDTO.EventType,
|
||||
ProviderEventID: webhookDTO.EventID,
|
||||
RawPayload: string(rawBody),
|
||||
}
|
||||
|
||||
if err := s.webhookRepository.Insert(webhookRecord); err != nil {
|
||||
logger.Error("paddle: failed to save skipped webhook record", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.webhookRepository.MarkSkipped(requestID.String()); err != nil {
|
||||
logger.Error("paddle: failed to mark webhook as skipped", "error", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PaddleBillingService) processWebhookEvent(
|
||||
logger *slog.Logger,
|
||||
webhookEvent billing_models.WebhookEvent,
|
||||
) error {
|
||||
logger.Debug("processing webhook event")
|
||||
|
||||
// subscription.created - there is no subscription in the database yet
|
||||
if webhookEvent.Type == billing_models.WHEventSubscriptionCreated {
|
||||
return s.billingService.ActivateSubscription(logger, webhookEvent)
|
||||
}
|
||||
|
||||
// dispute - finds subscription via invoice, no provider subscription ID available
|
||||
if webhookEvent.Type == billing_models.WHEventSubscriptionDisputeCreated {
|
||||
return s.billingService.RecordDispute(logger, webhookEvent)
|
||||
}
|
||||
|
||||
// for others - search subscription first
|
||||
subscription, err := s.billingService.GetSubscriptionByProviderSubID(logger, webhookEvent.ProviderSubscriptionID)
|
||||
if err != nil {
|
||||
logger.Error("paddle: failed to find subscription for webhook event", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
logger = logger.With("subscription_id", subscription.ID, "database_id", subscription.DatabaseID)
|
||||
logger.Debug(fmt.Sprintf("found subscription in DB with ID: %s", subscription.ID))
|
||||
|
||||
switch webhookEvent.Type {
|
||||
case billing_models.WHEventSubscriptionUpdated:
|
||||
if subscription.Status == billing_models.StatusCanceled {
|
||||
return s.billingService.ReactivateSubscription(logger, subscription, webhookEvent)
|
||||
}
|
||||
|
||||
return s.billingService.SyncSubscriptionFromProvider(logger, subscription, webhookEvent)
|
||||
case billing_models.WHEventSubscriptionCanceled:
|
||||
return s.billingService.CancelSubscription(logger, subscription, webhookEvent)
|
||||
case billing_models.WHEventPaymentSucceeded:
|
||||
return s.billingService.RecordPaymentSuccess(logger, subscription, webhookEvent)
|
||||
case billing_models.WHEventSubscriptionPastDue:
|
||||
return s.billingService.RecordPaymentFailed(logger, subscription, webhookEvent)
|
||||
default:
|
||||
logger.Error(fmt.Sprintf("unhandled webhook event type: %s", string(webhookEvent.Type)))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PaddleBillingService) normalizeWebhookEvent(
|
||||
logger *slog.Logger,
|
||||
requestID uuid.UUID,
|
||||
eventID, eventType string,
|
||||
data json.RawMessage,
|
||||
) (billing_models.WebhookEvent, error) {
|
||||
webhookEvent := billing_models.WebhookEvent{
|
||||
RequestID: requestID,
|
||||
ProviderEventID: eventID,
|
||||
}
|
||||
|
||||
switch eventType {
|
||||
case "subscription.created":
|
||||
webhookEvent.Type = billing_models.WHEventSubscriptionCreated
|
||||
|
||||
var subscription paddle.Subscription
|
||||
if err := json.Unmarshal(data, &subscription); err != nil {
|
||||
logger.Error("paddle: failed to unmarshal subscription.created webhook data", "error", err)
|
||||
return webhookEvent, err
|
||||
}
|
||||
|
||||
webhookEvent.ProviderSubscriptionID = subscription.ID
|
||||
webhookEvent.ProviderCustomerID = subscription.CustomerID
|
||||
webhookEvent.QuantityGB = subscription.Items[0].Quantity
|
||||
status, err := mapPaddleStatus(logger, subscription.Status)
|
||||
if err != nil {
|
||||
return webhookEvent, err
|
||||
}
|
||||
|
||||
webhookEvent.Status = status
|
||||
|
||||
if subscription.CurrentBillingPeriod != nil {
|
||||
periodStart := mustParseRFC3339(logger, "period start", subscription.CurrentBillingPeriod.StartsAt)
|
||||
periodEnd := mustParseRFC3339(logger, "period end", subscription.CurrentBillingPeriod.EndsAt)
|
||||
|
||||
webhookEvent.PeriodStart = &periodStart
|
||||
webhookEvent.PeriodEnd = &periodEnd
|
||||
}
|
||||
|
||||
if subscription.CustomData == nil || subscription.CustomData["database_id"] == "" {
|
||||
logger.Error("paddle: subscription has no database_id in custom data")
|
||||
}
|
||||
|
||||
databaseIDStr, isOk := subscription.CustomData["database_id"].(string)
|
||||
if !isOk {
|
||||
logger.Error("paddle: database_id in custom data is not a string")
|
||||
return webhookEvent, fmt.Errorf("invalid database_id type in custom data")
|
||||
}
|
||||
|
||||
databaseID := uuid.MustParse(databaseIDStr)
|
||||
webhookEvent.DatabaseID = &databaseID
|
||||
|
||||
case "subscription.updated":
|
||||
var subscription paddle.Subscription
|
||||
if err := json.Unmarshal(data, &subscription); err != nil {
|
||||
return webhookEvent, err
|
||||
}
|
||||
|
||||
webhookEvent.ProviderSubscriptionID = subscription.ID
|
||||
webhookEvent.ProviderCustomerID = subscription.CustomerID
|
||||
webhookEvent.QuantityGB = subscription.Items[0].Quantity
|
||||
webhookEvent.Type = billing_models.WHEventSubscriptionUpdated
|
||||
|
||||
status, err := mapPaddleStatus(logger, subscription.Status)
|
||||
if err != nil {
|
||||
return webhookEvent, err
|
||||
}
|
||||
|
||||
webhookEvent.Status = status
|
||||
|
||||
if subscription.CurrentBillingPeriod != nil {
|
||||
periodStart := mustParseRFC3339(logger, "period start", subscription.CurrentBillingPeriod.StartsAt)
|
||||
periodEnd := mustParseRFC3339(logger, "period end", subscription.CurrentBillingPeriod.EndsAt)
|
||||
|
||||
webhookEvent.PeriodStart = &periodStart
|
||||
webhookEvent.PeriodEnd = &periodEnd
|
||||
}
|
||||
|
||||
if subscription.ScheduledChange != nil &&
|
||||
subscription.ScheduledChange.Action == paddle.ScheduledChangeActionCancel {
|
||||
webhookEvent.Type = billing_models.WHEventSubscriptionCanceled
|
||||
}
|
||||
|
||||
case "subscription.canceled":
|
||||
webhookEvent.Type = billing_models.WHEventSubscriptionCanceled
|
||||
|
||||
var subscription paddle.Subscription
|
||||
if err := json.Unmarshal(data, &subscription); err != nil {
|
||||
return webhookEvent, err
|
||||
}
|
||||
|
||||
webhookEvent.ProviderSubscriptionID = subscription.ID
|
||||
webhookEvent.ProviderCustomerID = subscription.CustomerID
|
||||
|
||||
status, err := mapPaddleStatus(logger, subscription.Status)
|
||||
if err != nil {
|
||||
return webhookEvent, err
|
||||
}
|
||||
|
||||
webhookEvent.Status = status
|
||||
|
||||
case "transaction.completed":
|
||||
webhookEvent.Type = billing_models.WHEventPaymentSucceeded
|
||||
|
||||
var transaction paddle.Transaction
|
||||
if err := json.Unmarshal(data, &transaction); err != nil {
|
||||
return webhookEvent, err
|
||||
}
|
||||
|
||||
webhookEvent.ProviderInvoiceID = transaction.ID
|
||||
|
||||
if len(transaction.Items) > 0 {
|
||||
webhookEvent.QuantityGB = transaction.Items[0].Quantity
|
||||
}
|
||||
|
||||
if transaction.SubscriptionID != nil {
|
||||
webhookEvent.ProviderSubscriptionID = *transaction.SubscriptionID
|
||||
}
|
||||
|
||||
if transaction.CustomerID != nil {
|
||||
webhookEvent.ProviderCustomerID = *transaction.CustomerID
|
||||
}
|
||||
|
||||
amountCents, err := strconv.ParseInt(transaction.Details.Totals.Total, 10, 64)
|
||||
if err != nil {
|
||||
logger.Error("paddle: failed to parse transaction total", "error", err)
|
||||
} else {
|
||||
webhookEvent.AmountCents = amountCents
|
||||
}
|
||||
|
||||
if transaction.BillingPeriod != nil {
|
||||
periodStart := mustParseRFC3339(logger, "period start", transaction.BillingPeriod.StartsAt)
|
||||
periodEnd := mustParseRFC3339(logger, "period end", transaction.BillingPeriod.EndsAt)
|
||||
|
||||
webhookEvent.PeriodStart = &periodStart
|
||||
webhookEvent.PeriodEnd = &periodEnd
|
||||
}
|
||||
|
||||
case "subscription.past_due":
|
||||
webhookEvent.Type = billing_models.WHEventSubscriptionPastDue
|
||||
|
||||
var subscription paddle.Subscription
|
||||
if err := json.Unmarshal(data, &subscription); err != nil {
|
||||
return webhookEvent, err
|
||||
}
|
||||
|
||||
webhookEvent.ProviderSubscriptionID = subscription.ID
|
||||
webhookEvent.ProviderCustomerID = subscription.CustomerID
|
||||
webhookEvent.QuantityGB = subscription.Items[0].Quantity
|
||||
|
||||
status, err := mapPaddleStatus(logger, subscription.Status)
|
||||
if err != nil {
|
||||
return webhookEvent, err
|
||||
}
|
||||
|
||||
webhookEvent.Status = status
|
||||
|
||||
if subscription.CurrentBillingPeriod != nil {
|
||||
periodStart := mustParseRFC3339(logger, "period start", subscription.CurrentBillingPeriod.StartsAt)
|
||||
periodEnd := mustParseRFC3339(logger, "period end", subscription.CurrentBillingPeriod.EndsAt)
|
||||
|
||||
webhookEvent.PeriodStart = &periodStart
|
||||
webhookEvent.PeriodEnd = &periodEnd
|
||||
}
|
||||
|
||||
case "adjustment.created":
|
||||
webhookEvent.Type = billing_models.WHEventSubscriptionDisputeCreated
|
||||
|
||||
var adjustment struct {
|
||||
TransactionID string `json:"transaction_id"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &adjustment); err != nil {
|
||||
return webhookEvent, err
|
||||
}
|
||||
|
||||
webhookEvent.ProviderInvoiceID = adjustment.TransactionID
|
||||
|
||||
default:
|
||||
logger.Debug("unsupported paddle event type, skipping", "event_type", eventType)
|
||||
return webhookEvent, billing_webhooks.ErrUnsupportedEventType
|
||||
}
|
||||
|
||||
return webhookEvent, nil
|
||||
}
|
||||
|
||||
func (s *PaddleBillingService) toProviderSubscription(
|
||||
logger *slog.Logger,
|
||||
paddleSubscription *paddle.Subscription,
|
||||
) (billing_provider.ProviderSubscription, error) {
|
||||
status, err := mapPaddleStatus(logger, paddleSubscription.Status)
|
||||
if err != nil {
|
||||
return billing_provider.ProviderSubscription{}, err
|
||||
}
|
||||
|
||||
if len(paddleSubscription.Items) == 0 {
|
||||
return billing_provider.ProviderSubscription{}, fmt.Errorf(
|
||||
"paddle subscription %s has no items",
|
||||
paddleSubscription.ID,
|
||||
)
|
||||
}
|
||||
|
||||
providerSubscription := &billing_provider.ProviderSubscription{
|
||||
ProviderSubscriptionID: paddleSubscription.ID,
|
||||
ProviderCustomerID: paddleSubscription.CustomerID,
|
||||
Status: status,
|
||||
QuantityGB: paddleSubscription.Items[0].Quantity,
|
||||
}
|
||||
|
||||
if paddleSubscription.CurrentBillingPeriod != nil {
|
||||
providerSubscription.PeriodStart = mustParseRFC3339(
|
||||
logger,
|
||||
"period start",
|
||||
paddleSubscription.CurrentBillingPeriod.StartsAt,
|
||||
)
|
||||
providerSubscription.PeriodEnd = mustParseRFC3339(
|
||||
logger,
|
||||
"period end",
|
||||
paddleSubscription.CurrentBillingPeriod.EndsAt,
|
||||
)
|
||||
}
|
||||
|
||||
return *providerSubscription, nil
|
||||
}
|
||||
|
||||
func mustParseRFC3339(logger *slog.Logger, label, value string) time.Time {
|
||||
parsed, err := time.Parse(time.RFC3339, value)
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("paddle: failed to parse %s", label), "error", err)
|
||||
}
|
||||
|
||||
return parsed
|
||||
}
|
||||
|
||||
func mapPaddleStatus(logger *slog.Logger, s paddle.SubscriptionStatus) (billing_models.SubscriptionStatus, error) {
|
||||
switch s {
|
||||
case paddle.SubscriptionStatusActive:
|
||||
return billing_models.StatusActive, nil
|
||||
case paddle.SubscriptionStatusPastDue:
|
||||
return billing_models.StatusPastDue, nil
|
||||
case paddle.SubscriptionStatusCanceled:
|
||||
return billing_models.StatusCanceled, nil
|
||||
case paddle.SubscriptionStatusTrialing:
|
||||
return billing_models.StatusTrial, nil
|
||||
case paddle.SubscriptionStatusPaused:
|
||||
return billing_models.StatusCanceled, nil
|
||||
default:
|
||||
logger.Error(fmt.Sprintf("paddle: unknown subscription status: %s", string(s)))
|
||||
|
||||
return "", fmt.Errorf("paddle: unknown subscription status: %s", string(s))
|
||||
}
|
||||
}
|
||||
38
backend/internal/features/billing/provider/dto.go
Normal file
38
backend/internal/features/billing/provider/dto.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package billing_provider
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
billing_models "databasus-backend/internal/features/billing/models"
|
||||
)
|
||||
|
||||
type CreateSubscriptionRequest struct {
|
||||
ProviderCustomerID string
|
||||
DatabaseID uuid.UUID
|
||||
StorageGB int
|
||||
}
|
||||
|
||||
type ProviderSubscription struct {
|
||||
ProviderSubscriptionID string
|
||||
ProviderCustomerID string
|
||||
Status billing_models.SubscriptionStatus
|
||||
QuantityGB int
|
||||
PeriodStart time.Time
|
||||
PeriodEnd time.Time
|
||||
}
|
||||
|
||||
type CheckoutRequest struct {
|
||||
DatabaseID uuid.UUID
|
||||
Email string
|
||||
StorageGB int
|
||||
SuccessURL string
|
||||
CancelURL string
|
||||
}
|
||||
|
||||
type ProviderName string
|
||||
|
||||
const (
|
||||
ProviderPaddle ProviderName = "paddle"
|
||||
)
|
||||
21
backend/internal/features/billing/provider/provider.go
Normal file
21
backend/internal/features/billing/provider/provider.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package billing_provider
|
||||
|
||||
import "log/slog"
|
||||
|
||||
type BillingProvider interface {
|
||||
GetProviderName() ProviderName
|
||||
|
||||
UpgradeQuantityWithSurcharge(logger *slog.Logger, providerSubscriptionID string, quantityGB int) error
|
||||
|
||||
ScheduleQuantityDowngradeFromNextBillingCycle(
|
||||
logger *slog.Logger,
|
||||
providerSubscriptionID string,
|
||||
quantityGB int,
|
||||
) error
|
||||
|
||||
GetSubscription(logger *slog.Logger, providerSubscriptionID string) (ProviderSubscription, error)
|
||||
|
||||
CreateCheckoutSession(logger *slog.Logger, req CheckoutRequest) (checkoutURL string, err error)
|
||||
|
||||
CreatePortalSession(logger *slog.Logger, providerCustomerID, returnURL string) (portalURL string, err error)
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
package billing_repositories
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
|
||||
billing_models "databasus-backend/internal/features/billing/models"
|
||||
"databasus-backend/internal/storage"
|
||||
)
|
||||
|
||||
type InvoiceRepository struct{}
|
||||
|
||||
func (r *InvoiceRepository) Save(invoice billing_models.Invoice) error {
|
||||
if invoice.SubscriptionID == uuid.Nil {
|
||||
return errors.New("subscription id is required")
|
||||
}
|
||||
|
||||
db := storage.GetDb()
|
||||
|
||||
if invoice.ID == uuid.Nil {
|
||||
invoice.ID = uuid.New()
|
||||
return db.Create(&invoice).Error
|
||||
}
|
||||
|
||||
return db.Save(invoice).Error
|
||||
}
|
||||
|
||||
func (r *InvoiceRepository) FindByProviderInvID(providerInvoiceID string) (*billing_models.Invoice, error) {
|
||||
var invoice billing_models.Invoice
|
||||
|
||||
if err := storage.GetDb().Where("provider_invoice_id = ?", providerInvoiceID).
|
||||
First(&invoice).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &invoice, nil
|
||||
}
|
||||
|
||||
func (r *InvoiceRepository) FindByDatabaseID(
|
||||
databaseID uuid.UUID,
|
||||
limit, offset int,
|
||||
) ([]*billing_models.Invoice, error) {
|
||||
var invoices []*billing_models.Invoice
|
||||
|
||||
if err := storage.GetDb().Joins("JOIN subscriptions ON subscriptions.id = invoices.subscription_id").
|
||||
Where("subscriptions.database_id = ?", databaseID).
|
||||
Order("invoices.created_at DESC").
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&invoices).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invoices, nil
|
||||
}
|
||||
|
||||
func (r *InvoiceRepository) CountByDatabaseID(databaseID uuid.UUID) (int64, error) {
|
||||
var count int64
|
||||
|
||||
err := storage.GetDb().Model(&billing_models.Invoice{}).
|
||||
Joins("JOIN subscriptions ON subscriptions.id = invoices.subscription_id").
|
||||
Where("subscriptions.database_id = ?", databaseID).
|
||||
Count(&count).Error
|
||||
|
||||
return count, err
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
package billing_repositories
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
billing_models "databasus-backend/internal/features/billing/models"
|
||||
"databasus-backend/internal/storage"
|
||||
)
|
||||
|
||||
type SubscriptionEventRepository struct{}
|
||||
|
||||
func (r *SubscriptionEventRepository) Create(event billing_models.SubscriptionEvent) error {
|
||||
if event.SubscriptionID == uuid.Nil {
|
||||
return errors.New("subscription id is required")
|
||||
}
|
||||
|
||||
event.ID = uuid.New()
|
||||
return storage.GetDb().Create(&event).Error
|
||||
}
|
||||
|
||||
func (r *SubscriptionEventRepository) FindByDatabaseID(
|
||||
databaseID uuid.UUID,
|
||||
limit, offset int,
|
||||
) ([]*billing_models.SubscriptionEvent, error) {
|
||||
var events []*billing_models.SubscriptionEvent
|
||||
|
||||
if err := storage.GetDb().Joins("JOIN subscriptions ON subscriptions.id = subscription_events.subscription_id").
|
||||
Where("subscriptions.database_id = ?", databaseID).
|
||||
Order("subscription_events.created_at DESC").
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&events).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
func (r *SubscriptionEventRepository) CountByDatabaseID(databaseID uuid.UUID) (int64, error) {
|
||||
var count int64
|
||||
|
||||
err := storage.GetDb().Model(&billing_models.SubscriptionEvent{}).
|
||||
Joins("JOIN subscriptions ON subscriptions.id = subscription_events.subscription_id").
|
||||
Where("subscriptions.database_id = ?", databaseID).
|
||||
Count(&count).Error
|
||||
|
||||
return count, err
|
||||
}
|
||||
@@ -0,0 +1,123 @@
|
||||
package billing_repositories
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
|
||||
billing_models "databasus-backend/internal/features/billing/models"
|
||||
"databasus-backend/internal/storage"
|
||||
)
|
||||
|
||||
type SubscriptionRepository struct{}
|
||||
|
||||
func (r *SubscriptionRepository) Save(sub billing_models.Subscription) error {
|
||||
db := storage.GetDb()
|
||||
|
||||
if sub.ID == uuid.Nil {
|
||||
sub.ID = uuid.New()
|
||||
return db.Create(&sub).Error
|
||||
}
|
||||
|
||||
return db.Save(&sub).Error
|
||||
}
|
||||
|
||||
func (r *SubscriptionRepository) FindByID(id uuid.UUID) (*billing_models.Subscription, error) {
|
||||
var sub billing_models.Subscription
|
||||
|
||||
if err := storage.GetDb().Where("id = ?", id).First(&sub).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
func (r *SubscriptionRepository) FindByDatabaseIDAndStatuses(
|
||||
databaseID uuid.UUID,
|
||||
stauses []billing_models.SubscriptionStatus,
|
||||
) ([]*billing_models.Subscription, error) {
|
||||
var subs []*billing_models.Subscription
|
||||
|
||||
if err := storage.GetDb().Where("database_id = ? AND status IN ?", databaseID, stauses).
|
||||
Find(&subs).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return subs, nil
|
||||
}
|
||||
|
||||
func (r *SubscriptionRepository) FindLatestByDatabaseID(databaseID uuid.UUID) (*billing_models.Subscription, error) {
|
||||
var sub billing_models.Subscription
|
||||
|
||||
if err := storage.GetDb().
|
||||
Where("database_id = ?", databaseID).
|
||||
Order("created_at DESC").
|
||||
First(&sub).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
func (r *SubscriptionRepository) FindByProviderSubID(providerSubID string) (*billing_models.Subscription, error) {
|
||||
var sub billing_models.Subscription
|
||||
|
||||
if err := storage.GetDb().Where("provider_sub_id = ?", providerSubID).
|
||||
First(&sub).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
func (r *SubscriptionRepository) FindByStatuses(
|
||||
statuses []billing_models.SubscriptionStatus,
|
||||
) ([]billing_models.Subscription, error) {
|
||||
var subs []billing_models.Subscription
|
||||
|
||||
if err := storage.GetDb().Where("status IN ?", statuses).Find(&subs).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return subs, nil
|
||||
}
|
||||
|
||||
func (r *SubscriptionRepository) FindCanceledWithEndedGracePeriod(
|
||||
now time.Time,
|
||||
) ([]billing_models.Subscription, error) {
|
||||
var subs []billing_models.Subscription
|
||||
|
||||
if err := storage.GetDb().
|
||||
Where("status = ? AND data_retention_grace_period_until < ?", billing_models.StatusCanceled, now).
|
||||
Find(&subs).
|
||||
Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return subs, nil
|
||||
}
|
||||
|
||||
func (r *SubscriptionRepository) FindExpiredTrials(now time.Time) ([]billing_models.Subscription, error) {
|
||||
var subs []billing_models.Subscription
|
||||
|
||||
if err := storage.GetDb().Where("status = ? AND current_period_end < ?", billing_models.StatusTrial, now).
|
||||
Find(&subs).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return subs, nil
|
||||
}
|
||||
1253
backend/internal/features/billing/service.go
Normal file
1253
backend/internal/features/billing/service.go
Normal file
File diff suppressed because it is too large
Load Diff
8
backend/internal/features/billing/webhooks/errors.go
Normal file
8
backend/internal/features/billing/webhooks/errors.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package billing_webhooks
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrDuplicateWebhook = errors.New("duplicate webhook event")
|
||||
ErrUnsupportedEventType = errors.New("unsupported webhook event type")
|
||||
)
|
||||
25
backend/internal/features/billing/webhooks/model.go
Normal file
25
backend/internal/features/billing/webhooks/model.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package billing_webhooks
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
billing_provider "databasus-backend/internal/features/billing/provider"
|
||||
)
|
||||
|
||||
type WebhookRecord struct {
|
||||
RequestID uuid.UUID `gorm:"column:request_id;primaryKey;type:uuid;default:gen_random_uuid()"`
|
||||
ProviderName billing_provider.ProviderName `gorm:"column:provider_name;type:text;not null"`
|
||||
EventType string `gorm:"column:event_type;type:text;not null"`
|
||||
ProviderEventID string `gorm:"column:provider_event_id;type:text;not null;index"`
|
||||
RawPayload string `gorm:"column:raw_payload;type:text;not null"`
|
||||
ProcessedAt *time.Time `gorm:"column:processed_at"`
|
||||
IsSkipped bool `gorm:"column:is_skipped;not null;default:false"`
|
||||
Error *string `gorm:"column:error"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;not null"`
|
||||
}
|
||||
|
||||
func (WebhookRecord) TableName() string {
|
||||
return "webhook_records"
|
||||
}
|
||||
73
backend/internal/features/billing/webhooks/repository.go
Normal file
73
backend/internal/features/billing/webhooks/repository.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package billing_webhooks
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"databasus-backend/internal/storage"
|
||||
)
|
||||
|
||||
type WebhookRepository struct{}
|
||||
|
||||
func (r *WebhookRepository) FindSuccessfulByProviderEventID(providerEventID string) (*WebhookRecord, error) {
|
||||
var record WebhookRecord
|
||||
|
||||
err := storage.GetDb().
|
||||
Where("provider_event_id = ? AND processed_at IS NOT NULL", providerEventID).
|
||||
First(&record).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
func (r *WebhookRepository) Insert(record *WebhookRecord) error {
|
||||
if record.ProviderEventID == "" {
|
||||
return errors.New("provider event ID is required")
|
||||
}
|
||||
|
||||
record.CreatedAt = time.Now().UTC()
|
||||
|
||||
return storage.GetDb().Create(record).Error
|
||||
}
|
||||
|
||||
func (r *WebhookRepository) MarkProcessed(requestID string) error {
|
||||
now := time.Now().UTC()
|
||||
|
||||
return storage.
|
||||
GetDb().
|
||||
Model(&WebhookRecord{}).
|
||||
Where("request_id = ?", requestID).
|
||||
Update("processed_at", now).
|
||||
Error
|
||||
}
|
||||
|
||||
func (r *WebhookRepository) MarkSkipped(requestID string) error {
|
||||
now := time.Now().UTC()
|
||||
|
||||
return storage.
|
||||
GetDb().
|
||||
Model(&WebhookRecord{}).
|
||||
Where("request_id = ?", requestID).
|
||||
Updates(map[string]any{
|
||||
"is_skipped": true,
|
||||
"processed_at": now,
|
||||
}).
|
||||
Error
|
||||
}
|
||||
|
||||
func (r *WebhookRepository) MarkError(requestID, errMsg string) error {
|
||||
return storage.
|
||||
GetDb().
|
||||
Model(&WebhookRecord{}).
|
||||
Where("request_id = ?", requestID).
|
||||
Update("error", errMsg).
|
||||
Error
|
||||
}
|
||||
@@ -1328,6 +1328,143 @@ func getTestPostgresConfig() *postgresql.PostgresqlDatabase {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateDatabase_WhenCloudAndUserIsNotReadOnly_ReturnsBadRequest(t *testing.T) {
|
||||
enableCloud(t)
|
||||
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Cloud Not ReadOnly", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
request := Database{
|
||||
Name: "Cloud Non-ReadOnly DB",
|
||||
WorkspaceID: &workspace.ID,
|
||||
Type: DatabaseTypePostgres,
|
||||
Postgresql: getTestPostgresConfig(),
|
||||
}
|
||||
|
||||
resp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/databases/create",
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(resp.Body), "in cloud mode, only read-only database users are allowed")
|
||||
}
|
||||
|
||||
func Test_CreateDatabase_WhenCloudAndUserIsReadOnly_DatabaseCreated(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Cloud ReadOnly", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Temp DB for RO User", workspace.ID, owner.Token, router)
|
||||
|
||||
readOnlyUser := createReadOnlyUserViaAPI(t, router, database.ID, owner.Token)
|
||||
assert.NotEmpty(t, readOnlyUser.Username)
|
||||
assert.NotEmpty(t, readOnlyUser.Password)
|
||||
|
||||
RemoveTestDatabase(database)
|
||||
|
||||
enableCloud(t)
|
||||
|
||||
pgConfig := getTestPostgresConfig()
|
||||
pgConfig.Username = readOnlyUser.Username
|
||||
pgConfig.Password = readOnlyUser.Password
|
||||
|
||||
request := Database{
|
||||
Name: "Cloud ReadOnly DB",
|
||||
WorkspaceID: &workspace.ID,
|
||||
Type: DatabaseTypePostgres,
|
||||
Postgresql: pgConfig,
|
||||
}
|
||||
|
||||
var response Database
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/databases/create",
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusCreated,
|
||||
&response,
|
||||
)
|
||||
defer RemoveTestDatabase(&response)
|
||||
|
||||
assert.Equal(t, "Cloud ReadOnly DB", response.Name)
|
||||
assert.NotEqual(t, uuid.Nil, response.ID)
|
||||
}
|
||||
|
||||
func Test_CreateDatabase_WhenNotCloudAndUserIsNotReadOnly_DatabaseCreated(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Non-Cloud", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
request := Database{
|
||||
Name: "Non-Cloud DB",
|
||||
WorkspaceID: &workspace.ID,
|
||||
Type: DatabaseTypePostgres,
|
||||
Postgresql: getTestPostgresConfig(),
|
||||
}
|
||||
|
||||
var response Database
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/databases/create",
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusCreated,
|
||||
&response,
|
||||
)
|
||||
defer RemoveTestDatabase(&response)
|
||||
|
||||
assert.Equal(t, "Non-Cloud DB", response.Name)
|
||||
assert.NotEqual(t, uuid.Nil, response.ID)
|
||||
}
|
||||
|
||||
func enableCloud(t *testing.T) {
|
||||
t.Helper()
|
||||
config.GetEnv().IsCloud = true
|
||||
t.Cleanup(func() {
|
||||
config.GetEnv().IsCloud = false
|
||||
})
|
||||
}
|
||||
|
||||
func createReadOnlyUserViaAPI(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
databaseID uuid.UUID,
|
||||
token string,
|
||||
) *CreateReadOnlyUserResponse {
|
||||
var database Database
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/databases/%s", databaseID.String()),
|
||||
"Bearer "+token,
|
||||
http.StatusOK,
|
||||
&database,
|
||||
)
|
||||
|
||||
var response CreateReadOnlyUserResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/databases/create-readonly-user",
|
||||
"Bearer "+token,
|
||||
database,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
return &response
|
||||
}
|
||||
|
||||
func getTestMariadbConfig() *mariadb.MariadbDatabase {
|
||||
env := config.GetEnv()
|
||||
portStr := env.TestMariadb1011Port
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package mariadb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
@@ -212,7 +211,7 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
|
||||
|
||||
mariadbModel := createMariadbModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
|
||||
isReadOnly, privileges, err := mariadbModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
@@ -241,7 +240,7 @@ func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
|
||||
|
||||
mariadbModel := createMariadbModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
|
||||
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
@@ -313,7 +312,7 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
|
||||
|
||||
mariadbModel := createMariadbModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
|
||||
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
@@ -390,7 +389,7 @@ func Test_ReadOnlyUser_FutureTables_NoSelectPermission(t *testing.T) {
|
||||
|
||||
mariadbModel := createMariadbModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
|
||||
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
@@ -466,7 +465,7 @@ func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
|
||||
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
@@ -511,7 +510,7 @@ func Test_ReadOnlyUser_CannotDropOrAlterTables(t *testing.T) {
|
||||
|
||||
mariadbModel := createMariadbModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
|
||||
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -42,9 +42,9 @@ func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMongodbContainer(t, tc.port, tc.version)
|
||||
defer container.Client.Disconnect(context.Background())
|
||||
defer container.Client.Disconnect(t.Context())
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
db := container.Client.Database(container.Database)
|
||||
|
||||
_ = db.Collection("permission_test").Drop(ctx)
|
||||
@@ -108,9 +108,9 @@ func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMongodbContainer(t, tc.port, tc.version)
|
||||
defer container.Client.Disconnect(context.Background())
|
||||
defer container.Client.Disconnect(t.Context())
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
db := container.Client.Database(container.Database)
|
||||
|
||||
_ = db.Collection("backup_test").Drop(ctx)
|
||||
@@ -178,11 +178,11 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMongodbContainer(t, tc.port, tc.version)
|
||||
defer container.Client.Disconnect(context.Background())
|
||||
defer container.Client.Disconnect(t.Context())
|
||||
|
||||
mongodbModel := createMongodbModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
|
||||
isReadOnly, roles, err := mongodbModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
@@ -195,9 +195,9 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
|
||||
func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMongodbContainer(t, env.TestMongodb70Port, tools.MongodbVersion7)
|
||||
defer container.Client.Disconnect(context.Background())
|
||||
defer container.Client.Disconnect(t.Context())
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
db := container.Client.Database(container.Database)
|
||||
|
||||
_ = db.Collection("readonly_check_test").Drop(ctx)
|
||||
@@ -251,15 +251,15 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMongodbContainer(t, tc.port, tc.version)
|
||||
defer container.Client.Disconnect(context.Background())
|
||||
defer container.Client.Disconnect(t.Context())
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
db := container.Client.Database(container.Database)
|
||||
|
||||
_ = db.Collection("readonly_test").Drop(ctx)
|
||||
_ = db.Collection("hack_collection").Drop(ctx)
|
||||
|
||||
_, err := db.Collection("readonly_test").InsertMany(ctx, []interface{}{
|
||||
_, err := db.Collection("readonly_test").InsertMany(ctx, []any{
|
||||
bson.M{"data": "test1"},
|
||||
bson.M{"data": "test2"},
|
||||
})
|
||||
@@ -317,9 +317,9 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
|
||||
func Test_ReadOnlyUser_FutureCollections_CanSelect(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMongodbContainer(t, env.TestMongodb70Port, tools.MongodbVersion7)
|
||||
defer container.Client.Disconnect(context.Background())
|
||||
defer container.Client.Disconnect(t.Context())
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
db := container.Client.Database(container.Database)
|
||||
|
||||
mongodbModel := createMongodbModel(container)
|
||||
@@ -348,9 +348,9 @@ func Test_ReadOnlyUser_FutureCollections_CanSelect(t *testing.T) {
|
||||
func Test_ReadOnlyUser_CannotDropOrModifyCollections(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMongodbContainer(t, env.TestMongodb70Port, tools.MongodbVersion7)
|
||||
defer container.Client.Disconnect(context.Background())
|
||||
defer container.Client.Disconnect(t.Context())
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
db := container.Client.Database(container.Database)
|
||||
|
||||
_ = db.Collection("drop_test").Drop(ctx)
|
||||
@@ -420,7 +420,7 @@ func connectToMongodbContainer(
|
||||
authDatabase,
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
clientOptions := options.Client().ApplyURI(uri)
|
||||
@@ -473,7 +473,7 @@ func connectWithCredentials(
|
||||
container.Database, container.AuthDatabase,
|
||||
)
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
clientOptions := options.Client().ApplyURI(uri)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
@@ -231,7 +230,7 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
|
||||
|
||||
mysqlModel := createMysqlModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
|
||||
isReadOnly, privileges, err := mysqlModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
@@ -260,7 +259,7 @@ func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
|
||||
|
||||
mysqlModel := createMysqlModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
|
||||
username, password, err := mysqlModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
@@ -326,7 +325,7 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
|
||||
|
||||
mysqlModel := createMysqlModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
|
||||
username, password, err := mysqlModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
@@ -400,7 +399,7 @@ func Test_ReadOnlyUser_FutureTables_NoSelectPermission(t *testing.T) {
|
||||
|
||||
mysqlModel := createMysqlModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
|
||||
username, password, err := mysqlModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
@@ -477,7 +476,7 @@ func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
|
||||
username, password, err := mysqlModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
@@ -523,7 +522,7 @@ func Test_ReadOnlyUser_CannotDropOrAlterTables(t *testing.T) {
|
||||
|
||||
mysqlModel := createMysqlModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
|
||||
username, password, err := mysqlModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -81,8 +81,8 @@ func (p *PostgresqlDatabase) Validate() error {
|
||||
p.BackupType = PostgresBackupTypePgDump
|
||||
}
|
||||
|
||||
if p.BackupType == PostgresBackupTypePgDump && config.GetEnv().IsCloud {
|
||||
return errors.New("PG_DUMP backup type is not supported in cloud mode")
|
||||
if p.BackupType != PostgresBackupTypePgDump && config.GetEnv().IsCloud {
|
||||
return errors.New("only PG_DUMP backup type is supported in cloud mode")
|
||||
}
|
||||
|
||||
if p.BackupType == PostgresBackupTypePgDump {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
@@ -267,7 +266,7 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
|
||||
isReadOnly, privileges, err := pgModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
@@ -294,7 +293,7 @@ func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
@@ -359,7 +358,7 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
@@ -438,7 +437,7 @@ func Test_ReadOnlyUser_FutureTables_HaveSelectPermission(t *testing.T) {
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
@@ -491,7 +490,7 @@ func Test_ReadOnlyUser_MultipleSchemas_AllAccessible(t *testing.T) {
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
@@ -566,7 +565,7 @@ func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
@@ -653,7 +652,7 @@ func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
|
||||
connectionUsername, newPassword, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
@@ -743,7 +742,7 @@ func Test_CreateReadOnlyUser_WithPublicSchema_Success(t *testing.T) {
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
@@ -851,7 +850,7 @@ func Test_CreateReadOnlyUser_WithoutPublicSchema_Success(t *testing.T) {
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err, "CreateReadOnlyUser should succeed without public schema")
|
||||
@@ -1018,7 +1017,7 @@ func Test_CreateReadOnlyUser_PublicSchemaExistsButNoPermissions_ReturnsError(t *
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.Error(
|
||||
@@ -1310,6 +1309,46 @@ func Test_Validate_WhenRequiredFieldsMissing_ReturnsError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Validate_WhenCloudAndBackupTypeIsNotPgDump_ValidationFails(t *testing.T) {
|
||||
enableCloud(t)
|
||||
|
||||
model := &PostgresqlDatabase{
|
||||
Host: "example.com",
|
||||
Port: 5432,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
CpuCount: 1,
|
||||
BackupType: PostgresBackupTypeWalV1,
|
||||
}
|
||||
|
||||
err := model.Validate()
|
||||
assert.EqualError(t, err, "only PG_DUMP backup type is supported in cloud mode")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenCloudAndBackupTypeIsPgDump_ValidationPasses(t *testing.T) {
|
||||
enableCloud(t)
|
||||
|
||||
model := &PostgresqlDatabase{
|
||||
Host: "example.com",
|
||||
Port: 5432,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
CpuCount: 1,
|
||||
BackupType: PostgresBackupTypePgDump,
|
||||
}
|
||||
|
||||
err := model.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func enableCloud(t *testing.T) {
|
||||
t.Helper()
|
||||
config.GetEnv().IsCloud = true
|
||||
t.Cleanup(func() {
|
||||
config.GetEnv().IsCloud = false
|
||||
})
|
||||
}
|
||||
|
||||
type PostgresContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
@@ -1395,7 +1434,7 @@ func Test_CreateReadOnlyUser_TablesCreatedByDifferentUser_ReadOnlyUserCanRead(t
|
||||
// At this point, user_creator already owns objects, so ALTER DEFAULT PRIVILEGES FOR ROLE should apply
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
|
||||
readonlyUsername, readonlyPassword, err := pgModel.CreateReadOnlyUser(
|
||||
ctx,
|
||||
@@ -1562,7 +1601,7 @@ func Test_CreateReadOnlyUser_WithIncludeSchemas_OnlyGrantsAccessToSpecifiedSchem
|
||||
pgModel.IncludeSchemas = []string{"public", "included_schema"}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
|
||||
readonlyUsername, readonlyPassword, err := pgModel.CreateReadOnlyUser(
|
||||
ctx,
|
||||
|
||||
@@ -2,7 +2,6 @@ package databases
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
@@ -40,22 +39,7 @@ func GetDatabaseController() *DatabaseController {
|
||||
return databaseController
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService)
|
||||
notifiers.GetNotifierService().SetNotifierDatabaseCounter(databaseService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
var SetupDependencies = sync.OnceFunc(func() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService)
|
||||
notifiers.GetNotifierService().SetNotifierDatabaseCounter(databaseService)
|
||||
})
|
||||
|
||||
@@ -25,16 +25,16 @@ type Database struct {
|
||||
Name string `json:"name" gorm:"column:name;type:text;not null"`
|
||||
Type DatabaseType `json:"type" gorm:"column:type;type:text;not null"`
|
||||
|
||||
Postgresql *postgresql.PostgresqlDatabase `json:"postgresql,omitempty" gorm:"foreignKey:DatabaseID"`
|
||||
Mysql *mysql.MysqlDatabase `json:"mysql,omitempty" gorm:"foreignKey:DatabaseID"`
|
||||
Mariadb *mariadb.MariadbDatabase `json:"mariadb,omitempty" gorm:"foreignKey:DatabaseID"`
|
||||
Mongodb *mongodb.MongodbDatabase `json:"mongodb,omitempty" gorm:"foreignKey:DatabaseID"`
|
||||
Postgresql *postgresql.PostgresqlDatabase `json:"postgresql,omitzero" gorm:"foreignKey:DatabaseID"`
|
||||
Mysql *mysql.MysqlDatabase `json:"mysql,omitzero" gorm:"foreignKey:DatabaseID"`
|
||||
Mariadb *mariadb.MariadbDatabase `json:"mariadb,omitzero" gorm:"foreignKey:DatabaseID"`
|
||||
Mongodb *mongodb.MongodbDatabase `json:"mongodb,omitzero" gorm:"foreignKey:DatabaseID"`
|
||||
|
||||
Notifiers []notifiers.Notifier `json:"notifiers" gorm:"many2many:database_notifiers;"`
|
||||
|
||||
// these fields are not reliable, but
|
||||
// they are used for pretty UI
|
||||
LastBackupTime *time.Time `json:"lastBackupTime,omitempty" gorm:"column:last_backup_time;type:timestamp with time zone"`
|
||||
LastBackupTime *time.Time `json:"lastBackupTime,omitzero" gorm:"column:last_backup_time;type:timestamp with time zone"`
|
||||
LastBackupErrorMessage *string `json:"lastBackupErrorMessage,omitempty" gorm:"column:last_backup_error_message;type:text"`
|
||||
|
||||
HealthStatus *HealthStatus `json:"healthStatus" gorm:"column:health_status;type:text;not null"`
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
package secrets
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -16,34 +15,28 @@ type HealthcheckAttemptBackgroundService struct {
|
||||
checkDatabaseHealthUseCase *CheckDatabaseHealthUseCase
|
||||
logger *slog.Logger
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *HealthcheckAttemptBackgroundService) Run(ctx context.Context) {
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
// first healthcheck immediately
|
||||
s.checkDatabases()
|
||||
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.checkDatabases()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
if s.hasRun.Swap(true) {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
|
||||
// first healthcheck immediately
|
||||
s.checkDatabases()
|
||||
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.checkDatabases()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthcheckAttemptBackgroundService) checkDatabases() {
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
package healthcheck_attempt
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"databasus-backend/internal/features/databases"
|
||||
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
@@ -30,8 +27,6 @@ var healthcheckAttemptBackgroundService = &HealthcheckAttemptBackgroundService{
|
||||
healthcheckConfigService: healthcheck_config.GetHealthcheckConfigService(),
|
||||
checkDatabaseHealthUseCase: checkDatabaseHealthUseCase,
|
||||
logger: logger.GetLogger(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
|
||||
var healthcheckAttemptController = &HealthcheckAttemptController{
|
||||
|
||||
@@ -2,7 +2,6 @@ package healthcheck_config
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -33,23 +32,8 @@ func GetHealthcheckConfigController() *HealthcheckConfigController {
|
||||
return healthcheckConfigController
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
databases.
|
||||
GetDatabaseService().
|
||||
AddDbCreationListener(healthcheckConfigService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
var SetupDependencies = sync.OnceFunc(func() {
|
||||
databases.
|
||||
GetDatabaseService().
|
||||
AddDbCreationListener(healthcheckConfigService)
|
||||
})
|
||||
|
||||
@@ -2,7 +2,6 @@ package notifiers
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
@@ -39,21 +38,6 @@ func GetNotifierRepository() *NotifierRepository {
|
||||
return notifierRepository
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
var SetupDependencies = sync.OnceFunc(func() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService)
|
||||
})
|
||||
|
||||
@@ -23,12 +23,12 @@ type Notifier struct {
|
||||
LastSendError *string `json:"lastSendError" gorm:"column:last_send_error;type:text"`
|
||||
|
||||
// specific notifier
|
||||
TelegramNotifier *telegram_notifier.TelegramNotifier `json:"telegramNotifier" gorm:"foreignKey:NotifierID"`
|
||||
EmailNotifier *email_notifier.EmailNotifier `json:"emailNotifier" gorm:"foreignKey:NotifierID"`
|
||||
WebhookNotifier *webhook_notifier.WebhookNotifier `json:"webhookNotifier" gorm:"foreignKey:NotifierID"`
|
||||
SlackNotifier *slack_notifier.SlackNotifier `json:"slackNotifier" gorm:"foreignKey:NotifierID"`
|
||||
DiscordNotifier *discord_notifier.DiscordNotifier `json:"discordNotifier" gorm:"foreignKey:NotifierID"`
|
||||
TeamsNotifier *teams_notifier.TeamsNotifier `json:"teamsNotifier,omitempty" gorm:"foreignKey:NotifierID;constraint:OnDelete:CASCADE"`
|
||||
TelegramNotifier *telegram_notifier.TelegramNotifier `json:"telegramNotifier" gorm:"foreignKey:NotifierID"`
|
||||
EmailNotifier *email_notifier.EmailNotifier `json:"emailNotifier" gorm:"foreignKey:NotifierID"`
|
||||
WebhookNotifier *webhook_notifier.WebhookNotifier `json:"webhookNotifier" gorm:"foreignKey:NotifierID"`
|
||||
SlackNotifier *slack_notifier.SlackNotifier `json:"slackNotifier" gorm:"foreignKey:NotifierID"`
|
||||
DiscordNotifier *discord_notifier.DiscordNotifier `json:"discordNotifier" gorm:"foreignKey:NotifierID"`
|
||||
TeamsNotifier *teams_notifier.TeamsNotifier `json:"teamsNotifier,omitzero" gorm:"foreignKey:NotifierID;constraint:OnDelete:CASCADE"`
|
||||
}
|
||||
|
||||
func (n *Notifier) TableName() string {
|
||||
|
||||
@@ -49,7 +49,7 @@ type cardAttachment struct {
|
||||
type payload struct {
|
||||
Title string `json:"title"`
|
||||
Text string `json:"text"`
|
||||
Attachments []cardAttachment `json:"attachments,omitempty"`
|
||||
Attachments []cardAttachment `json:"attachments,omitzero"`
|
||||
}
|
||||
|
||||
func (n *TeamsNotifier) Send(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user