Compare commits

...

92 Commits

Author SHA1 Message Date
Rostislav Dugin
c2ddbfc86f Merge pull request #484 from databasus/develop
FEATURE (docker): Add PUID/PGID environment variables to control post…
2026-03-31 11:52:23 +03:00
Rostislav Dugin
f287967b5d FEATURE (docker): Add PUID/PGID environment variables to control postgres user UID/GID for host-level backup compatibility 2026-03-31 11:51:57 +03:00
Rostislav Dugin
ef879df08f Merge pull request #483 from databasus/develop
FIX (backups): Use system's temp directory instead of mounter directo…
2026-03-31 11:41:26 +03:00
Rostislav Dugin
44ddcb836e FIX (backups): Use system's temp directory instead of mounter directory to fix access permissions on TrueNAS 2026-03-31 11:40:11 +03:00
Rostislav Dugin
b5178f5752 Merge pull request #482 from databasus/develop
FEATURE (clipboard): Add parsing from clipboard via dialog in HTTP\no…
2026-03-31 11:21:25 +03:00
Rostislav Dugin
7913c1b474 FEATURE (clipboard): Add parsing from clipboard via dialog in HTTP\no navigator mode 2026-03-31 11:20:13 +03:00
Rostislav Dugin
2815cc3752 Merge pull request #481 from databasus/develop
FIX (storages): Validat only single rclone storage is passed
2026-03-31 10:37:54 +03:00
Rostislav Dugin
189573fa1b FIX (storages): Validat only single rclone storage is passed 2026-03-31 10:37:13 +03:00
Rostislav Dugin
81f77760c9 Merge pull request #479 from databasus/develop
FEATURE (navbar): Update navbar link color
2026-03-30 13:16:07 +03:00
Rostislav Dugin
63e23b2489 FEATURE (navbar): Update navbar link color 2026-03-30 13:15:03 +03:00
Rostislav Dugin
8c1b8ac00f Merge pull request #477 from databasus/develop
Develop
2026-03-29 15:45:15 +03:00
Rostislav Dugin
1926096377 FEATURE (backups): Add filters to backups panel 2026-03-29 15:33:01 +03:00
Rostislav Dugin
0a131511a8 FIX (agent): Fix uploading WAL to storages 2026-03-29 14:35:46 +03:00
Rostislav Dugin
aa01ce0b76 FEATURE (agent): Make installation guide more structured 2026-03-29 14:34:33 +03:00
Rostislav Dugin
496fc05993 Merge pull request #474 from databasus/develop
FIX (wal): Fix timeout upload test
2026-03-29 11:30:01 +03:00
Rostislav Dugin
1ac0eb4d5b FIX (wal): Fix timeout upload test 2026-03-29 11:29:44 +03:00
Rostislav Dugin
6b052902f7 Merge pull request #473 from databasus/develop
Develop
2026-03-28 22:53:13 +03:00
Rostislav Dugin
c7d091fe51 FEATURE (agent): Stop WAL and FULL backups on staling within 5 mins 2026-03-28 22:52:46 +03:00
Rostislav Dugin
b1dfd1c425 FIX (agent): Do not show cancel button for agent backups 2026-03-28 22:07:45 +03:00
Rostislav Dugin
4bee78646a REFACTOR (go): Refactor go to follow modern syntax guidelines 2026-03-28 22:02:22 +03:00
Rostislav Dugin
927eeabc0f Merge pull request #470 from databasus/develop
Develop
2026-03-27 23:58:39 +03:00
Rostislav Dugin
3a5a53c92d FIX (backups): Fix light theme for banner 2026-03-27 23:44:00 +03:00
Rostislav Dugin
f0ab470a84 FEATURE (readme): Update readme 2026-03-27 23:00:25 +03:00
Rostislav Dugin
f728fda759 FIX (backups): Hide hourly GFS when daily and more period selected 2026-03-27 22:51:40 +03:00
Rostislav Dugin
80b5df6283 FIX (billing): Fix UI units 2026-03-27 22:50:08 +03:00
Rostislav Dugin
67556a0db1 FIX (dockerfile): Fix index.html 2026-03-27 22:33:20 +03:00
Rostislav Dugin
c4cf7f8446 Merge branch 'develop' of https://github.com/databasus/databasus into develop 2026-03-27 22:02:40 +03:00
Rostislav Dugin
61a0bcabb1 FEATURE (cloud): Add cloud 2026-03-27 22:02:25 +03:00
Rostislav Dugin
f7f70a13eb Merge pull request #465 from databasus/develop
Develop
2026-03-24 18:01:35 +03:00
Rostislav Dugin
f1e289c421 Merge pull request #464 from kvendingoldo/main
feat: add support of imagePullSecrets
2026-03-24 18:00:15 +03:00
Alexander Sharov
c0952e057f feat: add support of imagePullSecrets
Signed-off-by: Alexander Sharov <kvendingoldo@yandex.ru>
2026-03-24 18:53:10 +04:00
Rostislav Dugin
b4d4e0a1d7 Merge pull request #459 from databasus/develop
FIX (readme): Fix typo in readme
2026-03-22 13:42:39 +03:00
Rostislav Dugin
c648e9c29f FIX (readme): Fix typo in readme 2026-03-22 13:41:56 +03:00
Rostislav Dugin
3fce6d2a99 Merge pull request #458 from databasus/develop
FIX (playground): Move index.html for playground into <noscript>
2026-03-22 09:58:48 +03:00
Rostislav Dugin
198b94ba9d FIX (playground): Move index.html for playground into <noscript> 2026-03-22 09:57:24 +03:00
Rostislav Dugin
80cd0bf5d3 Merge pull request #457 from databasus/develop
FIX (playground): Make turnstile mandatory in sign in and sign up
2026-03-21 22:32:26 +03:00
Rostislav Dugin
231e3cc709 FIX (playground): Make turnstile mandatory in sign in and sign up 2026-03-21 22:30:16 +03:00
Rostislav Dugin
8cf0fdacb1 Merge pull request #456 from databasus/develop
Develop
2026-03-21 14:15:21 +03:00
Rostislav Dugin
2d28af19dc FEATURE (playground): Remove playground warning 2026-03-21 14:14:22 +03:00
Rostislav Dugin
67dc257fda FIX (mariadb\mysql): Skip SSL if https mode is set to false 2026-03-21 14:12:53 +03:00
Rostislav Dugin
881167f812 FEATURE (index.html): Adjust policices for playgronund index 2026-03-21 14:08:09 +03:00
Rostislav Dugin
cf807cfc54 FIX (mariadb\mysql): Skip tables locking over restores 2026-03-21 12:55:34 +03:00
Rostislav Dugin
df91651709 Merge pull request #455 from databasus/develop
FIX (readme): Fix FAQ link
2026-03-21 12:44:22 +03:00
Rostislav Dugin
b0592dae9e FIX (readme): Fix FAQ link 2026-03-21 12:43:55 +03:00
Rostislav Dugin
c997202484 Merge pull request #454 from databasus/develop
FEATURE (notifiers): Change testing notifier from Telegram to webhook
2026-03-21 12:27:24 +03:00
Rostislav Dugin
a17ea2f3e2 FEATURE (notifiers): Change testing notifier from Telegram to webhook 2026-03-21 12:26:57 +03:00
Rostislav Dugin
856aa1c256 Merge pull request #453 from databasus/develop
FIX (agent): Make E2E test for locking check more stable
2026-03-21 11:59:35 +03:00
Rostislav Dugin
f60f677351 FIX (agent): Make E2E test for locking check more stable 2026-03-21 11:57:09 +03:00
Rostislav Dugin
4c980746ab Merge pull request #452 from databasus/develop
Develop
2026-03-20 17:43:41 +03:00
Rostislav Dugin
89197bbbc6 FEATURE (restore): Add restore hint for Docker 2026-03-20 17:41:09 +03:00
Rostislav Dugin
e2ac5bfbd7 FIX (agent): Make pgType param mandatory over restore 2026-03-20 17:29:56 +03:00
Rostislav Dugin
cf6e8f212a FIX (agent): Adjust restore path for Docker PG restoration 2026-03-20 17:04:46 +03:00
Rostislav Dugin
6ee7e02f5d FEATURE (agent): Change recovery target dir flag name 2026-03-20 14:07:12 +03:00
Rostislav Dugin
14bcd3d70b FEATURE (readme): Update readme 2026-03-20 13:55:12 +03:00
Rostislav Dugin
5faa11f82a FEATURE (agent): Increase agent update check interval 2026-03-20 13:54:56 +03:00
Rostislav Dugin
2c4e3e567b FEATURE (agent): Extend WAL logging 2026-03-20 13:38:11 +03:00
Rostislav Dugin
82d615545b FIX (agent): Verify PostgreSQL connection without requirement to expose ports 2026-03-20 12:45:02 +03:00
Rostislav Dugin
e913f4c32e FIX (e2e): Fix inclusion of e2e to Makefile in mandatory way 2026-03-20 11:52:49 +03:00
Rostislav Dugin
57a75918e4 FEATURE (ci \ cd): Add publishing dev image 2026-03-20 11:46:53 +03:00
Rostislav Dugin
8a601c7f68 FEATURE (agent): Add restore from WAL-backup 2026-03-19 23:35:54 +03:00
Rostislav Dugin
f0064b4be3 Merge pull request #448 from databasus/develop
FIX (agent): Fix lock test
2026-03-17 16:41:47 +03:00
Rostislav Dugin
94505bab3f FIX (agent): Fix lock test 2026-03-17 16:41:07 +03:00
Rostislav Dugin
9acf3cff09 Merge pull request #447 from databasus/develop
Develop
2026-03-17 16:36:45 +03:00
Rostislav Dugin
0d7e147df6 FIX (wal): Allow to save error via /complete endpoint 2026-03-17 16:33:00 +03:00
Rostislav Dugin
1394b47570 FIX (agent): Fix linting issues 2026-03-17 14:55:16 +03:00
Rostislav Dugin
a9865ae3e4 Merge pull request #446 from databasus/develop
Develop
2026-03-17 14:39:24 +03:00
Rostislav Dugin
4b5478e60a FEATURE (upgrader): Add background upgrading of the agent 2026-03-17 14:38:32 +03:00
Rostislav Dugin
6355301903 FIX (agent): Respect API responses status code when retying 2026-03-16 22:13:47 +03:00
Rostislav Dugin
29b403a9c6 FIX (wal): Enforce streaming without RAM buffering over base backup 2026-03-16 21:53:40 +03:00
Rostislav Dugin
12606053f4 FEATURE (params): Rename WAL dir param 2026-03-16 17:50:09 +03:00
Rostislav Dugin
904b386378 FIX (logger): Limit logger to 5Mb 2026-03-16 17:31:37 +03:00
Rostislav Dugin
1d9738b808 FEATURE (agent): Make zstd compression 5 by default 2026-03-16 15:52:37 +03:00
Databasus
58b37f4c92 Merge pull request #443 from gogo199432/main
feat(helm): add service annotations support
2026-03-16 15:47:26 +03:00
gordon
6c4f814c94 feat(helm): add service annotations support 2026-03-15 16:45:51 +01:00
Rostislav Dugin
bcd13c27d3 FIX (agent): Add lock file watcher to exit from process in case of lock file deletion 2026-03-15 18:04:03 +03:00
Rostislav Dugin
120f9600bf FEATURE (agent): Add check for PG >= 15 for WAL 2026-03-15 17:48:13 +03:00
Rostislav Dugin
563c7c1d64 FEATURE (agent): Add running as daemon 2026-03-15 17:37:13 +03:00
Rostislav Dugin
68f15f7661 FEATURE (agent): Add WAL streaming 2026-03-15 14:04:54 +03:00
Rostislav Dugin
627d96a00d FIX (backups): Do not validate chain on WAL uploading 2026-03-15 13:13:42 +03:00
Rostislav Dugin
02b9a9ec8d FEATURE (agent): Add locking to ensure single running instance 2026-03-14 13:55:57 +03:00
Rostislav Dugin
415dda8752 Merge pull request #440 from databasus/develop
FIX (local storage): Add fallback for file movement via renaming to s…
2026-03-14 13:38:39 +03:00
Rostislav Dugin
3faf85796a FIX (local storage): Add fallback for file movement via renaming to support cross-device movement 2026-03-14 13:32:29 +03:00
Rostislav Dugin
edd2759f5a Merge pull request #439 from databasus/develop
FIX (ci \ cd): Add e2e agent docker-compose to repo
2026-03-14 13:17:03 +03:00
Rostislav Dugin
c283856f38 FIX (ci \ cd): Add e2e agent docker-compose to repo 2026-03-14 13:15:34 +03:00
Rostislav Dugin
6059e1a33b Merge pull request #438 from databasus/develop
FIX (ci \ cd): Exclude agent e2e from docker ignore
2026-03-14 13:11:53 +03:00
Rostislav Dugin
2deda2e7ea FIX (ci \ cd): Exclude agent e2e from docker ignore 2026-03-14 13:11:27 +03:00
Rostislav Dugin
acf1143752 Merge pull request #437 from databasus/develop
FIX (ci \ cd): Update e2e tests for agent to run on GitHub workers
2026-03-14 12:54:56 +03:00
Rostislav Dugin
889063a8b4 FIX (ci \ cd): Update e2e tests for agent to run on GitHub workers 2026-03-14 12:54:32 +03:00
Rostislav Dugin
a1e20e7b10 Merge pull request #436 from databasus/develop
FIX (linting): Add E2E to linting
2026-03-14 12:48:23 +03:00
Rostislav Dugin
7e76945550 FIX (linting): Add E2E to linting 2026-03-14 12:47:43 +03:00
Rostislav Dugin
d98acfc4af Merge pull request #435 from databasus/develop
FEATURE (agent): Add postgres verification and e2e tests for agent
2026-03-14 12:43:51 +03:00
Rostislav Dugin
0ffc7c8c96 FEATURE (agent): Add postgres verification and e2e tests for agent 2026-03-14 12:43:13 +03:00
263 changed files with 20992 additions and 4187 deletions

View File

@@ -9,6 +9,7 @@ on:
jobs:
lint-backend:
if: github.ref != 'refs/heads/develop'
runs-on: self-hosted
container:
image: golang:1.26.1
@@ -55,6 +56,7 @@ jobs:
git diff --exit-code go.mod go.sum || (echo "go mod tidy made changes, please run 'go mod tidy' and commit the changes" && exit 1)
lint-frontend:
if: github.ref != 'refs/heads/develop'
runs-on: ubuntu-latest
steps:
- name: Check out code
@@ -87,6 +89,7 @@ jobs:
npm run build
lint-agent:
if: github.ref != 'refs/heads/develop'
runs-on: ubuntu-latest
steps:
- name: Check out code
@@ -120,6 +123,7 @@ jobs:
git diff --exit-code go.mod go.sum || (echo "go mod tidy made changes, please run 'go mod tidy' and commit the changes" && exit 1)
test-frontend:
if: github.ref != 'refs/heads/develop'
runs-on: ubuntu-latest
needs: [lint-frontend]
steps:
@@ -142,6 +146,7 @@ jobs:
npm run test
test-agent:
if: github.ref != 'refs/heads/develop'
runs-on: ubuntu-latest
needs: [lint-agent]
steps:
@@ -164,7 +169,54 @@ jobs:
cd agent
go test -count=1 -failfast ./internal/...
e2e-agent:
if: github.ref != 'refs/heads/develop'
runs-on: ubuntu-latest
needs: [lint-agent]
steps:
- name: Check out code
uses: actions/checkout@v4
- name: Run e2e tests
run: |
cd agent
make e2e
- name: Cleanup
if: always()
run: |
cd agent/e2e
docker compose down -v --rmi local || true
rm -rf artifacts || true
e2e-agent-backup-restore:
if: github.ref != 'refs/heads/develop'
runs-on: ubuntu-latest
needs: [lint-agent]
strategy:
matrix:
pg_version: [15, 16, 17, 18]
fail-fast: false
steps:
- name: Check out code
uses: actions/checkout@v4
- name: Run backup-restore e2e (PG ${{ matrix.pg_version }})
run: |
cd agent
make e2e-backup-restore PG_VERSION=${{ matrix.pg_version }}
- name: Cleanup
if: always()
run: |
cd agent/e2e
docker compose -f docker-compose.backup-restore.yml down -v --rmi local || true
rm -rf artifacts || true
# Self-hosted: performant high-frequency CPU is used to start many containers and run tests fast. Tests
# step is bottle-neck, because we need a lot of containers and cannot parallelize tests due to shared resources
test-backend:
if: github.ref != 'refs/heads/develop'
runs-on: self-hosted
needs: [lint-backend]
container:
@@ -252,9 +304,6 @@ jobs:
TEST_MARIADB_114_PORT=33114
TEST_MARIADB_118_PORT=33118
TEST_MARIADB_120_PORT=33120
# testing Telegram
TEST_TELEGRAM_BOT_TOKEN=${{ secrets.TEST_TELEGRAM_BOT_TOKEN }}
TEST_TELEGRAM_CHAT_ID=${{ secrets.TEST_TELEGRAM_CHAT_ID }}
# supabase
TEST_SUPABASE_HOST=${{ secrets.TEST_SUPABASE_HOST }}
TEST_SUPABASE_PORT=${{ secrets.TEST_SUPABASE_PORT }}
@@ -493,11 +542,47 @@ jobs:
echo "Cleanup complete"
build-and-push-dev:
runs-on: self-hosted
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/develop' }}
steps:
- name: Clean workspace
run: |
sudo rm -rf "$GITHUB_WORKSPACE"/* || true
sudo rm -rf "$GITHUB_WORKSPACE"/.* || true
- name: Check out code
uses: actions/checkout@v4
- name: Set up QEMU (enables multi-arch emulation)
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Build and push dev image
uses: docker/build-push-action@v5
with:
context: .
push: true
platforms: linux/amd64,linux/arm64
build-args: |
APP_VERSION=dev-${{ github.sha }}
tags: |
databasus/databasus-dev:latest
databasus/databasus-dev:${{ github.sha }}
determine-version:
runs-on: self-hosted
container:
image: node:20
needs: [test-backend, test-frontend, test-agent]
needs: [test-backend, test-frontend, test-agent, e2e-agent, e2e-agent-backup-restore]
if: ${{ github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, '[skip-release]') }}
outputs:
should_release: ${{ steps.version_bump.outputs.should_release }}
@@ -588,43 +673,6 @@ jobs:
echo "No version bump needed"
fi
build-only:
runs-on: self-hosted
needs: [test-backend, test-frontend, test-agent]
if: ${{ github.ref == 'refs/heads/main' && contains(github.event.head_commit.message, '[skip-release]') }}
steps:
- name: Clean workspace
run: |
sudo rm -rf "$GITHUB_WORKSPACE"/* || true
sudo rm -rf "$GITHUB_WORKSPACE"/.* || true
- name: Check out code
uses: actions/checkout@v4
- name: Set up QEMU (enables multi-arch emulation)
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Build and push SHA-only tags
uses: docker/build-push-action@v5
with:
context: .
push: true
platforms: linux/amd64,linux/arm64
build-args: |
APP_VERSION=dev-${{ github.sha }}
tags: |
databasus/databasus:latest
databasus/databasus:${{ github.sha }}
build-and-push:
runs-on: self-hosted
needs: [determine-version]

1
.gitignore vendored
View File

@@ -5,6 +5,7 @@ databasus-data/
.env
pgdata/
docker-compose.yml
!agent/e2e/docker-compose.yml
node_modules/
.idea
/articles

836
AGENTS.md

File diff suppressed because it is too large Load Diff

View File

@@ -239,7 +239,8 @@ RUN apt-get update && \
fi
# Create postgres user and set up directories
RUN useradd -m -s /bin/bash postgres || true && \
RUN groupadd -g 999 postgres || true && \
useradd -m -s /bin/bash -u 999 -g 999 postgres || true && \
mkdir -p /databasus-data/pgdata && \
chown -R postgres:postgres /databasus-data/pgdata
@@ -257,6 +258,9 @@ COPY backend/migrations ./migrations
# Copy UI files
COPY --from=backend-build /app/ui/build ./ui/build
# Copy cloud static HTML template (injected into index.html at startup when IS_CLOUD=true)
COPY frontend/cloud-root-content.html /app/cloud-root-content.html
# Copy agent binaries (both architectures) — served by the backend
# at GET /api/v1/system/agent?arch=amd64|arm64
COPY --from=agent-build /agent-binaries ./agent-binaries
@@ -291,6 +295,23 @@ if [ -d "/postgresus-data" ] && [ "\$(ls -A /postgresus-data 2>/dev/null)" ]; th
exit 1
fi
# ========= Adjust postgres user UID/GID =========
PUID=\${PUID:-999}
PGID=\${PGID:-999}
CURRENT_UID=\$(id -u postgres)
CURRENT_GID=\$(id -g postgres)
if [ "\$CURRENT_GID" != "\$PGID" ]; then
echo "Adjusting postgres group GID from \$CURRENT_GID to \$PGID..."
groupmod -o -g "\$PGID" postgres
fi
if [ "\$CURRENT_UID" != "\$PUID" ]; then
echo "Adjusting postgres user UID from \$CURRENT_UID to \$PUID..."
usermod -o -u "\$PUID" postgres
fi
# PostgreSQL 17 binary paths
PG_BIN="/usr/lib/postgresql/17/bin"
@@ -313,7 +334,9 @@ window.__RUNTIME_CONFIG__ = {
GOOGLE_CLIENT_ID: '\${GOOGLE_CLIENT_ID:-}',
IS_EMAIL_CONFIGURED: '\$IS_EMAIL_CONFIGURED',
CLOUDFLARE_TURNSTILE_SITE_KEY: '\${CLOUDFLARE_TURNSTILE_SITE_KEY:-}',
CONTAINER_ARCH: '\${CONTAINER_ARCH:-unknown}'
CONTAINER_ARCH: '\${CONTAINER_ARCH:-unknown}',
CLOUD_PRICE_PER_GB: '\${CLOUD_PRICE_PER_GB:-}',
CLOUD_PADDLE_CLIENT_TOKEN: '\${CLOUD_PADDLE_CLIENT_TOKEN:-}'
};
JSEOF
@@ -326,6 +349,32 @@ if [ -n "\${ANALYTICS_SCRIPT:-}" ]; then
fi
fi
# Inject Paddle script if client token is provided (only if not already injected)
if [ -n "\${CLOUD_PADDLE_CLIENT_TOKEN:-}" ]; then
if ! grep -q "cdn.paddle.com" /app/ui/build/index.html 2>/dev/null; then
echo "Injecting Paddle script..."
sed -i "s#</head># <script src=\"https://cdn.paddle.com/paddle/v2/paddle.js\"></script>\\
</head>#" /app/ui/build/index.html
fi
fi
# Inject static HTML into root div for cloud mode (payment system requires visible legal links)
if [ "\${IS_CLOUD:-false}" = "true" ]; then
if ! grep -q "cloud-static-content" /app/ui/build/index.html 2>/dev/null; then
echo "Injecting cloud static HTML content..."
perl -i -pe '
BEGIN {
open my \$fh, "<", "/app/cloud-root-content.html" or die;
local \$/;
\$c = <\$fh>;
close \$fh;
\$c =~ s/\\n/ /g;
}
s/<div id="root"><\\/div>/<div id="root"><!-- cloud-static-content --><noscript>\$c<\\/noscript><\\/div>/
' /app/ui/build/index.html
fi
fi
# Ensure proper ownership of data directory
echo "Setting up data directory permissions..."
mkdir -p /databasus-data/pgdata
@@ -439,7 +488,7 @@ fi
echo "Setting up database and user..."
gosu postgres \$PG_BIN/psql -p 5437 -h localhost -d postgres << 'SQL'
# We use stub password, because internal DB is not exposed outside container
-- We use stub password, because internal DB is not exposed outside container
ALTER USER postgres WITH PASSWORD 'Q1234567';
SELECT 'CREATE DATABASE databasus OWNER postgres'
WHERE NOT EXISTS (SELECT FROM pg_database WHERE datname = 'databasus')

22
NOTICE.md Normal file
View File

@@ -0,0 +1,22 @@
Copyright © 20252026 Rostislav Dugin and contributors.
“Databasus” is a trademark of Rostislav Dugin.
The source code in this repository is licensed under the Apache License, Version 2.0.
That license applies to the code only and does not grant any right to use the
Databasus name, logo, or branding, except for reasonable and customary referential
use in describing the origin of the software and reproducing the content of this NOTICE.
Permitted referential use includes truthful use of the name “Databasus” to identify
the original Databasus project in software catalogs, deployment templates, hosting
panels, package indexes, compatibility pages, integrations, tutorials, reviews, and
similar informational materials, including phrases such as “Databasus”,
“Deploy Databasus”, “Databasus on Coolify”, and “Compatible with Databasus”.
You may not use “Databasus” as the name or primary branding of a competing product,
service, fork, distribution, or hosted offering, or in any manner likely to cause
confusion as to source, affiliation, sponsorship, or endorsement.
Nothing in this repository transfers, waives, limits, or estops any rights in the
Databasus mark. All trademark rights are reserved except for the limited referential
use stated above.

View File

@@ -1,8 +1,8 @@
<div align="center">
<img src="assets/logo.svg" alt="Databasus Logo" width="250"/>
<h3>Backup tool for PostgreSQL, MySQL and MongoDB</h3>
<p>Databasus is a free, open source and self-hosted tool to backup databases (with focus on PostgreSQL). Make backups with different storages (S3, Google Drive, FTP, etc.) and notifications about progress (Slack, Discord, Telegram, etc.)</p>
<h3>PostgreSQL backup tool (with MySQL\MariaDB and MongoDB support)</h3>
<p>Databasus is a free, open source and self-hosted tool to backup databases (with primary focus on PostgreSQL). Make backups with different storages (S3, Google Drive, FTP, etc.) and notifications about progress (Slack, Discord, Telegram, etc.)</p>
<!-- Badges -->
[![PostgreSQL](https://img.shields.io/badge/PostgreSQL-336791?logo=postgresql&logoColor=white)](https://www.postgresql.org/)
@@ -91,14 +91,16 @@ It is also important for Databasus that you are able to decrypt and restore back
- **Dark & light themes**: Choose the look that suits your workflow
- **Mobile adaptive**: Check your backups from anywhere on any device
### ☁️ **Works with self-hosted & cloud databases**
### 🔌 **Connection types**
Databasus works seamlessly with both self-hosted PostgreSQL and cloud-managed databases:
- **Remote** — Databasus connects directly to the database over the network (recommended in read-only mode). No agent or additional software required. Works with cloud-managed and self-hosted databases
- **Agent** — A lightweight Databasus agent (written in Go) runs alongside the database. The agent streams backups directly to Databasus, so the database never needs to be exposed publicly. Supports host-installed databases and Docker containers
- **Cloud support**: AWS RDS, Google Cloud SQL, Azure Database for PostgreSQL
- **Self-hosted**: Any PostgreSQL instance you manage yourself
- **Why no PITR support?**: Cloud providers already offer native PITR, and external PITR backups cannot be restored to managed cloud databases — making them impractical for cloud-hosted PostgreSQL
- **Practical granularity**: Hourly and daily backups are sufficient for 99% of projects without the operational complexity of WAL archiving
### 📦 **Backup types**
- **Logical** — Native dump of the database in its engine-specific binary format. Compressed and streamed directly to storage with no intermediate files
- **Physical** — File-level copy of the entire database cluster. Faster backup and restore for large datasets compared to logical dumps
- **Incremental** — Physical base backup combined with continuous WAL segment archiving. Enables Point-in-time recovery (PITR) — restore to any second between backups. Designed for disaster recovery and near-zero data loss requirements
### 🐳 **Self-hosted & secure**
@@ -243,7 +245,7 @@ Replace `admin` with the actual email address of the user whose password you wan
### 💾 Backuping Databasus itself
After installation, it is also recommended to <a href="https://databasus.com/faq/#backup-databasus">backup your Databasus itself</a> or, at least, to copy secret key used for encryption (30 seconds is needed). So you are able to restore from your encrypted backups if you lose access to the server with Databasus or it is corrupted.
After installation, it is also recommended to <a href="https://databasus.com/faq#backup-databasus">backup your Databasus itself</a> or, at least, to copy secret key used for encryption (30 seconds is needed). So you are able to restore from your encrypted backups if you lose access to the server with Databasus or it is corrupted.
---
@@ -257,7 +259,9 @@ Contributions are welcome! Read the <a href="https://databasus.com/contribute">c
Also you can join our large community of developers, DBAs and DevOps engineers on Telegram [@databasus_community](https://t.me/databasus_community).
## AI disclaimer
## FAQ
### AI disclaimer
There have been questions about AI usage in project development in issues and discussions. As the project focuses on security, reliability and production usage, it's important to explain how AI is used in the development process.
@@ -293,3 +297,13 @@ Moreover, it's important to note that we do not differentiate between bad human
Even if code is written manually by a human, it's not guaranteed to be merged. Vibe code is not allowed at all and all such PRs are rejected by default (see [contributing guide](https://databasus.com/contribute)).
We also draw attention to fast issue resolution and security [vulnerability reporting](https://github.com/databasus/databasus?tab=security-ov-file#readme).
### You have a cloud version — are you truly open source?
Yes. Every feature available in Databasus Cloud is equally available in the self-hosted version with no restrictions, no feature gates and no usage limits. The entire codebase is Apache 2.0 licensed and always will be.
Databasus is not "open core." We do not withhold features behind a paid tier and then call the limited remainder "open source," as projects like GitLab or Sentry do. We believe open source means the complete product is open, not just a marketing label on a stripped-down edition.
Databasus Cloud runs the exact same code as the self-hosted version. The only difference is that we take care of infrastructure, availability, backups, reservations, monitoring and updates for you — so you don't have to. If you are using cloud, you can always move your databases from cloud to self-hosted if you wish.
Revenue from Cloud funds full-time development of the project. Most large open-source projects rely on corporate backing or sponsorship to survive. We chose a different path: Databasus sustains itself so it can grow and improve independently, without being tied to any enterprise or sponsor.

View File

@@ -1 +1,3 @@
ENV_MODE=development
AGENT_DB_ID=your-database-id
AGENT_TOKEN=your-agent-token

6
agent/.gitignore vendored
View File

@@ -1,6 +1,7 @@
main
.env
docker-compose.yml
!e2e/docker-compose.yml
pgdata
pgdata_test/
mysqldata/
@@ -20,4 +21,7 @@ cmd.exe
temp/
valkey-data/
victoria-logs-data/
databasus.json
databasus.json
.test-tmp/
databasus.log
wal-queue/

View File

@@ -1,6 +1,21 @@
# Usage: make run ARGS="start --pg-host localhost"
.PHONY: run build test lint e2e e2e-clean e2e-backup-restore e2e-backup-restore-clean
-include .env
export
run:
go run cmd/main.go $(ARGS)
go run cmd/main.go start \
--databasus-host http://localhost:4005 \
--db-id $(AGENT_DB_ID) \
--token $(AGENT_TOKEN) \
--pg-host 127.0.0.1 \
--pg-port 7433 \
--pg-user devuser \
--pg-password devpassword \
--pg-type docker \
--pg-docker-container-name dev-postgres \
--pg-wal-dir ./wal-queue \
--skip-update
build:
CGO_ENABLED=0 go build -ldflags "-X main.Version=$(VERSION)" -o databasus-agent ./cmd/main.go
@@ -9,4 +24,18 @@ test:
go test -count=1 -failfast ./internal/...
lint:
golangci-lint fmt ./cmd/... ./internal/... && golangci-lint run ./cmd/... ./internal/...
golangci-lint fmt ./cmd/... ./internal/... ./e2e/... && golangci-lint run ./cmd/... ./internal/... ./e2e/...
e2e:
cd e2e && docker compose build --no-cache e2e-mock-server
cd e2e && docker compose build
cd e2e && docker compose run --rm e2e-agent-builder
cd e2e && docker compose up -d e2e-postgres e2e-mock-server
cd e2e && docker compose run --rm e2e-agent-runner
cd e2e && docker compose run --rm e2e-agent-docker
cd e2e && docker compose down -v
e2e-clean:
cd e2e && docker compose down -v --rmi local
cd e2e && docker compose -f docker-compose.backup-restore.yml down -v --rmi local 2>/dev/null || true
rm -rf e2e/artifacts

View File

@@ -1,13 +1,19 @@
package main
import (
"context"
"errors"
"flag"
"fmt"
"log/slog"
"os"
"path/filepath"
"strings"
"syscall"
"databasus-agent/internal/config"
"databasus-agent/internal/features/api"
"databasus-agent/internal/features/restore"
"databasus-agent/internal/features/start"
"databasus-agent/internal/features/upgrade"
"databasus-agent/internal/logger"
@@ -24,6 +30,8 @@ func main() {
switch os.Args[1] {
case "start":
runStart(os.Args[2:])
case "_run":
runDaemon(os.Args[2:])
case "stop":
runStop()
case "status":
@@ -42,7 +50,6 @@ func main() {
func runStart(args []string) {
fs := flag.NewFlagSet("start", flag.ExitOnError)
isDebug := fs.Bool("debug", false, "Enable debug logging")
isSkipUpdate := fs.Bool("skip-update", false, "Skip auto-update check")
cfg := &config.Config{}
@@ -52,36 +59,67 @@ func runStart(args []string) {
fmt.Fprintf(os.Stderr, "Failed to save config: %v\n", err)
}
logger.Init(*isDebug)
log := logger.GetLogger()
isDev := checkIsDevelopment()
runUpdateCheck(cfg.DatabasusHost, *isSkipUpdate, isDev, log)
if err := start.Run(cfg, log); err != nil {
if err := start.Start(cfg, Version, isDev, log); err != nil {
if errors.Is(err, upgrade.ErrUpgradeRestart) {
reexecAfterUpgrade(log)
}
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
}
func runDaemon(args []string) {
fs := flag.NewFlagSet("_run", flag.ExitOnError)
if err := fs.Parse(args); err != nil {
os.Exit(1)
}
log := logger.GetLogger()
cfg := &config.Config{}
cfg.LoadFromJSON()
if err := start.RunDaemon(cfg, Version, checkIsDevelopment(), log); err != nil {
if errors.Is(err, upgrade.ErrUpgradeRestart) {
reexecAfterUpgrade(log)
}
log.Error("Agent exited with error", "error", err)
os.Exit(1)
}
}
func runStop() {
logger.Init(false)
logger.GetLogger().Info("stop: stub — not yet implemented")
log := logger.GetLogger()
if err := start.Stop(log); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
}
func runStatus() {
logger.Init(false)
logger.GetLogger().Info("status: stub — not yet implemented")
log := logger.GetLogger()
if err := start.Status(log); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
}
func runRestore(args []string) {
fs := flag.NewFlagSet("restore", flag.ExitOnError)
targetDir := fs.String("target-dir", "", "Target pgdata directory")
pgDataDir := fs.String("target-dir", "", "Target pgdata directory (required)")
backupID := fs.String("backup-id", "", "Full backup UUID (optional)")
targetTime := fs.String("target-time", "", "PITR target time in RFC3339 (optional)")
isYes := fs.Bool("yes", false, "Skip confirmation prompt")
isDebug := fs.Bool("debug", false, "Enable debug logging")
isSkipUpdate := fs.Bool("skip-update", false, "Skip auto-update check")
cfg := &config.Config{}
@@ -91,18 +129,34 @@ func runRestore(args []string) {
fmt.Fprintf(os.Stderr, "Failed to save config: %v\n", err)
}
logger.Init(*isDebug)
log := logger.GetLogger()
isDev := checkIsDevelopment()
runUpdateCheck(cfg.DatabasusHost, *isSkipUpdate, isDev, log)
log.Info("restore: stub — not yet implemented",
"targetDir", *targetDir,
"backupId", *backupID,
"targetTime", *targetTime,
"yes", *isYes,
)
if *pgDataDir == "" {
fmt.Fprintln(os.Stderr, "Error: --target-dir is required")
os.Exit(1)
}
if cfg.DatabasusHost == "" || cfg.Token == "" {
fmt.Fprintln(os.Stderr, "Error: databasus-host and token must be configured")
os.Exit(1)
}
if cfg.PgType != "host" && cfg.PgType != "docker" {
fmt.Fprintf(os.Stderr, "Error: --pg-type must be 'host' or 'docker', got '%s'\n", cfg.PgType)
os.Exit(1)
}
apiClient := api.NewClient(cfg.DatabasusHost, cfg.Token, log)
restorer := restore.NewRestorer(apiClient, log, *pgDataDir, *backupID, *targetTime, cfg.PgType)
ctx := context.Background()
if err := restorer.Run(ctx); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
}
func printUsage() {
@@ -116,12 +170,7 @@ func printUsage() {
fmt.Fprintln(os.Stderr, " version Print agent version")
}
func runUpdateCheck(host string, isSkipUpdate, isDev bool, log interface {
Info(string, ...any)
Warn(string, ...any)
Error(string, ...any)
},
) {
func runUpdateCheck(host string, isSkipUpdate, isDev bool, log *slog.Logger) {
if isSkipUpdate {
return
}
@@ -130,10 +179,17 @@ func runUpdateCheck(host string, isSkipUpdate, isDev bool, log interface {
return
}
if err := upgrade.CheckAndUpdate(host, Version, isDev, log); err != nil {
apiClient := api.NewClient(host, "", log)
isUpgraded, err := upgrade.CheckAndUpdate(apiClient, Version, isDev, log)
if err != nil {
log.Error("Auto-update failed", "error", err)
os.Exit(1)
}
if isUpgraded {
reexecAfterUpgrade(log)
}
}
func checkIsDevelopment() bool {
@@ -172,3 +228,18 @@ func parseEnvMode(data []byte) bool {
return false
}
func reexecAfterUpgrade(log *slog.Logger) {
selfPath, err := os.Executable()
if err != nil {
log.Error("Failed to resolve executable for re-exec", "error", err)
os.Exit(1)
}
log.Info("Re-executing after upgrade...")
if err := syscall.Exec(selfPath, os.Args, os.Environ()); err != nil {
log.Error("Failed to re-exec after upgrade", "error", err)
os.Exit(1)
}
}

View File

@@ -0,0 +1,58 @@
services:
dev-postgres:
image: postgres:17
container_name: dev-postgres
environment:
POSTGRES_DB: devdb
POSTGRES_USER: devuser
POSTGRES_PASSWORD: devpassword
ports:
- "7433:5432"
command:
- bash
- -c
- |
mkdir -p /wal-queue && chown postgres:postgres /wal-queue
exec docker-entrypoint.sh postgres \
-c wal_level=replica \
-c max_wal_senders=3 \
-c archive_mode=on \
-c "archive_command=cp %p /wal-queue/%f"
volumes:
- ./wal-queue:/wal-queue
healthcheck:
test: ["CMD-SHELL", "pg_isready -U devuser -d devdb"]
interval: 2s
timeout: 5s
retries: 30
db-writer:
image: postgres:17
container_name: dev-db-writer
depends_on:
dev-postgres:
condition: service_healthy
environment:
PGHOST: dev-postgres
PGPORT: "5432"
PGUSER: devuser
PGPASSWORD: devpassword
PGDATABASE: devdb
command:
- bash
- -c
- |
echo "Waiting for postgres..."
until pg_isready -h dev-postgres -U devuser -d devdb; do sleep 1; done
psql -c "DROP TABLE IF EXISTS wal_generator;"
psql -c "CREATE TABLE wal_generator (id SERIAL PRIMARY KEY, data TEXT NOT NULL);"
echo "Starting WAL generation loop..."
while true; do
echo "Inserting ~50MB of data..."
psql -c "INSERT INTO wal_generator (data) SELECT repeat(md5(random()::text), 640) FROM generate_series(1, 2500);"
echo "Deleting data..."
psql -c "DELETE FROM wal_generator;"
echo "Cycle complete, sleeping 5s..."
sleep 5
done

2
agent/e2e/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
artifacts/
pgdata/

View File

@@ -0,0 +1,13 @@
# Builds agent binaries with different versions so
# we can test upgrade behavior (v1 -> v2)
FROM golang:1.26.1-alpine AS build
WORKDIR /src
COPY go.mod go.sum ./
RUN go mod download
COPY . .
RUN CGO_ENABLED=0 go build -ldflags "-X main.Version=v1.0.0" -o /out/agent-v1 ./cmd/main.go
RUN CGO_ENABLED=0 go build -ldflags "-X main.Version=v2.0.0" -o /out/agent-v2 ./cmd/main.go
FROM alpine:3.21
COPY --from=build /out/ /out/
CMD ["cp", "-v", "/out/agent-v1", "/out/agent-v2", "/artifacts/"]

View File

@@ -0,0 +1,22 @@
# Runs backup-restore via docker exec test (test 6). Needs both Docker
# CLI (for pg_basebackup via docker exec) and PostgreSQL server (for
# restore verification).
FROM debian:bookworm-slim
RUN apt-get update && \
apt-get install -y --no-install-recommends \
ca-certificates curl gnupg2 locales postgresql-common && \
sed -i '/en_US.UTF-8/s/^# //g' /etc/locale.gen && \
locale-gen && \
/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && \
apt-get install -y --no-install-recommends \
postgresql-17 && \
install -m 0755 -d /etc/apt/keyrings && \
curl -fsSL https://download.docker.com/linux/debian/gpg -o /etc/apt/keyrings/docker.asc && \
echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/debian bookworm stable" > /etc/apt/sources.list.d/docker.list && \
apt-get update && \
apt-get install -y --no-install-recommends docker-ce-cli && \
rm -rf /var/lib/apt/lists/*
WORKDIR /tmp
ENTRYPOINT []

View File

@@ -0,0 +1,14 @@
# Runs upgrade and host-mode backup-restore tests (tests 1-5). Needs
# full PostgreSQL server for backup-restore lifecycle tests.
FROM debian:bookworm-slim
RUN apt-get update && \
apt-get install -y --no-install-recommends \
ca-certificates curl gnupg2 postgresql-common && \
/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && \
apt-get install -y --no-install-recommends \
postgresql-17 && \
rm -rf /var/lib/apt/lists/*
WORKDIR /tmp
ENTRYPOINT []

View File

@@ -0,0 +1,16 @@
# Runs backup-restore lifecycle tests with a specific PostgreSQL version.
# Used for PG version matrix testing (15, 16, 17, 18).
FROM debian:bookworm-slim
ARG PG_VERSION=17
RUN apt-get update && \
apt-get install -y --no-install-recommends \
ca-certificates curl gnupg2 postgresql-common && \
/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && \
apt-get install -y --no-install-recommends \
postgresql-${PG_VERSION} && \
rm -rf /var/lib/apt/lists/*
WORKDIR /tmp
ENTRYPOINT []

View File

@@ -0,0 +1,10 @@
# Mock databasus API server for version checks and binary downloads. Just
# serves static responses and files from the `artifacts` directory.
FROM golang:1.26.1-alpine AS build
WORKDIR /app
COPY mock-server/main.go .
RUN CGO_ENABLED=0 go build -o mock-server main.go
FROM alpine:3.21
COPY --from=build /app/mock-server /usr/local/bin/mock-server
ENTRYPOINT ["mock-server"]

View File

@@ -0,0 +1,33 @@
services:
e2e-br-mock-server:
build:
context: .
dockerfile: Dockerfile.mock-server
volumes:
- backup-storage:/backup-storage
container_name: e2e-br-mock-server
healthcheck:
test: ["CMD", "wget", "-q", "--spider", "http://localhost:4050/health"]
interval: 2s
timeout: 5s
retries: 10
e2e-br-runner:
build:
context: .
dockerfile: Dockerfile.backup-restore-runner
args:
PG_VERSION: ${PG_VERSION:-17}
volumes:
- ./artifacts:/opt/agent/artifacts:ro
- ./scripts:/opt/agent/scripts:ro
depends_on:
e2e-br-mock-server:
condition: service_healthy
container_name: e2e-br-runner
command: ["bash", "/opt/agent/scripts/test-pg-host-path.sh"]
environment:
MOCK_SERVER_OVERRIDE: "http://e2e-br-mock-server:4050"
volumes:
backup-storage:

View File

@@ -0,0 +1,84 @@
services:
e2e-agent-builder:
build:
context: ..
dockerfile: e2e/Dockerfile.agent-builder
volumes:
- ./artifacts:/artifacts
container_name: e2e-agent-builder
e2e-postgres:
image: postgres:17
environment:
POSTGRES_DB: testdb
POSTGRES_USER: testuser
POSTGRES_PASSWORD: testpassword
container_name: e2e-agent-postgres
command:
- bash
- -c
- |
mkdir -p /wal-queue && chown postgres:postgres /wal-queue
exec docker-entrypoint.sh postgres \
-c wal_level=replica \
-c max_wal_senders=3 \
-c archive_mode=on \
-c "archive_command=cp %p /wal-queue/%f"
volumes:
- ./pgdata:/var/lib/postgresql/data
- wal-queue:/wal-queue
healthcheck:
test: ["CMD-SHELL", "pg_isready -U testuser -d testdb"]
interval: 2s
timeout: 5s
retries: 30
e2e-mock-server:
build:
context: .
dockerfile: Dockerfile.mock-server
volumes:
- ./artifacts:/artifacts:ro
- backup-storage:/backup-storage
container_name: e2e-mock-server
healthcheck:
test: ["CMD", "wget", "-q", "--spider", "http://localhost:4050/health"]
interval: 2s
timeout: 5s
retries: 10
e2e-agent-runner:
build:
context: .
dockerfile: Dockerfile.agent-runner
volumes:
- ./artifacts:/opt/agent/artifacts:ro
- ./scripts:/opt/agent/scripts:ro
depends_on:
e2e-postgres:
condition: service_healthy
e2e-mock-server:
condition: service_healthy
container_name: e2e-agent-runner
command: ["bash", "/opt/agent/scripts/run-all.sh", "host"]
e2e-agent-docker:
build:
context: .
dockerfile: Dockerfile.agent-docker
volumes:
- ./artifacts:/opt/agent/artifacts:ro
- ./scripts:/opt/agent/scripts:ro
- /var/run/docker.sock:/var/run/docker.sock
- wal-queue:/wal-queue
depends_on:
e2e-postgres:
condition: service_healthy
e2e-mock-server:
condition: service_healthy
container_name: e2e-agent-docker
command: ["bash", "/opt/agent/scripts/run-all.sh", "docker"]
volumes:
wal-queue:
backup-storage:

View File

@@ -0,0 +1,477 @@
package main
import (
"crypto/rand"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"sync"
"time"
)
const backupStorageDir = "/backup-storage"
type walSegment struct {
BackupID string
SegmentName string
FilePath string
SizeBytes int64
}
type server struct {
mu sync.RWMutex
version string
binaryPath string
backupID string
backupFilePath string
startSegment string
stopSegment string
isFinalized bool
walSegments []walSegment
backupCreatedAt time.Time
}
func main() {
version := "v2.0.0"
binaryPath := "/artifacts/agent-v2"
port := "4050"
_ = os.MkdirAll(backupStorageDir, 0o755)
s := &server{version: version, binaryPath: binaryPath}
// System endpoints
http.HandleFunc("/api/v1/system/version", s.handleVersion)
http.HandleFunc("/api/v1/system/agent", s.handleAgentDownload)
// Backup endpoints
http.HandleFunc("/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup", s.handleChainValidity)
http.HandleFunc("/api/v1/backups/postgres/wal/next-full-backup-time", s.handleNextBackupTime)
http.HandleFunc("/api/v1/backups/postgres/wal/upload/full-start", s.handleFullStart)
http.HandleFunc("/api/v1/backups/postgres/wal/upload/full-complete", s.handleFullComplete)
http.HandleFunc("/api/v1/backups/postgres/wal/upload/wal", s.handleWalUpload)
http.HandleFunc("/api/v1/backups/postgres/wal/error", s.handleError)
// Restore endpoints
http.HandleFunc("/api/v1/backups/postgres/wal/restore/plan", s.handleRestorePlan)
http.HandleFunc("/api/v1/backups/postgres/wal/restore/download", s.handleRestoreDownload)
// Mock control endpoints
http.HandleFunc("/mock/set-version", s.handleSetVersion)
http.HandleFunc("/mock/set-binary-path", s.handleSetBinaryPath)
http.HandleFunc("/mock/backup-status", s.handleBackupStatus)
http.HandleFunc("/mock/reset", s.handleReset)
http.HandleFunc("/health", s.handleHealth)
addr := ":" + port
log.Printf("Mock server starting on %s (version=%s, binary=%s)", addr, version, binaryPath)
if err := http.ListenAndServe(addr, nil); err != nil {
log.Fatalf("Server failed: %v", err)
}
}
// --- System handlers ---
func (s *server) handleVersion(w http.ResponseWriter, _ *http.Request) {
s.mu.RLock()
v := s.version
s.mu.RUnlock()
log.Printf("GET /api/v1/system/version -> %s", v)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{"version": v})
}
func (s *server) handleAgentDownload(w http.ResponseWriter, r *http.Request) {
s.mu.RLock()
path := s.binaryPath
s.mu.RUnlock()
log.Printf("GET /api/v1/system/agent (arch=%s) -> serving %s", r.URL.Query().Get("arch"), path)
http.ServeFile(w, r, path)
}
// --- Backup handlers ---
func (s *server) handleChainValidity(w http.ResponseWriter, _ *http.Request) {
s.mu.RLock()
isFinalized := s.isFinalized
s.mu.RUnlock()
log.Printf("GET chain-validity -> isFinalized=%v", isFinalized)
w.Header().Set("Content-Type", "application/json")
if isFinalized {
_ = json.NewEncoder(w).Encode(map[string]any{
"isValid": true,
})
} else {
_ = json.NewEncoder(w).Encode(map[string]any{
"isValid": false,
"error": "no full backup found",
})
}
}
func (s *server) handleNextBackupTime(w http.ResponseWriter, _ *http.Request) {
log.Printf("GET next-full-backup-time")
nextTime := time.Now().UTC().Add(1 * time.Hour)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"nextFullBackupTime": nextTime.Format(time.RFC3339),
})
}
func (s *server) handleFullStart(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "POST only", http.StatusMethodNotAllowed)
return
}
backupID := generateID()
filePath := filepath.Join(backupStorageDir, backupID+".zst")
file, err := os.Create(filePath)
if err != nil {
log.Printf("ERROR creating backup file: %v", err)
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
bytesWritten, err := io.Copy(file, r.Body)
_ = file.Close()
if err != nil {
log.Printf("ERROR writing backup data: %v", err)
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
s.mu.Lock()
s.backupID = backupID
s.backupFilePath = filePath
s.backupCreatedAt = time.Now().UTC()
s.mu.Unlock()
log.Printf("POST full-start -> backupID=%s, size=%d bytes", backupID, bytesWritten)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{"backupId": backupID})
}
func (s *server) handleFullComplete(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "POST only", http.StatusMethodNotAllowed)
return
}
var body struct {
BackupID string `json:"backupId"`
StartSegment string `json:"startSegment"`
StopSegment string `json:"stopSegment"`
Error *string `json:"error,omitempty"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if body.Error != nil {
log.Printf("POST full-complete -> backupID=%s ERROR: %s", body.BackupID, *body.Error)
w.WriteHeader(http.StatusOK)
return
}
s.mu.Lock()
s.startSegment = body.StartSegment
s.stopSegment = body.StopSegment
s.isFinalized = true
s.mu.Unlock()
log.Printf(
"POST full-complete -> backupID=%s, start=%s, stop=%s",
body.BackupID,
body.StartSegment,
body.StopSegment,
)
w.WriteHeader(http.StatusOK)
}
func (s *server) handleWalUpload(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "POST only", http.StatusMethodNotAllowed)
return
}
segmentName := r.Header.Get("X-Wal-Segment-Name")
if segmentName == "" {
http.Error(w, "missing X-Wal-Segment-Name header", http.StatusBadRequest)
return
}
walBackupID := generateID()
filePath := filepath.Join(backupStorageDir, walBackupID+".zst")
file, err := os.Create(filePath)
if err != nil {
log.Printf("ERROR creating WAL file: %v", err)
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
bytesWritten, err := io.Copy(file, r.Body)
_ = file.Close()
if err != nil {
log.Printf("ERROR writing WAL data: %v", err)
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
s.mu.Lock()
s.walSegments = append(s.walSegments, walSegment{
BackupID: walBackupID,
SegmentName: segmentName,
FilePath: filePath,
SizeBytes: bytesWritten,
})
s.mu.Unlock()
log.Printf("POST wal-upload -> segment=%s, walBackupID=%s, size=%d", segmentName, walBackupID, bytesWritten)
w.WriteHeader(http.StatusNoContent)
}
func (s *server) handleError(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "POST only", http.StatusMethodNotAllowed)
return
}
var body struct {
Error string `json:"error"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
log.Printf("POST error -> failed to decode: %v", err)
} else {
log.Printf("POST error -> %s", body.Error)
}
w.WriteHeader(http.StatusOK)
}
// --- Restore handlers ---
func (s *server) handleRestorePlan(w http.ResponseWriter, _ *http.Request) {
s.mu.RLock()
defer s.mu.RUnlock()
if !s.isFinalized {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
_ = json.NewEncoder(w).Encode(map[string]string{
"error": "no_backups",
"message": "No full backups available",
})
return
}
backupFileInfo, err := os.Stat(s.backupFilePath)
if err != nil {
log.Printf("ERROR stat backup file: %v", err)
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
backupSizeBytes := backupFileInfo.Size()
totalSizeBytes := backupSizeBytes
walSegmentsJSON := make([]map[string]any, 0, len(s.walSegments))
latestSegment := ""
for _, segment := range s.walSegments {
totalSizeBytes += segment.SizeBytes
latestSegment = segment.SegmentName
walSegmentsJSON = append(walSegmentsJSON, map[string]any{
"backupId": segment.BackupID,
"segmentName": segment.SegmentName,
"sizeBytes": segment.SizeBytes,
})
}
response := map[string]any{
"fullBackup": map[string]any{
"id": s.backupID,
"fullBackupWalStartSegment": s.startSegment,
"fullBackupWalStopSegment": s.stopSegment,
"pgVersion": "17",
"createdAt": s.backupCreatedAt.Format(time.RFC3339),
"sizeBytes": backupSizeBytes,
},
"walSegments": walSegmentsJSON,
"totalSizeBytes": totalSizeBytes,
"latestAvailableSegment": latestSegment,
}
log.Printf("GET restore-plan -> backupID=%s, walSegments=%d", s.backupID, len(s.walSegments))
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(response)
}
func (s *server) handleRestoreDownload(w http.ResponseWriter, r *http.Request) {
requestedBackupID := r.URL.Query().Get("backupId")
if requestedBackupID == "" {
http.Error(w, "missing backupId query param", http.StatusBadRequest)
return
}
filePath := s.findBackupFile(requestedBackupID)
if filePath == "" {
log.Printf("GET restore-download -> backupId=%s NOT FOUND", requestedBackupID)
http.Error(w, "backup not found", http.StatusNotFound)
return
}
log.Printf("GET restore-download -> backupId=%s, file=%s", requestedBackupID, filePath)
http.ServeFile(w, r, filePath)
}
// --- Mock control handlers ---
func (s *server) handleSetVersion(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "POST only", http.StatusMethodNotAllowed)
return
}
var body struct {
Version string `json:"version"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
s.mu.Lock()
s.version = body.Version
s.mu.Unlock()
log.Printf("POST /mock/set-version -> %s", body.Version)
_, _ = fmt.Fprintf(w, "version set to %s", body.Version)
}
func (s *server) handleSetBinaryPath(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "POST only", http.StatusMethodNotAllowed)
return
}
var body struct {
BinaryPath string `json:"binaryPath"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
s.mu.Lock()
s.binaryPath = body.BinaryPath
s.mu.Unlock()
log.Printf("POST /mock/set-binary-path -> %s", body.BinaryPath)
_, _ = fmt.Fprintf(w, "binary path set to %s", body.BinaryPath)
}
func (s *server) handleBackupStatus(w http.ResponseWriter, _ *http.Request) {
s.mu.RLock()
isFinalized := s.isFinalized
walSegmentCount := len(s.walSegments)
s.mu.RUnlock()
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"isFinalized": isFinalized,
"walSegmentCount": walSegmentCount,
})
}
func (s *server) handleReset(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "POST only", http.StatusMethodNotAllowed)
return
}
s.mu.Lock()
s.backupID = ""
s.backupFilePath = ""
s.startSegment = ""
s.stopSegment = ""
s.isFinalized = false
s.walSegments = nil
s.backupCreatedAt = time.Time{}
s.mu.Unlock()
// Clean stored files
entries, _ := os.ReadDir(backupStorageDir)
for _, entry := range entries {
_ = os.Remove(filepath.Join(backupStorageDir, entry.Name()))
}
log.Printf("POST /mock/reset -> state cleared")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}
func (s *server) handleHealth(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}
// --- Private helpers ---
func generateID() string {
b := make([]byte, 16)
_, _ = rand.Read(b)
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:16])
}
func (s *server) findBackupFile(backupID string) string {
s.mu.RLock()
defer s.mu.RUnlock()
if s.backupID == backupID {
return s.backupFilePath
}
for _, segment := range s.walSegments {
if segment.BackupID == backupID {
return segment.FilePath
}
}
return ""
}

View File

@@ -0,0 +1,357 @@
#!/bin/bash
# Shared helper functions for backup-restore E2E tests.
# Source this file from test scripts: source "$(dirname "$0")/backup-restore-helpers.sh"
AGENT="/tmp/test-agent"
AGENT_PID=""
cleanup_agent() {
if [ -n "$AGENT_PID" ]; then
kill "$AGENT_PID" 2>/dev/null || true
wait "$AGENT_PID" 2>/dev/null || true
AGENT_PID=""
fi
pkill -f "test-agent" 2>/dev/null || true
for i in $(seq 1 20); do
pgrep -f "test-agent" > /dev/null 2>&1 || break
sleep 0.5
done
pkill -9 -f "test-agent" 2>/dev/null || true
sleep 0.5
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
}
setup_agent() {
local artifacts="${1:-/opt/agent/artifacts}"
cleanup_agent
cp "$artifacts/agent-v1" "$AGENT"
chmod +x "$AGENT"
}
init_pg_local() {
local pgdata="$1"
local port="$2"
local wal_queue="$3"
local pg_bin_dir="$4"
# Stop any leftover PG from previous test runs
su postgres -c "$pg_bin_dir/pg_ctl -D $pgdata stop -m immediate" 2>/dev/null || true
su postgres -c "$pg_bin_dir/pg_ctl -D /tmp/restore-pgdata stop -m immediate" 2>/dev/null || true
mkdir -p "$wal_queue"
chown postgres:postgres "$wal_queue"
rm -rf "$pgdata"
su postgres -c "$pg_bin_dir/initdb -D $pgdata" > /dev/null
cat >> "$pgdata/postgresql.conf" <<PGCONF
wal_level = replica
archive_mode = on
archive_command = 'cp %p $wal_queue/%f'
max_wal_senders = 3
listen_addresses = 'localhost'
port = $port
checkpoint_timeout = 30s
PGCONF
echo "local all all trust" > "$pgdata/pg_hba.conf"
echo "host all all 127.0.0.1/32 trust" >> "$pgdata/pg_hba.conf"
echo "host all all ::1/128 trust" >> "$pgdata/pg_hba.conf"
echo "local replication all trust" >> "$pgdata/pg_hba.conf"
echo "host replication all 127.0.0.1/32 trust" >> "$pgdata/pg_hba.conf"
echo "host replication all ::1/128 trust" >> "$pgdata/pg_hba.conf"
su postgres -c "$pg_bin_dir/pg_ctl -D $pgdata -l /tmp/pg.log start -w"
su postgres -c "$pg_bin_dir/psql -p $port -c \"CREATE USER testuser WITH SUPERUSER REPLICATION;\"" > /dev/null 2>&1 || true
su postgres -c "$pg_bin_dir/psql -p $port -c \"CREATE DATABASE testdb OWNER testuser;\"" > /dev/null 2>&1 || true
echo "PostgreSQL initialized and started on port $port"
}
insert_test_data() {
local port="$1"
local pg_bin_dir="$2"
su postgres -c "$pg_bin_dir/psql -p $port -U testuser -d testdb" <<SQL
CREATE TABLE e2e_test_data (
id SERIAL PRIMARY KEY,
name TEXT NOT NULL,
value INT NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW()
);
INSERT INTO e2e_test_data (name, value) VALUES
('row1', 100),
('row2', 200),
('row3', 300);
SQL
echo "Test data inserted (3 rows)"
}
force_checkpoint() {
local port="$1"
local pg_bin_dir="$2"
su postgres -c "$pg_bin_dir/psql -p $port -c 'CHECKPOINT;'" > /dev/null
echo "Checkpoint forced"
}
run_agent_backup() {
local mock_server="$1"
local pg_host="$2"
local pg_port="$3"
local wal_queue="$4"
local pg_type="$5"
local pg_host_bin_dir="${6:-}"
local pg_docker_container="${7:-}"
# Reset mock server state and set version to match agent (prevents background upgrade loop)
curl -sf -X POST "$mock_server/mock/reset" > /dev/null
curl -sf -X POST "$mock_server/mock/set-version" \
-H "Content-Type: application/json" \
-d '{"version":"v1.0.0"}' > /dev/null
# Build JSON config
cd /tmp
local extra_fields=""
if [ -n "$pg_host_bin_dir" ]; then
extra_fields="$extra_fields\"pgHostBinDir\": \"$pg_host_bin_dir\","
fi
if [ -n "$pg_docker_container" ]; then
extra_fields="$extra_fields\"pgDockerContainerName\": \"$pg_docker_container\","
fi
cat > databasus.json <<AGENTCONF
{
"databasusHost": "$mock_server",
"dbId": "test-db-id",
"token": "test-token",
"pgHost": "$pg_host",
"pgPort": $pg_port,
"pgUser": "testuser",
"pgPassword": "",
${extra_fields}
"pgType": "$pg_type",
"pgWalDir": "$wal_queue",
"deleteWalAfterUpload": true
}
AGENTCONF
# Run agent daemon in background
"$AGENT" _run > /tmp/agent-output.log 2>&1 &
AGENT_PID=$!
echo "Agent started with PID $AGENT_PID"
}
generate_wal_background() {
local port="$1"
local pg_bin_dir="$2"
while true; do
su postgres -c "$pg_bin_dir/psql -p $port -U testuser -d testdb -c \"
INSERT INTO e2e_test_data (name, value)
SELECT 'bulk_' || g, g FROM generate_series(1, 1000) g;
SELECT pg_switch_wal();
\"" > /dev/null 2>&1 || break
sleep 2
done
}
generate_wal_docker_background() {
local container="$1"
while true; do
docker exec "$container" psql -U testuser -d testdb -c "
INSERT INTO e2e_test_data (name, value)
SELECT 'bulk_' || g, g FROM generate_series(1, 1000) g;
SELECT pg_switch_wal();
" > /dev/null 2>&1 || break
sleep 2
done
}
wait_for_backup_complete() {
local mock_server="$1"
local timeout="${2:-120}"
echo "Waiting for backup to complete (timeout: ${timeout}s)..."
for i in $(seq 1 "$timeout"); do
STATUS=$(curl -sf "$mock_server/mock/backup-status" 2>/dev/null || echo '{}')
IS_FINALIZED=$(echo "$STATUS" | grep -o '"isFinalized":true' || true)
WAL_COUNT=$(echo "$STATUS" | grep -o '"walSegmentCount":[0-9]*' | grep -o '[0-9]*$' || echo "0")
if [ -n "$IS_FINALIZED" ] && [ "$WAL_COUNT" -gt 0 ]; then
echo "Backup complete: finalized with $WAL_COUNT WAL segments"
return 0
fi
sleep 1
done
echo "FAIL: Backup did not complete within ${timeout} seconds"
echo "Last status: $STATUS"
echo "Agent output:"
cat /tmp/agent-output.log 2>/dev/null || true
return 1
}
stop_agent() {
if [ -n "$AGENT_PID" ]; then
kill "$AGENT_PID" 2>/dev/null || true
wait "$AGENT_PID" 2>/dev/null || true
AGENT_PID=""
fi
echo "Agent stopped"
}
stop_pg() {
local pgdata="$1"
local pg_bin_dir="$2"
su postgres -c "$pg_bin_dir/pg_ctl -D $pgdata stop -m fast" 2>/dev/null || true
echo "PostgreSQL stopped"
}
run_agent_restore() {
local mock_server="$1"
local restore_dir="$2"
rm -rf "$restore_dir"
mkdir -p "$restore_dir"
chown postgres:postgres "$restore_dir"
cd /tmp
"$AGENT" restore \
--skip-update \
--databasus-host "$mock_server" \
--token test-token \
--target-dir "$restore_dir"
echo "Agent restore completed"
}
start_restored_pg() {
local restore_dir="$1"
local port="$2"
local pg_bin_dir="$3"
# Ensure port is set in restored config
if ! grep -q "^port" "$restore_dir/postgresql.conf" 2>/dev/null; then
echo "port = $port" >> "$restore_dir/postgresql.conf"
fi
# Ensure listen_addresses is set
if ! grep -q "^listen_addresses" "$restore_dir/postgresql.conf" 2>/dev/null; then
echo "listen_addresses = 'localhost'" >> "$restore_dir/postgresql.conf"
fi
chown -R postgres:postgres "$restore_dir"
chmod 700 "$restore_dir"
if ! su postgres -c "$pg_bin_dir/pg_ctl -D $restore_dir -l /tmp/pg-restore.log start -w"; then
echo "FAIL: PostgreSQL failed to start on restored data"
echo "--- pg-restore.log ---"
cat /tmp/pg-restore.log 2>/dev/null || echo "(no log file)"
echo "--- postgresql.auto.conf ---"
cat "$restore_dir/postgresql.auto.conf" 2>/dev/null || echo "(no file)"
echo "--- pg_wal/ listing ---"
ls -la "$restore_dir/pg_wal/" 2>/dev/null || echo "(no pg_wal dir)"
echo "--- databasus-wal-restore/ listing ---"
ls -la "$restore_dir/databasus-wal-restore/" 2>/dev/null || echo "(no dir)"
echo "--- end diagnostics ---"
return 1
fi
echo "PostgreSQL started on restored data"
}
wait_for_recovery_complete() {
local port="$1"
local pg_bin_dir="$2"
local timeout="${3:-60}"
echo "Waiting for recovery to complete (timeout: ${timeout}s)..."
for i in $(seq 1 "$timeout"); do
IS_READY=$(su postgres -c "$pg_bin_dir/pg_isready -p $port" 2>&1 || true)
if echo "$IS_READY" | grep -q "accepting connections"; then
IN_RECOVERY=$(su postgres -c "$pg_bin_dir/psql -p $port -U testuser -d testdb -t -c 'SELECT pg_is_in_recovery();'" 2>/dev/null | tr -d ' \n' || echo "t")
if [ "$IN_RECOVERY" = "f" ]; then
echo "PostgreSQL recovered and promoted to primary"
return 0
fi
fi
sleep 1
done
echo "FAIL: PostgreSQL did not recover within ${timeout} seconds"
echo "Recovery log:"
cat /tmp/pg-restore.log 2>/dev/null || true
return 1
}
verify_restored_data() {
local port="$1"
local pg_bin_dir="$2"
ROW_COUNT=$(su postgres -c "$pg_bin_dir/psql -p $port -U testuser -d testdb -t -c 'SELECT COUNT(*) FROM e2e_test_data;'" | tr -d ' \n')
if [ "$ROW_COUNT" -lt 3 ]; then
echo "FAIL: Expected at least 3 rows, got $ROW_COUNT"
su postgres -c "$pg_bin_dir/psql -p $port -U testuser -d testdb -c 'SELECT * FROM e2e_test_data;'"
return 1
fi
RESULT=$(su postgres -c "$pg_bin_dir/psql -p $port -U testuser -d testdb -t -c \"SELECT value FROM e2e_test_data WHERE name='row1';\"" | tr -d ' \n')
if [ "$RESULT" != "100" ]; then
echo "FAIL: Expected row1 value=100, got $RESULT"
return 1
fi
RESULT2=$(su postgres -c "$pg_bin_dir/psql -p $port -U testuser -d testdb -t -c \"SELECT value FROM e2e_test_data WHERE name='row3';\"" | tr -d ' \n')
if [ "$RESULT2" != "300" ]; then
echo "FAIL: Expected row3 value=300, got $RESULT2"
return 1
fi
echo "PASS: Found $ROW_COUNT rows, data integrity verified"
return 0
}
find_pg_bin_dir() {
# Find the PG bin dir from the installed version
local pg_config_path
pg_config_path=$(which pg_config 2>/dev/null || true)
if [ -n "$pg_config_path" ]; then
pg_config --bindir
return
fi
# Fallback: search common locations
for version in 18 17 16 15; do
if [ -d "/usr/lib/postgresql/$version/bin" ]; then
echo "/usr/lib/postgresql/$version/bin"
return
fi
done
echo "ERROR: Cannot find PostgreSQL bin directory" >&2
return 1
}

View File

@@ -0,0 +1,56 @@
#!/bin/bash
set -euo pipefail
MODE="${1:-host}"
SCRIPT_DIR="$(dirname "$0")"
PASSED=0
FAILED=0
FAILED_NAMES=""
run_test() {
local name="$1"
local script="$2"
echo ""
echo "========================================"
echo " $name"
echo "========================================"
if bash "$script"; then
echo " PASSED: $name"
PASSED=$((PASSED + 1))
else
echo " FAILED: $name"
FAILED=$((FAILED + 1))
FAILED_NAMES="${FAILED_NAMES}\n - ${name}"
fi
}
if [ "$MODE" = "host" ]; then
run_test "Test 1: Upgrade success (v1 -> v2)" "$SCRIPT_DIR/test-upgrade-success.sh"
run_test "Test 2: Upgrade skip (version matches)" "$SCRIPT_DIR/test-upgrade-skip.sh"
run_test "Test 3: Background upgrade (v1 -> v2 while running)" "$SCRIPT_DIR/test-upgrade-background.sh"
run_test "Test 4: Backup-restore via host PATH" "$SCRIPT_DIR/test-pg-host-path.sh"
run_test "Test 5: Backup-restore via host bindir" "$SCRIPT_DIR/test-pg-host-bindir.sh"
elif [ "$MODE" = "docker" ]; then
run_test "Test 6: Backup-restore via docker exec" "$SCRIPT_DIR/test-pg-docker-exec.sh"
else
echo "Unknown mode: $MODE (expected 'host' or 'docker')"
exit 1
fi
echo ""
echo "========================================"
echo " Results: $PASSED passed, $FAILED failed"
if [ "$FAILED" -gt 0 ]; then
echo ""
echo " Failed:"
echo -e "$FAILED_NAMES"
fi
echo "========================================"
if [ "$FAILED" -gt 0 ]; then
exit 1
fi

View File

@@ -0,0 +1,95 @@
#!/bin/bash
set -euo pipefail
SCRIPT_DIR="$(dirname "$0")"
source "$SCRIPT_DIR/backup-restore-helpers.sh"
MOCK_SERVER="${MOCK_SERVER_OVERRIDE:-http://e2e-mock-server:4050}"
PG_CONTAINER="e2e-agent-postgres"
RESTORE_PGDATA="/tmp/restore-pgdata"
WAL_QUEUE="/wal-queue"
PG_PORT=5432
# For restore verification we need a local PG bin dir
PG_BIN_DIR=$(find_pg_bin_dir)
echo "Using local PG bin dir for restore verification: $PG_BIN_DIR"
# Verify docker CLI works and PG container is accessible
if ! docker exec "$PG_CONTAINER" pg_basebackup --version > /dev/null 2>&1; then
echo "FAIL: Cannot reach pg_basebackup inside container $PG_CONTAINER (test setup issue)"
exit 1
fi
echo "=== Phase 1: Setup agent ==="
setup_agent
echo "=== Phase 2: Insert test data into containerized PostgreSQL ==="
docker exec "$PG_CONTAINER" psql -U testuser -d testdb -c "
CREATE TABLE IF NOT EXISTS e2e_test_data (
id SERIAL PRIMARY KEY,
name TEXT NOT NULL,
value INT NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW()
);
DELETE FROM e2e_test_data;
INSERT INTO e2e_test_data (name, value) VALUES
('row1', 100),
('row2', 200),
('row3', 300);
"
echo "Test data inserted (3 rows)"
echo "=== Phase 3: Start agent backup (docker exec mode) ==="
curl -sf -X POST "$MOCK_SERVER/mock/reset" > /dev/null
cd /tmp
cat > databasus.json <<AGENTCONF
{
"databasusHost": "$MOCK_SERVER",
"dbId": "test-db-id",
"token": "test-token",
"pgHost": "$PG_CONTAINER",
"pgPort": $PG_PORT,
"pgUser": "testuser",
"pgPassword": "testpassword",
"pgType": "docker",
"pgDockerContainerName": "$PG_CONTAINER",
"pgWalDir": "$WAL_QUEUE",
"deleteWalAfterUpload": true
}
AGENTCONF
"$AGENT" _run > /tmp/agent-output.log 2>&1 &
AGENT_PID=$!
echo "Agent started with PID $AGENT_PID"
echo "=== Phase 4: Generate WAL in background ==="
generate_wal_docker_background "$PG_CONTAINER" &
WAL_GEN_PID=$!
echo "=== Phase 5: Wait for backup to complete ==="
wait_for_backup_complete "$MOCK_SERVER" 120
echo "=== Phase 6: Stop WAL generator and agent ==="
kill $WAL_GEN_PID 2>/dev/null || true
wait $WAL_GEN_PID 2>/dev/null || true
stop_agent
echo "=== Phase 7: Restore to local directory ==="
run_agent_restore "$MOCK_SERVER" "$RESTORE_PGDATA"
echo "=== Phase 8: Start local PostgreSQL on restored data ==="
# Use a different port to avoid conflict with the containerized PG
RESTORE_PORT=5433
start_restored_pg "$RESTORE_PGDATA" "$RESTORE_PORT" "$PG_BIN_DIR"
echo "=== Phase 9: Wait for recovery ==="
wait_for_recovery_complete "$RESTORE_PORT" "$PG_BIN_DIR" 60
echo "=== Phase 10: Verify data ==="
verify_restored_data "$RESTORE_PORT" "$PG_BIN_DIR"
echo "=== Phase 11: Cleanup ==="
stop_pg "$RESTORE_PGDATA" "$PG_BIN_DIR"
echo "pg_basebackup via docker exec: full backup-restore lifecycle passed"

View File

@@ -0,0 +1,62 @@
#!/bin/bash
set -euo pipefail
SCRIPT_DIR="$(dirname "$0")"
source "$SCRIPT_DIR/backup-restore-helpers.sh"
MOCK_SERVER="${MOCK_SERVER_OVERRIDE:-http://e2e-mock-server:4050}"
PGDATA="/tmp/pgdata"
RESTORE_PGDATA="/tmp/restore-pgdata"
WAL_QUEUE="/tmp/wal-queue"
PG_PORT=5433
CUSTOM_BIN_DIR="/opt/pg/bin"
PG_BIN_DIR=$(find_pg_bin_dir)
echo "Using PG bin dir: $PG_BIN_DIR"
# Copy pg_basebackup to a custom directory (simulates non-PATH installation)
mkdir -p "$CUSTOM_BIN_DIR"
cp "$PG_BIN_DIR/pg_basebackup" "$CUSTOM_BIN_DIR/pg_basebackup"
echo "=== Phase 1: Setup agent ==="
setup_agent
echo "=== Phase 2: Initialize PostgreSQL ==="
init_pg_local "$PGDATA" "$PG_PORT" "$WAL_QUEUE" "$PG_BIN_DIR"
echo "=== Phase 3: Insert test data ==="
insert_test_data "$PG_PORT" "$PG_BIN_DIR"
echo "=== Phase 4: Force checkpoint and start agent backup (using --pg-host-bin-dir) ==="
force_checkpoint "$PG_PORT" "$PG_BIN_DIR"
run_agent_backup "$MOCK_SERVER" "127.0.0.1" "$PG_PORT" "$WAL_QUEUE" "host" "$CUSTOM_BIN_DIR"
echo "=== Phase 5: Generate WAL in background ==="
generate_wal_background "$PG_PORT" "$PG_BIN_DIR" &
WAL_GEN_PID=$!
echo "=== Phase 6: Wait for backup to complete ==="
wait_for_backup_complete "$MOCK_SERVER" 120
echo "=== Phase 7: Stop WAL generator, agent, and PostgreSQL ==="
kill $WAL_GEN_PID 2>/dev/null || true
wait $WAL_GEN_PID 2>/dev/null || true
stop_agent
stop_pg "$PGDATA" "$PG_BIN_DIR"
echo "=== Phase 8: Restore ==="
run_agent_restore "$MOCK_SERVER" "$RESTORE_PGDATA"
echo "=== Phase 9: Start PostgreSQL on restored data ==="
start_restored_pg "$RESTORE_PGDATA" "$PG_PORT" "$PG_BIN_DIR"
echo "=== Phase 10: Wait for recovery ==="
wait_for_recovery_complete "$PG_PORT" "$PG_BIN_DIR" 60
echo "=== Phase 11: Verify data ==="
verify_restored_data "$PG_PORT" "$PG_BIN_DIR"
echo "=== Phase 12: Cleanup ==="
stop_pg "$RESTORE_PGDATA" "$PG_BIN_DIR"
echo "pg_basebackup via custom bindir: full backup-restore lifecycle passed"

View File

@@ -0,0 +1,63 @@
#!/bin/bash
set -euo pipefail
SCRIPT_DIR="$(dirname "$0")"
source "$SCRIPT_DIR/backup-restore-helpers.sh"
MOCK_SERVER="${MOCK_SERVER_OVERRIDE:-http://e2e-mock-server:4050}"
PGDATA="/tmp/pgdata"
RESTORE_PGDATA="/tmp/restore-pgdata"
WAL_QUEUE="/tmp/wal-queue"
PG_PORT=5433
PG_BIN_DIR=$(find_pg_bin_dir)
echo "Using PG bin dir: $PG_BIN_DIR"
# Verify pg_basebackup is in PATH
if ! which pg_basebackup > /dev/null 2>&1; then
echo "FAIL: pg_basebackup not found in PATH (test setup issue)"
exit 1
fi
echo "=== Phase 1: Setup agent ==="
setup_agent
echo "=== Phase 2: Initialize PostgreSQL ==="
init_pg_local "$PGDATA" "$PG_PORT" "$WAL_QUEUE" "$PG_BIN_DIR"
echo "=== Phase 3: Insert test data ==="
insert_test_data "$PG_PORT" "$PG_BIN_DIR"
echo "=== Phase 4: Force checkpoint and start agent backup ==="
force_checkpoint "$PG_PORT" "$PG_BIN_DIR"
run_agent_backup "$MOCK_SERVER" "127.0.0.1" "$PG_PORT" "$WAL_QUEUE" "host"
echo "=== Phase 5: Generate WAL in background ==="
generate_wal_background "$PG_PORT" "$PG_BIN_DIR" &
WAL_GEN_PID=$!
echo "=== Phase 6: Wait for backup to complete ==="
wait_for_backup_complete "$MOCK_SERVER" 120
echo "=== Phase 7: Stop WAL generator, agent, and PostgreSQL ==="
kill $WAL_GEN_PID 2>/dev/null || true
wait $WAL_GEN_PID 2>/dev/null || true
stop_agent
stop_pg "$PGDATA" "$PG_BIN_DIR"
echo "=== Phase 8: Restore ==="
run_agent_restore "$MOCK_SERVER" "$RESTORE_PGDATA"
echo "=== Phase 9: Start PostgreSQL on restored data ==="
start_restored_pg "$RESTORE_PGDATA" "$PG_PORT" "$PG_BIN_DIR"
echo "=== Phase 10: Wait for recovery ==="
wait_for_recovery_complete "$PG_PORT" "$PG_BIN_DIR" 60
echo "=== Phase 11: Verify data ==="
verify_restored_data "$PG_PORT" "$PG_BIN_DIR"
echo "=== Phase 12: Cleanup ==="
stop_pg "$RESTORE_PGDATA" "$PG_BIN_DIR"
echo "pg_basebackup in PATH: full backup-restore lifecycle passed"

View File

@@ -0,0 +1,90 @@
#!/bin/bash
set -euo pipefail
ARTIFACTS="/opt/agent/artifacts"
AGENT="/tmp/test-agent"
# Cleanup from previous runs
pkill -f "test-agent" 2>/dev/null || true
for i in $(seq 1 20); do
pgrep -f "test-agent" > /dev/null 2>&1 || break
sleep 0.5
done
pkill -9 -f "test-agent" 2>/dev/null || true
sleep 0.5
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
# Set mock server to v1.0.0 (same as agent — no sync upgrade on start)
curl -sf -X POST http://e2e-mock-server:4050/mock/set-version \
-H "Content-Type: application/json" \
-d '{"version":"v1.0.0"}'
curl -sf -X POST http://e2e-mock-server:4050/mock/set-binary-path \
-H "Content-Type: application/json" \
-d '{"binaryPath":"/artifacts/agent-v1"}'
# Copy v1 binary to writable location
cp "$ARTIFACTS/agent-v1" "$AGENT"
chmod +x "$AGENT"
# Verify initial version
VERSION=$("$AGENT" version)
if [ "$VERSION" != "v1.0.0" ]; then
echo "FAIL: Expected initial version v1.0.0, got $VERSION"
exit 1
fi
echo "Initial version: $VERSION"
# Start agent as daemon (versions match → no sync upgrade)
mkdir -p /tmp/wal
"$AGENT" start \
--databasus-host http://e2e-mock-server:4050 \
--db-id test-db-id \
--token test-token \
--pg-host e2e-postgres \
--pg-port 5432 \
--pg-user testuser \
--pg-password testpassword \
--pg-wal-dir /tmp/wal \
--pg-type host
echo "Agent started as daemon, waiting for stabilization..."
sleep 2
# Change mock server to v2.0.0 and point to v2 binary
curl -sf -X POST http://e2e-mock-server:4050/mock/set-version \
-H "Content-Type: application/json" \
-d '{"version":"v2.0.0"}'
curl -sf -X POST http://e2e-mock-server:4050/mock/set-binary-path \
-H "Content-Type: application/json" \
-d '{"binaryPath":"/artifacts/agent-v2"}'
echo "Mock server updated to v2.0.0, waiting for background upgrade..."
# Poll for upgrade (timeout 60s, poll every 3s)
DEADLINE=$((SECONDS + 60))
while [ $SECONDS -lt $DEADLINE ]; do
VERSION=$("$AGENT" version)
if [ "$VERSION" = "v2.0.0" ]; then
echo "Binary upgraded to $VERSION"
break
fi
sleep 3
done
VERSION=$("$AGENT" version)
if [ "$VERSION" != "v2.0.0" ]; then
echo "FAIL: Expected v2.0.0 after background upgrade, got $VERSION"
cat databasus.log 2>/dev/null || true
exit 1
fi
# Verify agent is still running after restart
sleep 2
"$AGENT" status || true
# Cleanup
"$AGENT" stop || true
echo "Background upgrade test passed"

View File

@@ -0,0 +1,64 @@
#!/bin/bash
set -euo pipefail
ARTIFACTS="/opt/agent/artifacts"
AGENT="/tmp/test-agent"
# Cleanup from previous runs
pkill -f "test-agent" 2>/dev/null || true
for i in $(seq 1 20); do
pgrep -f "test-agent" > /dev/null 2>&1 || break
sleep 0.5
done
pkill -9 -f "test-agent" 2>/dev/null || true
sleep 0.5
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
# Set mock server to return v1.0.0 (same as agent)
curl -sf -X POST http://e2e-mock-server:4050/mock/set-version \
-H "Content-Type: application/json" \
-d '{"version":"v1.0.0"}'
# Copy v1 binary to writable location
cp "$ARTIFACTS/agent-v1" "$AGENT"
chmod +x "$AGENT"
# Verify initial version
VERSION=$("$AGENT" version)
if [ "$VERSION" != "v1.0.0" ]; then
echo "FAIL: Expected initial version v1.0.0, got $VERSION"
exit 1
fi
# Run start — agent should see version matches and skip upgrade
echo "Running agent start (expecting upgrade skip)..."
OUTPUT=$("$AGENT" start \
--databasus-host http://e2e-mock-server:4050 \
--db-id test-db-id \
--token test-token \
--pg-host e2e-postgres \
--pg-port 5432 \
--pg-user testuser \
--pg-password testpassword \
--pg-wal-dir /tmp/wal \
--pg-type host 2>&1) || true
echo "$OUTPUT"
# Verify output contains "up to date"
if ! echo "$OUTPUT" | grep -qi "up to date"; then
echo "FAIL: Expected output to contain 'up to date'"
exit 1
fi
# Verify binary is still v1
VERSION=$("$AGENT" version)
if [ "$VERSION" != "v1.0.0" ]; then
echo "FAIL: Expected version v1.0.0 (unchanged), got $VERSION"
exit 1
fi
echo "Upgrade correctly skipped, version still $VERSION"
# Cleanup daemon
"$AGENT" stop || true

View File

@@ -0,0 +1,69 @@
#!/bin/bash
set -euo pipefail
ARTIFACTS="/opt/agent/artifacts"
AGENT="/tmp/test-agent"
# Cleanup from previous runs
pkill -f "test-agent" 2>/dev/null || true
for i in $(seq 1 20); do
pgrep -f "test-agent" > /dev/null 2>&1 || break
sleep 0.5
done
pkill -9 -f "test-agent" 2>/dev/null || true
sleep 0.5
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
# Ensure mock server returns v2.0.0 and serves v2 binary
curl -sf -X POST http://e2e-mock-server:4050/mock/set-version \
-H "Content-Type: application/json" \
-d '{"version":"v2.0.0"}'
curl -sf -X POST http://e2e-mock-server:4050/mock/set-binary-path \
-H "Content-Type: application/json" \
-d '{"binaryPath":"/artifacts/agent-v2"}'
# Copy v1 binary to writable location
cp "$ARTIFACTS/agent-v1" "$AGENT"
chmod +x "$AGENT"
# Verify initial version
VERSION=$("$AGENT" version)
if [ "$VERSION" != "v1.0.0" ]; then
echo "FAIL: Expected initial version v1.0.0, got $VERSION"
exit 1
fi
echo "Initial version: $VERSION"
# Run start — agent will:
# 1. Fetch version from mock (v2.0.0 != v1.0.0)
# 2. Download v2 binary from mock
# 3. Replace itself on disk
# 4. Re-exec with same args
# 5. Re-exec'd v2 fetches version (v2.0.0 == v2.0.0) → skips update
# 6. Proceeds to start → verifies pg_basebackup + DB → exits 0 (stub)
echo "Running agent start (expecting upgrade v1 -> v2)..."
OUTPUT=$("$AGENT" start \
--databasus-host http://e2e-mock-server:4050 \
--db-id test-db-id \
--token test-token \
--pg-host e2e-postgres \
--pg-port 5432 \
--pg-user testuser \
--pg-password testpassword \
--pg-wal-dir /tmp/wal \
--pg-type host 2>&1) || true
echo "$OUTPUT"
# Verify binary on disk is now v2
VERSION=$("$AGENT" version)
if [ "$VERSION" != "v2.0.0" ]; then
echo "FAIL: Expected upgraded version v2.0.0, got $VERSION"
exit 1
fi
echo "Binary upgraded successfully to $VERSION"
# Cleanup daemon
"$AGENT" stop || true

View File

@@ -2,10 +2,21 @@ module databasus-agent
go 1.26.1
require github.com/stretchr/testify v1.11.1
require (
github.com/go-resty/resty/v2 v2.17.2
github.com/jackc/pgx/v5 v5.8.0
github.com/klauspost/compress v1.18.4
github.com/stretchr/testify v1.11.1
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.14.1 // indirect
golang.org/x/net v0.43.0 // indirect
golang.org/x/text v0.29.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -1,10 +1,43 @@
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-resty/resty/v2 v2.17.2 h1:FQW5oHYcIlkCNrMD2lloGScxcHJ0gkjshV3qcQAyHQk=
github.com/go-resty/resty/v2 v2.17.2/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo=
github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c=
github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -3,6 +3,7 @@ package config
import (
"encoding/json"
"flag"
"fmt"
"os"
"databasus-agent/internal/logger"
@@ -13,9 +14,18 @@ var log = logger.GetLogger()
const configFileName = "databasus.json"
type Config struct {
DatabasusHost string `json:"databasusHost"`
DbID string `json:"dbId"`
Token string `json:"token"`
DatabasusHost string `json:"databasusHost"`
DbID string `json:"dbId"`
Token string `json:"token"`
PgHost string `json:"pgHost"`
PgPort int `json:"pgPort"`
PgUser string `json:"pgUser"`
PgPassword string `json:"pgPassword"`
PgType string `json:"pgType"`
PgHostBinDir string `json:"pgHostBinDir"`
PgDockerContainerName string `json:"pgDockerContainerName"`
PgWalDir string `json:"pgWalDir"`
IsDeleteWalAfterUpload *bool `json:"deleteWalAfterUpload"`
flags parsedFlags
}
@@ -24,15 +34,24 @@ type Config struct {
// and overrides JSON values with any explicitly provided CLI flags.
func (c *Config) LoadFromJSONAndArgs(fs *flag.FlagSet, args []string) {
c.loadFromJSON()
c.applyDefaults()
c.initSources()
c.flags.host = fs.String(
c.flags.databasusHost = fs.String(
"databasus-host",
"",
"Databasus server URL (e.g. http://your-server:4005)",
)
c.flags.dbID = fs.String("db-id", "", "Database ID")
c.flags.token = fs.String("token", "", "Agent token")
c.flags.pgHost = fs.String("pg-host", "", "PostgreSQL host")
c.flags.pgPort = fs.Int("pg-port", 0, "PostgreSQL port")
c.flags.pgUser = fs.String("pg-user", "", "PostgreSQL user")
c.flags.pgPassword = fs.String("pg-password", "", "PostgreSQL password")
c.flags.pgType = fs.String("pg-type", "", "PostgreSQL type: host or docker")
c.flags.pgHostBinDir = fs.String("pg-host-bin-dir", "", "Path to PG bin directory (host mode)")
c.flags.pgDockerContainerName = fs.String("pg-docker-container-name", "", "Docker container name (docker mode)")
c.flags.pgWalDir = fs.String("pg-wal-dir", "", "Path to WAL queue directory")
if err := fs.Parse(args); err != nil {
os.Exit(1)
@@ -54,6 +73,11 @@ func (c *Config) SaveToJSON() error {
return os.WriteFile(configFileName, data, 0o644)
}
func (c *Config) LoadFromJSON() {
c.loadFromJSON()
c.applyDefaults()
}
func (c *Config) loadFromJSON() {
data, err := os.ReadFile(configFileName)
if err != nil {
@@ -76,11 +100,34 @@ func (c *Config) loadFromJSON() {
log.Info("Configuration loaded from " + configFileName)
}
func (c *Config) applyDefaults() {
if c.PgPort == 0 {
c.PgPort = 5432
}
if c.PgType == "" {
c.PgType = "host"
}
if c.IsDeleteWalAfterUpload == nil {
c.IsDeleteWalAfterUpload = new(true)
}
}
func (c *Config) initSources() {
c.flags.sources = map[string]string{
"databasus-host": "not configured",
"db-id": "not configured",
"token": "not configured",
"databasus-host": "not configured",
"db-id": "not configured",
"token": "not configured",
"pg-host": "not configured",
"pg-port": "not configured",
"pg-user": "not configured",
"pg-password": "not configured",
"pg-type": "not configured",
"pg-host-bin-dir": "not configured",
"pg-docker-container-name": "not configured",
"pg-wal-dir": "not configured",
"delete-wal-after-upload": "not configured",
}
if c.DatabasusHost != "" {
@@ -94,11 +141,44 @@ func (c *Config) initSources() {
if c.Token != "" {
c.flags.sources["token"] = configFileName
}
if c.PgHost != "" {
c.flags.sources["pg-host"] = configFileName
}
// PgPort always has a value after applyDefaults
c.flags.sources["pg-port"] = configFileName
if c.PgUser != "" {
c.flags.sources["pg-user"] = configFileName
}
if c.PgPassword != "" {
c.flags.sources["pg-password"] = configFileName
}
// PgType always has a value after applyDefaults
c.flags.sources["pg-type"] = configFileName
if c.PgHostBinDir != "" {
c.flags.sources["pg-host-bin-dir"] = configFileName
}
if c.PgDockerContainerName != "" {
c.flags.sources["pg-docker-container-name"] = configFileName
}
if c.PgWalDir != "" {
c.flags.sources["pg-wal-dir"] = configFileName
}
// IsDeleteWalAfterUpload always has a value after applyDefaults
c.flags.sources["delete-wal-after-upload"] = configFileName
}
func (c *Config) applyFlags() {
if c.flags.host != nil && *c.flags.host != "" {
c.DatabasusHost = *c.flags.host
if c.flags.databasusHost != nil && *c.flags.databasusHost != "" {
c.DatabasusHost = *c.flags.databasusHost
c.flags.sources["databasus-host"] = "command line args"
}
@@ -111,18 +191,73 @@ func (c *Config) applyFlags() {
c.Token = *c.flags.token
c.flags.sources["token"] = "command line args"
}
if c.flags.pgHost != nil && *c.flags.pgHost != "" {
c.PgHost = *c.flags.pgHost
c.flags.sources["pg-host"] = "command line args"
}
if c.flags.pgPort != nil && *c.flags.pgPort != 0 {
c.PgPort = *c.flags.pgPort
c.flags.sources["pg-port"] = "command line args"
}
if c.flags.pgUser != nil && *c.flags.pgUser != "" {
c.PgUser = *c.flags.pgUser
c.flags.sources["pg-user"] = "command line args"
}
if c.flags.pgPassword != nil && *c.flags.pgPassword != "" {
c.PgPassword = *c.flags.pgPassword
c.flags.sources["pg-password"] = "command line args"
}
if c.flags.pgType != nil && *c.flags.pgType != "" {
c.PgType = *c.flags.pgType
c.flags.sources["pg-type"] = "command line args"
}
if c.flags.pgHostBinDir != nil && *c.flags.pgHostBinDir != "" {
c.PgHostBinDir = *c.flags.pgHostBinDir
c.flags.sources["pg-host-bin-dir"] = "command line args"
}
if c.flags.pgDockerContainerName != nil && *c.flags.pgDockerContainerName != "" {
c.PgDockerContainerName = *c.flags.pgDockerContainerName
c.flags.sources["pg-docker-container-name"] = "command line args"
}
if c.flags.pgWalDir != nil && *c.flags.pgWalDir != "" {
c.PgWalDir = *c.flags.pgWalDir
c.flags.sources["pg-wal-dir"] = "command line args"
}
}
func (c *Config) logConfigSources() {
log.Info(
"databasus-host",
"value",
c.DatabasusHost,
"source",
c.flags.sources["databasus-host"],
)
log.Info("databasus-host", "value", c.DatabasusHost, "source", c.flags.sources["databasus-host"])
log.Info("db-id", "value", c.DbID, "source", c.flags.sources["db-id"])
log.Info("token", "value", maskSensitive(c.Token), "source", c.flags.sources["token"])
log.Info("pg-host", "value", c.PgHost, "source", c.flags.sources["pg-host"])
log.Info("pg-port", "value", c.PgPort, "source", c.flags.sources["pg-port"])
log.Info("pg-user", "value", c.PgUser, "source", c.flags.sources["pg-user"])
log.Info("pg-password", "value", maskSensitive(c.PgPassword), "source", c.flags.sources["pg-password"])
log.Info("pg-type", "value", c.PgType, "source", c.flags.sources["pg-type"])
log.Info("pg-host-bin-dir", "value", c.PgHostBinDir, "source", c.flags.sources["pg-host-bin-dir"])
log.Info(
"pg-docker-container-name",
"value",
c.PgDockerContainerName,
"source",
c.flags.sources["pg-docker-container-name"],
)
log.Info("pg-wal-dir", "value", c.PgWalDir, "source", c.flags.sources["pg-wal-dir"])
log.Info(
"delete-wal-after-upload",
"value",
fmt.Sprintf("%v", *c.IsDeleteWalAfterUpload),
"source",
c.flags.sources["delete-wal-after-upload"],
)
}
func maskSensitive(value string) string {

View File

@@ -86,10 +86,12 @@ func Test_LoadFromJSONAndArgs_PartialArgsOverrideJSON(t *testing.T) {
func Test_SaveToJSON_ConfigSavedCorrectly(t *testing.T) {
setupTempDir(t)
deleteWal := true
cfg := &Config{
DatabasusHost: "http://save-host:4005",
DbID: "save-db-id",
Token: "save-token",
DatabasusHost: "http://save-host:4005",
DbID: "save-db-id",
Token: "save-token",
IsDeleteWalAfterUpload: &deleteWal,
}
err := cfg.SaveToJSON()
@@ -126,6 +128,143 @@ func Test_SaveToJSON_AfterArgsOverrideJSON_SavedFileContainsMergedValues(t *test
assert.Equal(t, "json-token", saved.Token)
}
func Test_LoadFromJSONAndArgs_PgFieldsLoadedFromJSON(t *testing.T) {
dir := setupTempDir(t)
deleteWal := false
writeConfigJSON(t, dir, Config{
DatabasusHost: "http://json-host:4005",
DbID: "json-db-id",
Token: "json-token",
PgHost: "pg-json-host",
PgPort: 5433,
PgUser: "pg-json-user",
PgPassword: "pg-json-pass",
PgType: "docker",
PgHostBinDir: "/usr/bin",
PgDockerContainerName: "pg-container",
PgWalDir: "/opt/wal",
IsDeleteWalAfterUpload: &deleteWal,
})
cfg := &Config{}
fs := flag.NewFlagSet("test", flag.ContinueOnError)
cfg.LoadFromJSONAndArgs(fs, []string{})
assert.Equal(t, "pg-json-host", cfg.PgHost)
assert.Equal(t, 5433, cfg.PgPort)
assert.Equal(t, "pg-json-user", cfg.PgUser)
assert.Equal(t, "pg-json-pass", cfg.PgPassword)
assert.Equal(t, "docker", cfg.PgType)
assert.Equal(t, "/usr/bin", cfg.PgHostBinDir)
assert.Equal(t, "pg-container", cfg.PgDockerContainerName)
assert.Equal(t, "/opt/wal", cfg.PgWalDir)
assert.Equal(t, false, *cfg.IsDeleteWalAfterUpload)
}
func Test_LoadFromJSONAndArgs_PgFieldsLoadedFromArgs(t *testing.T) {
setupTempDir(t)
cfg := &Config{}
fs := flag.NewFlagSet("test", flag.ContinueOnError)
cfg.LoadFromJSONAndArgs(fs, []string{
"--pg-host", "arg-pg-host",
"--pg-port", "5433",
"--pg-user", "arg-pg-user",
"--pg-password", "arg-pg-pass",
"--pg-type", "docker",
"--pg-host-bin-dir", "/custom/bin",
"--pg-docker-container-name", "my-pg",
"--pg-wal-dir", "/var/wal",
})
assert.Equal(t, "arg-pg-host", cfg.PgHost)
assert.Equal(t, 5433, cfg.PgPort)
assert.Equal(t, "arg-pg-user", cfg.PgUser)
assert.Equal(t, "arg-pg-pass", cfg.PgPassword)
assert.Equal(t, "docker", cfg.PgType)
assert.Equal(t, "/custom/bin", cfg.PgHostBinDir)
assert.Equal(t, "my-pg", cfg.PgDockerContainerName)
assert.Equal(t, "/var/wal", cfg.PgWalDir)
}
func Test_LoadFromJSONAndArgs_PgArgsOverrideJSON(t *testing.T) {
dir := setupTempDir(t)
writeConfigJSON(t, dir, Config{
PgHost: "json-host",
PgPort: 5432,
PgUser: "json-user",
PgType: "host",
PgWalDir: "/json/wal",
})
cfg := &Config{}
fs := flag.NewFlagSet("test", flag.ContinueOnError)
cfg.LoadFromJSONAndArgs(fs, []string{
"--pg-host", "arg-host",
"--pg-port", "5433",
"--pg-user", "arg-user",
"--pg-type", "docker",
"--pg-docker-container-name", "my-container",
"--pg-wal-dir", "/arg/wal",
})
assert.Equal(t, "arg-host", cfg.PgHost)
assert.Equal(t, 5433, cfg.PgPort)
assert.Equal(t, "arg-user", cfg.PgUser)
assert.Equal(t, "docker", cfg.PgType)
assert.Equal(t, "my-container", cfg.PgDockerContainerName)
assert.Equal(t, "/arg/wal", cfg.PgWalDir)
}
func Test_LoadFromJSONAndArgs_DefaultsApplied_WhenNoJSONAndNoArgs(t *testing.T) {
setupTempDir(t)
cfg := &Config{}
fs := flag.NewFlagSet("test", flag.ContinueOnError)
cfg.LoadFromJSONAndArgs(fs, []string{})
assert.Equal(t, 5432, cfg.PgPort)
assert.Equal(t, "host", cfg.PgType)
require.NotNil(t, cfg.IsDeleteWalAfterUpload)
assert.Equal(t, true, *cfg.IsDeleteWalAfterUpload)
}
func Test_SaveToJSON_PgFieldsSavedCorrectly(t *testing.T) {
setupTempDir(t)
deleteWal := false
cfg := &Config{
DatabasusHost: "http://host:4005",
DbID: "db-id",
Token: "token",
PgHost: "pg-host",
PgPort: 5433,
PgUser: "pg-user",
PgPassword: "pg-pass",
PgType: "docker",
PgHostBinDir: "/usr/bin",
PgDockerContainerName: "pg-container",
PgWalDir: "/opt/wal",
IsDeleteWalAfterUpload: &deleteWal,
}
err := cfg.SaveToJSON()
require.NoError(t, err)
saved := readConfigJSON(t)
assert.Equal(t, "pg-host", saved.PgHost)
assert.Equal(t, 5433, saved.PgPort)
assert.Equal(t, "pg-user", saved.PgUser)
assert.Equal(t, "pg-pass", saved.PgPassword)
assert.Equal(t, "docker", saved.PgType)
assert.Equal(t, "/usr/bin", saved.PgHostBinDir)
assert.Equal(t, "pg-container", saved.PgDockerContainerName)
assert.Equal(t, "/opt/wal", saved.PgWalDir)
require.NotNil(t, saved.IsDeleteWalAfterUpload)
assert.Equal(t, false, *saved.IsDeleteWalAfterUpload)
}
func setupTempDir(t *testing.T) string {
t.Helper()

View File

@@ -1,9 +1,17 @@
package config
type parsedFlags struct {
host *string
dbID *string
token *string
databasusHost *string
dbID *string
token *string
pgHost *string
pgPort *int
pgUser *string
pgPassword *string
pgType *string
pgHostBinDir *string
pgDockerContainerName *string
pgWalDir *string
sources map[string]string
}

View File

@@ -0,0 +1,376 @@
package api
import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"os"
"time"
"github.com/go-resty/resty/v2"
)
const (
chainValidPath = "/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup"
nextBackupTimePath = "/api/v1/backups/postgres/wal/next-full-backup-time"
walUploadPath = "/api/v1/backups/postgres/wal/upload/wal"
fullStartPath = "/api/v1/backups/postgres/wal/upload/full-start"
fullCompletePath = "/api/v1/backups/postgres/wal/upload/full-complete"
reportErrorPath = "/api/v1/backups/postgres/wal/error"
restorePlanPath = "/api/v1/backups/postgres/wal/restore/plan"
restoreDownloadPath = "/api/v1/backups/postgres/wal/restore/download"
versionPath = "/api/v1/system/version"
agentBinaryPath = "/api/v1/system/agent"
apiCallTimeout = 30 * time.Second
maxRetryAttempts = 3
retryBaseDelay = 1 * time.Second
)
// For stream uploads (basebackup and WAL segments) the standard resty client is not used,
// because it buffers the entire body in memory before sending.
type Client struct {
json *resty.Client
streamHTTP *http.Client
host string
token string
log *slog.Logger
}
func NewClient(host, token string, log *slog.Logger) *Client {
setAuth := func(_ *resty.Client, req *resty.Request) error {
if token != "" {
req.SetHeader("Authorization", token)
}
return nil
}
jsonClient := resty.New().
SetTimeout(apiCallTimeout).
SetRetryCount(maxRetryAttempts - 1).
SetRetryWaitTime(retryBaseDelay).
SetRetryMaxWaitTime(4 * retryBaseDelay).
AddRetryCondition(func(resp *resty.Response, err error) bool {
return err != nil || resp.StatusCode() >= 500
}).
OnBeforeRequest(setAuth)
return &Client{
json: jsonClient,
streamHTTP: &http.Client{},
host: host,
token: token,
log: log,
}
}
func (c *Client) CheckWalChainValidity(ctx context.Context) (*WalChainValidityResponse, error) {
var resp WalChainValidityResponse
httpResp, err := c.json.R().
SetContext(ctx).
SetResult(&resp).
Get(c.buildURL(chainValidPath))
if err != nil {
return nil, err
}
if err := c.checkResponse(httpResp, "check WAL chain validity"); err != nil {
return nil, err
}
return &resp, nil
}
func (c *Client) GetNextFullBackupTime(ctx context.Context) (*NextFullBackupTimeResponse, error) {
var resp NextFullBackupTimeResponse
httpResp, err := c.json.R().
SetContext(ctx).
SetResult(&resp).
Get(c.buildURL(nextBackupTimePath))
if err != nil {
return nil, err
}
if err := c.checkResponse(httpResp, "get next full backup time"); err != nil {
return nil, err
}
return &resp, nil
}
func (c *Client) ReportBackupError(ctx context.Context, errMsg string) error {
httpResp, err := c.json.R().
SetContext(ctx).
SetBody(reportErrorRequest{Error: errMsg}).
Post(c.buildURL(reportErrorPath))
if err != nil {
return err
}
return c.checkResponse(httpResp, "report backup error")
}
func (c *Client) UploadBasebackup(
ctx context.Context,
body io.Reader,
) (*UploadBasebackupResponse, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.buildURL(fullStartPath), body)
if err != nil {
return nil, fmt.Errorf("create upload request: %w", err)
}
c.setStreamHeaders(req)
req.Header.Set("Content-Type", "application/octet-stream")
resp, err := c.streamHTTP.Do(req)
if err != nil {
return nil, fmt.Errorf("upload request: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("upload failed with status %d: %s", resp.StatusCode, string(respBody))
}
var result UploadBasebackupResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("decode upload response: %w", err)
}
return &result, nil
}
func (c *Client) FinalizeBasebackup(
ctx context.Context,
backupID string,
startSegment string,
stopSegment string,
) error {
resp, err := c.json.R().
SetContext(ctx).
SetBody(finalizeBasebackupRequest{
BackupID: backupID,
StartSegment: startSegment,
StopSegment: stopSegment,
}).
Post(c.buildURL(fullCompletePath))
if err != nil {
return fmt.Errorf("finalize request: %w", err)
}
if resp.StatusCode() != http.StatusOK {
return fmt.Errorf("finalize failed with status %d: %s", resp.StatusCode(), resp.String())
}
return nil
}
func (c *Client) FinalizeBasebackupWithError(
ctx context.Context,
backupID string,
errMsg string,
) error {
resp, err := c.json.R().
SetContext(ctx).
SetBody(finalizeBasebackupRequest{
BackupID: backupID,
Error: &errMsg,
}).
Post(c.buildURL(fullCompletePath))
if err != nil {
return fmt.Errorf("finalize-with-error request: %w", err)
}
if resp.StatusCode() != http.StatusOK {
return fmt.Errorf("finalize-with-error failed with status %d: %s", resp.StatusCode(), resp.String())
}
return nil
}
func (c *Client) UploadWalSegment(
ctx context.Context,
segmentName string,
body io.Reader,
) (*UploadWalSegmentResult, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.buildURL(walUploadPath), body)
if err != nil {
return nil, fmt.Errorf("create WAL upload request: %w", err)
}
c.setStreamHeaders(req)
req.Header.Set("Content-Type", "application/octet-stream")
req.Header.Set("X-Wal-Segment-Name", segmentName)
resp, err := c.streamHTTP.Do(req)
if err != nil {
return nil, fmt.Errorf("upload request: %w", err)
}
defer func() { _ = resp.Body.Close() }()
switch resp.StatusCode {
case http.StatusNoContent:
return &UploadWalSegmentResult{IsGapDetected: false}, nil
case http.StatusConflict:
var errResp uploadErrorResponse
if err := json.NewDecoder(resp.Body).Decode(&errResp); err != nil {
return &UploadWalSegmentResult{IsGapDetected: true}, nil
}
return &UploadWalSegmentResult{
IsGapDetected: true,
ExpectedSegmentName: errResp.ExpectedSegmentName,
ReceivedSegmentName: errResp.ReceivedSegmentName,
}, nil
default:
respBody, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("upload failed with status %d: %s", resp.StatusCode, string(respBody))
}
}
func (c *Client) GetRestorePlan(
ctx context.Context,
backupID string,
) (*GetRestorePlanResponse, *GetRestorePlanErrorResponse, error) {
request := c.json.R().SetContext(ctx)
if backupID != "" {
request.SetQueryParam("backupId", backupID)
}
httpResp, err := request.Get(c.buildURL(restorePlanPath))
if err != nil {
return nil, nil, fmt.Errorf("get restore plan: %w", err)
}
switch httpResp.StatusCode() {
case http.StatusOK:
var response GetRestorePlanResponse
if err := json.Unmarshal(httpResp.Body(), &response); err != nil {
return nil, nil, fmt.Errorf("decode restore plan response: %w", err)
}
return &response, nil, nil
case http.StatusBadRequest:
var errorResponse GetRestorePlanErrorResponse
if err := json.Unmarshal(httpResp.Body(), &errorResponse); err != nil {
return nil, nil, fmt.Errorf("decode restore plan error: %w", err)
}
return nil, &errorResponse, nil
default:
return nil, nil, fmt.Errorf("get restore plan: server returned status %d: %s",
httpResp.StatusCode(), httpResp.String())
}
}
func (c *Client) DownloadBackupFile(
ctx context.Context,
backupID string,
) (io.ReadCloser, error) {
requestURL := c.buildURL(restoreDownloadPath) + "?" + url.Values{"backupId": {backupID}}.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
if err != nil {
return nil, fmt.Errorf("create download request: %w", err)
}
c.setStreamHeaders(req)
resp, err := c.streamHTTP.Do(req)
if err != nil {
return nil, fmt.Errorf("download backup file: %w", err)
}
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close()
return nil, fmt.Errorf("download backup file: server returned status %d: %s",
resp.StatusCode, string(respBody))
}
return resp.Body, nil
}
func (c *Client) FetchServerVersion(ctx context.Context) (string, error) {
var ver versionResponse
httpResp, err := c.json.R().
SetContext(ctx).
SetResult(&ver).
Get(c.buildURL(versionPath))
if err != nil {
return "", err
}
if err := c.checkResponse(httpResp, "fetch server version"); err != nil {
return "", err
}
return ver.Version, nil
}
func (c *Client) DownloadAgentBinary(ctx context.Context, arch, destPath string) error {
requestURL := c.buildURL(agentBinaryPath) + "?" + url.Values{"arch": {arch}}.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
if err != nil {
return fmt.Errorf("create agent download request: %w", err)
}
c.setStreamHeaders(req)
resp, err := c.streamHTTP.Do(req)
if err != nil {
return err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("server returned %d for agent download", resp.StatusCode)
}
file, err := os.Create(destPath)
if err != nil {
return err
}
defer func() { _ = file.Close() }()
_, err = io.Copy(file, resp.Body)
return err
}
func (c *Client) buildURL(path string) string {
return c.host + path
}
func (c *Client) checkResponse(resp *resty.Response, method string) error {
if resp.StatusCode() >= 400 {
return fmt.Errorf("%s: server returned status %d: %s", method, resp.StatusCode(), resp.String())
}
return nil
}
func (c *Client) setStreamHeaders(req *http.Request) {
if c.token != "" {
req.Header.Set("Authorization", c.token)
}
}

View File

@@ -0,0 +1,72 @@
package api
import "time"
type WalChainValidityResponse struct {
IsValid bool `json:"isValid"`
Error string `json:"error,omitempty"`
LastContiguousSegment string `json:"lastContiguousSegment,omitempty"`
}
type NextFullBackupTimeResponse struct {
NextFullBackupTime *time.Time `json:"nextFullBackupTime"`
}
type UploadWalSegmentResult struct {
IsGapDetected bool
ExpectedSegmentName string
ReceivedSegmentName string
}
type reportErrorRequest struct {
Error string `json:"error"`
}
type versionResponse struct {
Version string `json:"version"`
}
type UploadBasebackupResponse struct {
BackupID string `json:"backupId"`
}
type finalizeBasebackupRequest struct {
BackupID string `json:"backupId"`
StartSegment string `json:"startSegment"`
StopSegment string `json:"stopSegment"`
Error *string `json:"error,omitempty"`
}
type uploadErrorResponse struct {
Error string `json:"error"`
ExpectedSegmentName string `json:"expectedSegmentName"`
ReceivedSegmentName string `json:"receivedSegmentName"`
}
type RestorePlanFullBackup struct {
BackupID string `json:"id"`
FullBackupWalStartSegment string `json:"fullBackupWalStartSegment"`
FullBackupWalStopSegment string `json:"fullBackupWalStopSegment"`
PgVersion string `json:"pgVersion"`
CreatedAt time.Time `json:"createdAt"`
SizeBytes int64 `json:"sizeBytes"`
}
type RestorePlanWalSegment struct {
BackupID string `json:"backupId"`
SegmentName string `json:"segmentName"`
SizeBytes int64 `json:"sizeBytes"`
}
type GetRestorePlanResponse struct {
FullBackup RestorePlanFullBackup `json:"fullBackup"`
WalSegments []RestorePlanWalSegment `json:"walSegments"`
TotalSizeBytes int64 `json:"totalSizeBytes"`
LatestAvailableSegment string `json:"latestAvailableSegment"`
}
type GetRestorePlanErrorResponse struct {
Error string `json:"error"`
Message string `json:"message"`
LastContiguousSegment string `json:"lastContiguousSegment,omitempty"`
}

View File

@@ -0,0 +1,60 @@
package api
import (
"context"
"fmt"
"io"
"time"
)
// IdleTimeoutReader wraps an io.Reader and cancels the associated context
// if no bytes are successfully read within the specified timeout duration.
// This detects stalled uploads where the network or source stops transmitting data.
//
// When the idle timeout fires, the reader is also closed (if it implements io.Closer)
// to unblock any goroutine blocked on the underlying Read.
type IdleTimeoutReader struct {
reader io.Reader
timeout time.Duration
cancel context.CancelCauseFunc
timer *time.Timer
}
// NewIdleTimeoutReader creates a reader that cancels the context via cancel
// if Read does not return any bytes for the given timeout duration.
func NewIdleTimeoutReader(reader io.Reader, timeout time.Duration, cancel context.CancelCauseFunc) *IdleTimeoutReader {
r := &IdleTimeoutReader{
reader: reader,
timeout: timeout,
cancel: cancel,
}
r.timer = time.AfterFunc(timeout, func() {
cancel(fmt.Errorf("upload idle timeout: no bytes transmitted for %v", timeout))
if closer, ok := reader.(io.Closer); ok {
_ = closer.Close()
}
})
return r
}
func (r *IdleTimeoutReader) Read(p []byte) (int, error) {
n, err := r.reader.Read(p)
if n > 0 {
r.timer.Reset(r.timeout)
}
if err != nil && err != io.EOF {
r.Stop()
}
return n, err
}
// Stop cancels the idle timer. Must be called when the reader is no longer needed.
func (r *IdleTimeoutReader) Stop() {
r.timer.Stop()
}

View File

@@ -0,0 +1,112 @@
package api
import (
"context"
"fmt"
"io"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_ReadThroughIdleTimeoutReader_WhenBytesFlowContinuously_DoesNotCancelContext(t *testing.T) {
ctx, cancel := context.WithCancelCause(t.Context())
defer cancel(nil)
pr, pw := io.Pipe()
idleReader := NewIdleTimeoutReader(pr, 200*time.Millisecond, cancel)
defer idleReader.Stop()
go func() {
for range 5 {
_, _ = pw.Write([]byte("data"))
time.Sleep(50 * time.Millisecond)
}
_ = pw.Close()
}()
data, err := io.ReadAll(idleReader)
require.NoError(t, err)
assert.Equal(t, "datadatadatadatadata", string(data))
assert.NoError(t, ctx.Err(), "context should not be cancelled when bytes flow continuously")
}
func Test_ReadThroughIdleTimeoutReader_WhenNoBytesTransmitted_CancelsContext(t *testing.T) {
ctx, cancel := context.WithCancelCause(t.Context())
defer cancel(nil)
pr, _ := io.Pipe()
idleReader := NewIdleTimeoutReader(pr, 100*time.Millisecond, cancel)
defer idleReader.Stop()
time.Sleep(200 * time.Millisecond)
assert.Error(t, ctx.Err(), "context should be cancelled when no bytes are transmitted")
assert.Contains(t, context.Cause(ctx).Error(), "upload idle timeout")
}
func Test_ReadThroughIdleTimeoutReader_WhenBytesStopMidStream_CancelsContext(t *testing.T) {
ctx, cancel := context.WithCancelCause(t.Context())
defer cancel(nil)
pr, pw := io.Pipe()
idleReader := NewIdleTimeoutReader(pr, 100*time.Millisecond, cancel)
defer idleReader.Stop()
go func() {
_, _ = pw.Write([]byte("initial"))
// Stop writing — simulate stalled source
}()
buf := make([]byte, 1024)
n, _ := idleReader.Read(buf)
assert.Equal(t, "initial", string(buf[:n]))
time.Sleep(200 * time.Millisecond)
assert.Error(t, ctx.Err(), "context should be cancelled when bytes stop mid-stream")
assert.Contains(t, context.Cause(ctx).Error(), "upload idle timeout")
}
func Test_StopIdleTimeoutReader_WhenCalledBeforeTimeout_DoesNotCancelContext(t *testing.T) {
ctx, cancel := context.WithCancelCause(t.Context())
defer cancel(nil)
pr, _ := io.Pipe()
idleReader := NewIdleTimeoutReader(pr, 100*time.Millisecond, cancel)
idleReader.Stop()
time.Sleep(200 * time.Millisecond)
assert.NoError(t, ctx.Err(), "context should not be cancelled when reader is stopped before timeout")
}
func Test_ReadThroughIdleTimeoutReader_WhenReaderReturnsError_PropagatesError(t *testing.T) {
ctx, cancel := context.WithCancelCause(t.Context())
defer cancel(nil)
pr, pw := io.Pipe()
idleReader := NewIdleTimeoutReader(pr, 5*time.Second, cancel)
defer idleReader.Stop()
expectedErr := fmt.Errorf("test read error")
_ = pw.CloseWithError(expectedErr)
buf := make([]byte, 1024)
_, err := idleReader.Read(buf)
assert.ErrorIs(t, err, expectedErr)
// Timer should be stopped after error — context should not be cancelled
time.Sleep(100 * time.Millisecond)
assert.NoError(t, ctx.Err(), "context should not be cancelled after reader error stops the timer")
}

View File

@@ -0,0 +1,316 @@
package full_backup
import (
"bytes"
"context"
"fmt"
"io"
"log/slog"
"os"
"os/exec"
"path/filepath"
"sync/atomic"
"time"
"github.com/klauspost/compress/zstd"
"databasus-agent/internal/config"
"databasus-agent/internal/features/api"
)
const (
checkInterval = 30 * time.Second
retryDelay = 1 * time.Minute
uploadTimeout = 23 * time.Hour
)
var uploadIdleTimeout = 5 * time.Minute
var retryDelayOverride *time.Duration
type CmdBuilder func(ctx context.Context) *exec.Cmd
// FullBackuper runs pg_basebackup when the WAL chain is broken or a scheduled backup is due.
//
// Every 30 seconds it checks two conditions via the Databasus API:
// 1. WAL chain validity — if broken or no full backup exists, triggers an immediate basebackup.
// 2. Scheduled backup time — if the next full backup time has passed, triggers a basebackup.
//
// Only one basebackup runs at a time (guarded by atomic bool).
// On failure the error is reported to the server and the backup retries after 1 minute, indefinitely.
// WAL segment uploads (handled by wal.Streamer) continue independently and are not paused.
//
// pg_basebackup runs as "pg_basebackup -Ft -D - -X fetch --verbose --checkpoint=fast".
// Stdout (tar) is zstd-compressed and uploaded to the server.
// Stderr is parsed for WAL start/stop segment names (LSN → segment arithmetic).
type FullBackuper struct {
cfg *config.Config
apiClient *api.Client
log *slog.Logger
isRunning atomic.Bool
cmdBuilder CmdBuilder
}
func NewFullBackuper(cfg *config.Config, apiClient *api.Client, log *slog.Logger) *FullBackuper {
backuper := &FullBackuper{
cfg: cfg,
apiClient: apiClient,
log: log,
}
backuper.cmdBuilder = backuper.defaultCmdBuilder
return backuper
}
func (backuper *FullBackuper) Run(ctx context.Context) {
backuper.log.Info("Full backuper started")
backuper.checkAndRunIfNeeded(ctx)
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
backuper.log.Info("Full backuper stopping")
return
case <-ticker.C:
backuper.checkAndRunIfNeeded(ctx)
}
}
}
func (backuper *FullBackuper) checkAndRunIfNeeded(ctx context.Context) {
if backuper.isRunning.Load() {
backuper.log.Debug("Skipping check: basebackup already in progress")
return
}
chainResp, err := backuper.apiClient.CheckWalChainValidity(ctx)
if err != nil {
backuper.log.Error("Failed to check WAL chain validity", "error", err)
return
}
if !chainResp.IsValid {
backuper.log.Info("WAL chain is invalid, triggering basebackup",
"error", chainResp.Error,
"lastContiguousSegment", chainResp.LastContiguousSegment,
)
backuper.runBasebackupWithRetry(ctx)
return
}
nextTimeResp, err := backuper.apiClient.GetNextFullBackupTime(ctx)
if err != nil {
backuper.log.Error("Failed to check next full backup time", "error", err)
return
}
if nextTimeResp.NextFullBackupTime == nil || !nextTimeResp.NextFullBackupTime.After(time.Now().UTC()) {
backuper.log.Info("Scheduled full backup is due, triggering basebackup")
backuper.runBasebackupWithRetry(ctx)
return
}
backuper.log.Debug("No basebackup needed",
"nextFullBackupTime", nextTimeResp.NextFullBackupTime,
)
}
func (backuper *FullBackuper) runBasebackupWithRetry(ctx context.Context) {
if !backuper.isRunning.CompareAndSwap(false, true) {
backuper.log.Debug("Skipping basebackup: already running")
return
}
defer backuper.isRunning.Store(false)
for {
if ctx.Err() != nil {
return
}
backuper.log.Info("Starting pg_basebackup")
err := backuper.executeAndUploadBasebackup(ctx)
if err == nil {
backuper.log.Info("Basebackup completed successfully")
return
}
backuper.log.Error("Basebackup failed", "error", err)
backuper.reportError(ctx, err.Error())
delay := retryDelay
if retryDelayOverride != nil {
delay = *retryDelayOverride
}
backuper.log.Info("Retrying basebackup after delay", "delay", delay)
select {
case <-ctx.Done():
return
case <-time.After(delay):
}
}
}
func (backuper *FullBackuper) executeAndUploadBasebackup(ctx context.Context) error {
cmd := backuper.cmdBuilder(ctx)
var stderrBuf bytes.Buffer
cmd.Stderr = &stderrBuf
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("create stdout pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("start pg_basebackup: %w", err)
}
// Phase 1: Stream compressed data via io.Pipe directly to the API.
pipeReader, pipeWriter := io.Pipe()
defer func() { _ = pipeReader.Close() }()
go backuper.compressAndStream(pipeWriter, stdoutPipe)
uploadCtx, timeoutCancel := context.WithTimeout(ctx, uploadTimeout)
defer timeoutCancel()
idleCtx, idleCancel := context.WithCancelCause(uploadCtx)
defer idleCancel(nil)
idleReader := api.NewIdleTimeoutReader(pipeReader, uploadIdleTimeout, idleCancel)
defer idleReader.Stop()
uploadResp, uploadErr := backuper.apiClient.UploadBasebackup(idleCtx, idleReader)
if uploadErr != nil && cmd.Process != nil {
_ = cmd.Process.Kill()
}
cmdErr := cmd.Wait()
if uploadErr != nil {
if cause := context.Cause(idleCtx); cause != nil {
uploadErr = cause
}
stderrStr := stderrBuf.String()
if stderrStr != "" {
return fmt.Errorf("upload basebackup: %w (pg_basebackup stderr: %s)", uploadErr, stderrStr)
}
return fmt.Errorf("upload basebackup: %w", uploadErr)
}
if cmdErr != nil {
errMsg := fmt.Sprintf("pg_basebackup exited with error: %v (stderr: %s)", cmdErr, stderrBuf.String())
_ = backuper.apiClient.FinalizeBasebackupWithError(ctx, uploadResp.BackupID, errMsg)
return fmt.Errorf("%s", errMsg)
}
// Phase 2: Parse stderr for WAL segments and finalize the backup.
stderrStr := stderrBuf.String()
backuper.log.Debug("pg_basebackup stderr", "stderr", stderrStr)
startSegment, stopSegment, err := ParseBasebackupStderr(stderrStr)
if err != nil {
errMsg := fmt.Sprintf("parse pg_basebackup stderr: %v", err)
_ = backuper.apiClient.FinalizeBasebackupWithError(ctx, uploadResp.BackupID, errMsg)
return fmt.Errorf("parse pg_basebackup stderr: %w", err)
}
backuper.log.Info("Basebackup WAL segments parsed",
"startSegment", startSegment,
"stopSegment", stopSegment,
"backupId", uploadResp.BackupID,
)
if err := backuper.apiClient.FinalizeBasebackup(ctx, uploadResp.BackupID, startSegment, stopSegment); err != nil {
return fmt.Errorf("finalize basebackup: %w", err)
}
return nil
}
func (backuper *FullBackuper) compressAndStream(pipeWriter *io.PipeWriter, reader io.Reader) {
encoder, err := zstd.NewWriter(pipeWriter,
zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(5)),
zstd.WithEncoderCRC(true),
)
if err != nil {
_ = pipeWriter.CloseWithError(fmt.Errorf("create zstd encoder: %w", err))
return
}
if _, err := io.Copy(encoder, reader); err != nil {
_ = encoder.Close()
_ = pipeWriter.CloseWithError(fmt.Errorf("compress: %w", err))
return
}
if err := encoder.Close(); err != nil {
_ = pipeWriter.CloseWithError(fmt.Errorf("close encoder: %w", err))
return
}
_ = pipeWriter.Close()
}
func (backuper *FullBackuper) reportError(ctx context.Context, errMsg string) {
if err := backuper.apiClient.ReportBackupError(ctx, errMsg); err != nil {
backuper.log.Error("Failed to report error to server", "error", err)
}
}
func (backuper *FullBackuper) defaultCmdBuilder(ctx context.Context) *exec.Cmd {
switch backuper.cfg.PgType {
case "docker":
return backuper.buildDockerCmd(ctx)
default:
return backuper.buildHostCmd(ctx)
}
}
func (backuper *FullBackuper) buildHostCmd(ctx context.Context) *exec.Cmd {
binary := "pg_basebackup"
if backuper.cfg.PgHostBinDir != "" {
binary = filepath.Join(backuper.cfg.PgHostBinDir, "pg_basebackup")
}
cmd := exec.CommandContext(ctx, binary,
"-Ft", "-D", "-", "-X", "fetch", "--verbose", "--checkpoint=fast",
"-h", backuper.cfg.PgHost,
"-p", fmt.Sprintf("%d", backuper.cfg.PgPort),
"-U", backuper.cfg.PgUser,
)
cmd.Env = append(os.Environ(), "PGPASSWORD="+backuper.cfg.PgPassword)
return cmd
}
func (backuper *FullBackuper) buildDockerCmd(ctx context.Context) *exec.Cmd {
cmd := exec.CommandContext(ctx, "docker", "exec",
"-e", "PGPASSWORD="+backuper.cfg.PgPassword,
"-i", backuper.cfg.PgDockerContainerName,
"pg_basebackup",
"-Ft", "-D", "-", "-X", "fetch", "--verbose", "--checkpoint=fast",
"-h", "localhost",
"-p", "5432",
"-U", backuper.cfg.PgUser,
)
return cmd
}

View File

@@ -0,0 +1,735 @@
package full_backup
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"os/exec"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/klauspost/compress/zstd"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"databasus-agent/internal/config"
"databasus-agent/internal/features/api"
"databasus-agent/internal/logger"
)
const (
testChainValidPath = "/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup"
testNextBackupTimePath = "/api/v1/backups/postgres/wal/next-full-backup-time"
testFullStartPath = "/api/v1/backups/postgres/wal/upload/full-start"
testFullCompletePath = "/api/v1/backups/postgres/wal/upload/full-complete"
testReportErrorPath = "/api/v1/backups/postgres/wal/error"
testBackupID = "test-backup-id-1234"
)
func Test_RunFullBackup_WhenChainBroken_BasebackupTriggered(t *testing.T) {
var mu sync.Mutex
var uploadReceived bool
var uploadHeaders http.Header
var finalizeReceived bool
var finalizeBody map[string]any
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testChainValidPath:
writeJSON(w, api.WalChainValidityResponse{
IsValid: false,
Error: "wal_chain_broken",
LastContiguousSegment: "000000010000000100000011",
})
case testFullStartPath:
mu.Lock()
uploadReceived = true
uploadHeaders = r.Header.Clone()
mu.Unlock()
_, _ = io.ReadAll(r.Body)
writeJSON(w, map[string]string{"backupId": testBackupID})
case testFullCompletePath:
mu.Lock()
finalizeReceived = true
_ = json.NewDecoder(r.Body).Decode(&finalizeBody)
mu.Unlock()
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
})
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, "test-backup-data", validStderr())
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
defer cancel()
go fb.Run(ctx)
waitForCondition(t, func() bool {
mu.Lock()
defer mu.Unlock()
return finalizeReceived
}, 5*time.Second)
cancel()
mu.Lock()
defer mu.Unlock()
assert.True(t, uploadReceived)
assert.Equal(t, "application/octet-stream", uploadHeaders.Get("Content-Type"))
assert.Equal(t, "test-token", uploadHeaders.Get("Authorization"))
assert.True(t, finalizeReceived)
assert.Equal(t, testBackupID, finalizeBody["backupId"])
assert.Equal(t, "000000010000000000000002", finalizeBody["startSegment"])
assert.Equal(t, "000000010000000000000002", finalizeBody["stopSegment"])
}
func Test_RunFullBackup_WhenScheduledBackupDue_BasebackupTriggered(t *testing.T) {
var mu sync.Mutex
var finalizeReceived bool
pastTime := time.Now().UTC().Add(-1 * time.Hour)
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testChainValidPath:
writeJSON(w, api.WalChainValidityResponse{IsValid: true})
case testNextBackupTimePath:
writeJSON(w, api.NextFullBackupTimeResponse{NextFullBackupTime: &pastTime})
case testFullStartPath:
_, _ = io.ReadAll(r.Body)
writeJSON(w, map[string]string{"backupId": testBackupID})
case testFullCompletePath:
mu.Lock()
finalizeReceived = true
mu.Unlock()
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
})
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, "scheduled-backup-data", validStderr())
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
defer cancel()
go fb.Run(ctx)
waitForCondition(t, func() bool {
mu.Lock()
defer mu.Unlock()
return finalizeReceived
}, 5*time.Second)
cancel()
mu.Lock()
defer mu.Unlock()
assert.True(t, finalizeReceived)
}
func Test_RunFullBackup_WhenNoFullBackupExists_ImmediateBasebackupTriggered(t *testing.T) {
var mu sync.Mutex
var finalizeReceived bool
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testChainValidPath:
writeJSON(w, api.WalChainValidityResponse{
IsValid: false,
Error: "no_full_backup",
})
case testFullStartPath:
_, _ = io.ReadAll(r.Body)
writeJSON(w, map[string]string{"backupId": testBackupID})
case testFullCompletePath:
mu.Lock()
finalizeReceived = true
mu.Unlock()
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
})
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, "first-backup-data", validStderr())
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
defer cancel()
go fb.Run(ctx)
waitForCondition(t, func() bool {
mu.Lock()
defer mu.Unlock()
return finalizeReceived
}, 5*time.Second)
cancel()
mu.Lock()
defer mu.Unlock()
assert.True(t, finalizeReceived)
}
func Test_RunFullBackup_WhenUploadFails_RetriesAfterDelay(t *testing.T) {
var mu sync.Mutex
var uploadAttempts int
var errorReported bool
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testChainValidPath:
writeJSON(w, api.WalChainValidityResponse{
IsValid: false,
Error: "no_full_backup",
})
case testFullStartPath:
_, _ = io.ReadAll(r.Body)
mu.Lock()
uploadAttempts++
attempt := uploadAttempts
mu.Unlock()
if attempt == 1 {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(`{"error":"storage unavailable"}`))
return
}
writeJSON(w, map[string]string{"backupId": testBackupID})
case testFullCompletePath:
w.WriteHeader(http.StatusOK)
case testReportErrorPath:
mu.Lock()
errorReported = true
mu.Unlock()
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
})
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, "retry-backup-data", validStderr())
origRetryDelay := retryDelay
setRetryDelay(100 * time.Millisecond)
defer setRetryDelay(origRetryDelay)
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
defer cancel()
go fb.Run(ctx)
waitForCondition(t, func() bool {
mu.Lock()
defer mu.Unlock()
return uploadAttempts >= 2
}, 10*time.Second)
cancel()
mu.Lock()
defer mu.Unlock()
assert.GreaterOrEqual(t, uploadAttempts, 2)
assert.True(t, errorReported)
}
func Test_RunFullBackup_WhenAlreadyRunning_SkipsExecution(t *testing.T) {
var mu sync.Mutex
var uploadCount int
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testChainValidPath:
writeJSON(w, api.WalChainValidityResponse{
IsValid: false,
Error: "no_full_backup",
})
case testFullStartPath:
_, _ = io.ReadAll(r.Body)
mu.Lock()
uploadCount++
mu.Unlock()
writeJSON(w, map[string]string{"backupId": testBackupID})
case testFullCompletePath:
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
})
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr())
fb.isRunning.Store(true)
fb.checkAndRunIfNeeded(t.Context())
mu.Lock()
count := uploadCount
mu.Unlock()
assert.Equal(t, 0, count, "should not trigger backup when already running")
}
func Test_RunFullBackup_WhenContextCancelled_StopsCleanly(t *testing.T) {
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testChainValidPath:
writeJSON(w, api.WalChainValidityResponse{
IsValid: false,
Error: "no_full_backup",
})
case testFullStartPath:
_, _ = io.ReadAll(r.Body)
w.WriteHeader(http.StatusInternalServerError)
case testFullCompletePath:
w.WriteHeader(http.StatusOK)
case testReportErrorPath:
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
})
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr())
origRetryDelay := retryDelay
setRetryDelay(5 * time.Second)
defer setRetryDelay(origRetryDelay)
ctx, cancel := context.WithTimeout(t.Context(), 500*time.Millisecond)
defer cancel()
done := make(chan struct{})
go func() {
fb.Run(ctx)
close(done)
}()
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("Run should have stopped after context cancellation")
}
}
func Test_RunFullBackup_WhenChainValidAndNotScheduled_NoBasebackupTriggered(t *testing.T) {
var uploadReceived atomic.Bool
futureTime := time.Now().UTC().Add(24 * time.Hour)
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testChainValidPath:
writeJSON(w, api.WalChainValidityResponse{IsValid: true})
case testNextBackupTimePath:
writeJSON(w, api.NextFullBackupTimeResponse{NextFullBackupTime: &futureTime})
case testFullStartPath:
uploadReceived.Store(true)
_, _ = io.ReadAll(r.Body)
writeJSON(w, map[string]string{"backupId": testBackupID})
case testFullCompletePath:
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
})
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr())
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
defer cancel()
go fb.Run(ctx)
time.Sleep(500 * time.Millisecond)
cancel()
assert.False(t, uploadReceived.Load(), "should not trigger backup when chain valid and not scheduled")
}
func Test_RunFullBackup_WhenStderrParsingFails_FinalizesWithErrorAndRetries(t *testing.T) {
var mu sync.Mutex
var errorReported bool
var finalizeWithErrorReceived bool
var finalizeBody map[string]any
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testChainValidPath:
writeJSON(w, api.WalChainValidityResponse{
IsValid: false,
Error: "no_full_backup",
})
case testFullStartPath:
_, _ = io.ReadAll(r.Body)
writeJSON(w, map[string]string{"backupId": testBackupID})
case testFullCompletePath:
mu.Lock()
finalizeWithErrorReceived = true
_ = json.NewDecoder(r.Body).Decode(&finalizeBody)
mu.Unlock()
w.WriteHeader(http.StatusOK)
case testReportErrorPath:
mu.Lock()
errorReported = true
mu.Unlock()
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
})
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, "data", "pg_basebackup: unexpected output with no LSN info")
origRetryDelay := retryDelay
setRetryDelay(100 * time.Millisecond)
defer setRetryDelay(origRetryDelay)
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
defer cancel()
go fb.Run(ctx)
waitForCondition(t, func() bool {
mu.Lock()
defer mu.Unlock()
return errorReported
}, 2*time.Second)
cancel()
mu.Lock()
defer mu.Unlock()
assert.True(t, errorReported)
assert.True(t, finalizeWithErrorReceived, "should finalize with error when stderr parsing fails")
assert.Equal(t, testBackupID, finalizeBody["backupId"])
assert.NotNil(t, finalizeBody["error"], "finalize should include error message")
}
func Test_RunFullBackup_WhenNextBackupTimeNull_BasebackupTriggered(t *testing.T) {
var mu sync.Mutex
var finalizeReceived bool
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testChainValidPath:
writeJSON(w, api.WalChainValidityResponse{IsValid: true})
case testNextBackupTimePath:
writeJSON(w, api.NextFullBackupTimeResponse{NextFullBackupTime: nil})
case testFullStartPath:
_, _ = io.ReadAll(r.Body)
writeJSON(w, map[string]string{"backupId": testBackupID})
case testFullCompletePath:
mu.Lock()
finalizeReceived = true
mu.Unlock()
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
})
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, "first-run-data", validStderr())
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
defer cancel()
go fb.Run(ctx)
waitForCondition(t, func() bool {
mu.Lock()
defer mu.Unlock()
return finalizeReceived
}, 5*time.Second)
cancel()
mu.Lock()
defer mu.Unlock()
assert.True(t, finalizeReceived)
}
func Test_RunFullBackup_WhenChainValidityReturns401_NoBasebackupTriggered(t *testing.T) {
var uploadReceived atomic.Bool
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testChainValidPath:
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"error":"invalid token"}`))
case testFullStartPath:
uploadReceived.Store(true)
_, _ = io.ReadAll(r.Body)
writeJSON(w, map[string]string{"backupId": testBackupID})
case testFullCompletePath:
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
})
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr())
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
defer cancel()
go fb.Run(ctx)
time.Sleep(500 * time.Millisecond)
cancel()
assert.False(t, uploadReceived.Load(), "should not trigger backup when API returns 401")
}
func Test_RunFullBackup_WhenUploadSucceeds_BodyIsZstdCompressed(t *testing.T) {
var mu sync.Mutex
var receivedBody []byte
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testChainValidPath:
writeJSON(w, api.WalChainValidityResponse{
IsValid: false,
Error: "no_full_backup",
})
case testFullStartPath:
body, _ := io.ReadAll(r.Body)
mu.Lock()
receivedBody = body
mu.Unlock()
writeJSON(w, map[string]string{"backupId": testBackupID})
case testFullCompletePath:
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
})
originalContent := "test-backup-content-for-compression-check"
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, originalContent, validStderr())
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
defer cancel()
go fb.Run(ctx)
waitForCondition(t, func() bool {
mu.Lock()
defer mu.Unlock()
return len(receivedBody) > 0
}, 5*time.Second)
cancel()
mu.Lock()
body := receivedBody
mu.Unlock()
decoder, err := zstd.NewReader(nil)
require.NoError(t, err)
defer decoder.Close()
decompressed, err := decoder.DecodeAll(body, nil)
require.NoError(t, err)
assert.Equal(t, originalContent, string(decompressed))
}
func Test_RunFullBackup_WhenUploadStalls_FailsWithIdleTimeout(t *testing.T) {
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testFullStartPath:
// Server reads body normally — it will block until connection is closed
_, _ = io.ReadAll(r.Body)
writeJSON(w, map[string]string{"backupId": testBackupID})
default:
w.WriteHeader(http.StatusNotFound)
}
})
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = stallingCmdBuilder(t)
origIdleTimeout := uploadIdleTimeout
uploadIdleTimeout = 200 * time.Millisecond
defer func() { uploadIdleTimeout = origIdleTimeout }()
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
defer cancel()
err := fb.executeAndUploadBasebackup(ctx)
assert.Error(t, err)
assert.Contains(t, err.Error(), "idle timeout", "error should mention idle timeout")
}
func stallingCmdBuilder(t *testing.T) CmdBuilder {
t.Helper()
return func(ctx context.Context) *exec.Cmd {
cmd := exec.CommandContext(ctx, os.Args[0],
"-test.run=TestHelperProcessStalling",
"--",
)
cmd.Env = append(os.Environ(), "GO_TEST_HELPER_PROCESS_STALLING=1")
return cmd
}
}
func TestHelperProcessStalling(t *testing.T) {
if os.Getenv("GO_TEST_HELPER_PROCESS_STALLING") != "1" {
return
}
// Write enough data to flush through the zstd encoder's internal buffer (~128KB blocks).
// Without enough data, zstd buffers everything and the pipe never receives bytes.
data := make([]byte, 256*1024)
for i := range data {
data[i] = byte(i)
}
_, _ = os.Stdout.Write(data)
// Stall with stdout open — the compress goroutine blocks on its next read.
// The parent process will kill us when the context is cancelled.
time.Sleep(time.Hour)
os.Exit(0)
}
func newTestServer(t *testing.T, handler http.HandlerFunc) *httptest.Server {
t.Helper()
server := httptest.NewServer(handler)
t.Cleanup(server.Close)
return server
}
func newTestFullBackuper(serverURL string) *FullBackuper {
cfg := &config.Config{
DatabasusHost: serverURL,
DbID: "test-db-id",
Token: "test-token",
PgHost: "localhost",
PgPort: 5432,
PgUser: "postgres",
PgPassword: "password",
PgType: "host",
}
apiClient := api.NewClient(serverURL, cfg.Token, logger.GetLogger())
return NewFullBackuper(cfg, apiClient, logger.GetLogger())
}
func mockCmdBuilder(t *testing.T, stdoutContent, stderrContent string) CmdBuilder {
t.Helper()
return func(ctx context.Context) *exec.Cmd {
cmd := exec.CommandContext(ctx, os.Args[0],
"-test.run=TestHelperProcess",
"--",
stdoutContent,
stderrContent,
)
cmd.Env = append(os.Environ(), "GO_TEST_HELPER_PROCESS=1")
return cmd
}
}
func TestHelperProcess(t *testing.T) {
if os.Getenv("GO_TEST_HELPER_PROCESS") != "1" {
return
}
args := os.Args
for i, arg := range args {
if arg == "--" {
args = args[i+1:]
break
}
}
if len(args) >= 1 {
_, _ = fmt.Fprint(os.Stdout, args[0])
}
if len(args) >= 2 {
_, _ = fmt.Fprint(os.Stderr, args[1])
}
os.Exit(0)
}
func validStderr() string {
return `pg_basebackup: initiating base backup, waiting for checkpoint to complete
pg_basebackup: checkpoint completed
pg_basebackup: write-ahead log start point: 0/2000028 on timeline 1
pg_basebackup: starting background WAL receiver
pg_basebackup: write-ahead log end point: 0/2000100
pg_basebackup: waiting for background process to finish streaming ...
pg_basebackup: syncing data to disk ...
pg_basebackup: base backup completed`
}
func writeJSON(w http.ResponseWriter, v any) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(v); err != nil {
w.WriteHeader(http.StatusInternalServerError)
}
}
func waitForCondition(t *testing.T, condition func() bool, timeout time.Duration) {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if condition() {
return
}
time.Sleep(50 * time.Millisecond)
}
t.Fatalf("condition not met within %v", timeout)
}
func setRetryDelay(d time.Duration) {
retryDelayOverride = &d
}
func init() {
retryDelayOverride = nil
}

View File

@@ -0,0 +1,75 @@
package full_backup
import (
"fmt"
"regexp"
"strconv"
"strings"
)
const defaultWalSegmentSize uint32 = 16 * 1024 * 1024 // 16 MB
var (
startLSNRegex = regexp.MustCompile(`write-ahead log start point: ([0-9A-Fa-f]+/[0-9A-Fa-f]+)`)
stopLSNRegex = regexp.MustCompile(`write-ahead log end point: ([0-9A-Fa-f]+/[0-9A-Fa-f]+)`)
)
func ParseBasebackupStderr(stderr string) (startSegment, stopSegment string, err error) {
startMatch := startLSNRegex.FindStringSubmatch(stderr)
if len(startMatch) < 2 {
return "", "", fmt.Errorf("failed to parse start WAL location from pg_basebackup stderr")
}
stopMatch := stopLSNRegex.FindStringSubmatch(stderr)
if len(stopMatch) < 2 {
return "", "", fmt.Errorf("failed to parse stop WAL location from pg_basebackup stderr")
}
startSegment, err = LSNToSegmentName(startMatch[1], 1, defaultWalSegmentSize)
if err != nil {
return "", "", fmt.Errorf("failed to convert start LSN to segment name: %w", err)
}
stopSegment, err = LSNToSegmentName(stopMatch[1], 1, defaultWalSegmentSize)
if err != nil {
return "", "", fmt.Errorf("failed to convert stop LSN to segment name: %w", err)
}
return startSegment, stopSegment, nil
}
func LSNToSegmentName(lsn string, timelineID, walSegmentSize uint32) (string, error) {
high, low, err := parseLSN(lsn)
if err != nil {
return "", err
}
segmentsPerXLogID := uint32(0x100000000 / uint64(walSegmentSize))
logID := high
segmentOffset := low / walSegmentSize
if segmentOffset >= segmentsPerXLogID {
return "", fmt.Errorf("segment offset %d exceeds segments per XLogId %d", segmentOffset, segmentsPerXLogID)
}
return fmt.Sprintf("%08X%08X%08X", timelineID, logID, segmentOffset), nil
}
func parseLSN(lsn string) (high, low uint32, err error) {
parts := strings.SplitN(lsn, "/", 2)
if len(parts) != 2 {
return 0, 0, fmt.Errorf("invalid LSN format: %q (expected X/Y)", lsn)
}
highVal, err := strconv.ParseUint(parts[0], 16, 32)
if err != nil {
return 0, 0, fmt.Errorf("invalid LSN high part %q: %w", parts[0], err)
}
lowVal, err := strconv.ParseUint(parts[1], 16, 32)
if err != nil {
return 0, 0, fmt.Errorf("invalid LSN low part %q: %w", parts[1], err)
}
return uint32(highVal), uint32(lowVal), nil
}

View File

@@ -0,0 +1,157 @@
package full_backup
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_ParseBasebackupStderr_WithPG17FetchOutput_ExtractsCorrectSegments(t *testing.T) {
stderr := `pg_basebackup: initiating base backup, waiting for checkpoint to complete
pg_basebackup: checkpoint completed
pg_basebackup: write-ahead log start point: 0/2000028 on timeline 1
pg_basebackup: starting background WAL receiver
pg_basebackup: write-ahead log end point: 0/2000100
pg_basebackup: waiting for background process to finish streaming ...
pg_basebackup: syncing data to disk ...
pg_basebackup: renaming backup_manifest.tmp to backup_manifest
pg_basebackup: base backup completed`
startSeg, stopSeg, err := ParseBasebackupStderr(stderr)
require.NoError(t, err)
assert.Equal(t, "000000010000000000000002", startSeg)
assert.Equal(t, "000000010000000000000002", stopSeg)
}
func Test_ParseBasebackupStderr_WithHighLSNValues_ExtractsCorrectSegments(t *testing.T) {
stderr := `pg_basebackup: write-ahead log start point: 1/AB000028 on timeline 1
pg_basebackup: write-ahead log end point: 1/AC000000`
startSeg, stopSeg, err := ParseBasebackupStderr(stderr)
require.NoError(t, err)
assert.Equal(t, "0000000100000001000000AB", startSeg)
assert.Equal(t, "0000000100000001000000AC", stopSeg)
}
func Test_ParseBasebackupStderr_WithHighLogID_ExtractsCorrectSegments(t *testing.T) {
stderr := `pg_basebackup: write-ahead log start point: A/FF000028 on timeline 1
pg_basebackup: write-ahead log end point: B/1000000`
startSeg, stopSeg, err := ParseBasebackupStderr(stderr)
require.NoError(t, err)
assert.Equal(t, "000000010000000A000000FF", startSeg)
assert.Equal(t, "000000010000000B00000001", stopSeg)
}
func Test_ParseBasebackupStderr_WhenStartLSNMissing_ReturnsError(t *testing.T) {
stderr := `pg_basebackup: write-ahead log end point: 0/2000100
pg_basebackup: base backup completed`
_, _, err := ParseBasebackupStderr(stderr)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to parse start WAL location")
}
func Test_ParseBasebackupStderr_WhenStopLSNMissing_ReturnsError(t *testing.T) {
stderr := `pg_basebackup: write-ahead log start point: 0/2000028 on timeline 1
pg_basebackup: base backup completed`
_, _, err := ParseBasebackupStderr(stderr)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to parse stop WAL location")
}
func Test_ParseBasebackupStderr_WhenEmptyStderr_ReturnsError(t *testing.T) {
_, _, err := ParseBasebackupStderr("")
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to parse start WAL location")
}
func Test_LSNToSegmentName_WithBoundaryValues_ConvertsCorrectly(t *testing.T) {
tests := []struct {
name string
lsn string
timeline uint32
segSize uint32
expected string
}{
{
name: "first segment",
lsn: "0/1000000",
timeline: 1,
segSize: 16 * 1024 * 1024,
expected: "000000010000000000000001",
},
{
name: "segment at boundary FF",
lsn: "0/FF000000",
timeline: 1,
segSize: 16 * 1024 * 1024,
expected: "0000000100000000000000FF",
},
{
name: "segment in second log file",
lsn: "1/0",
timeline: 1,
segSize: 16 * 1024 * 1024,
expected: "000000010000000100000000",
},
{
name: "segment with offset within 16MB",
lsn: "0/200ABCD",
timeline: 1,
segSize: 16 * 1024 * 1024,
expected: "000000010000000000000002",
},
{
name: "zero LSN",
lsn: "0/0",
timeline: 1,
segSize: 16 * 1024 * 1024,
expected: "000000010000000000000000",
},
{
name: "high timeline ID",
lsn: "0/1000000",
timeline: 2,
segSize: 16 * 1024 * 1024,
expected: "000000020000000000000001",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := LSNToSegmentName(tt.lsn, tt.timeline, tt.segSize)
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}
func Test_LSNToSegmentName_WithInvalidLSN_ReturnsError(t *testing.T) {
tests := []struct {
name string
lsn string
}{
{name: "no slash", lsn: "012345"},
{name: "empty string", lsn: ""},
{name: "invalid hex high", lsn: "GG/0"},
{name: "invalid hex low", lsn: "0/ZZ"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := LSNToSegmentName(tt.lsn, 1, 16*1024*1024)
require.Error(t, err)
})
}
}

View File

@@ -0,0 +1,444 @@
package restore
import (
"archive/tar"
"context"
"errors"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"strings"
"time"
"github.com/klauspost/compress/zstd"
"databasus-agent/internal/features/api"
)
const (
walRestoreDir = "databasus-wal-restore"
maxRetryAttempts = 3
retryBaseDelay = 1 * time.Second
recoverySignalFile = "recovery.signal"
autoConfFile = "postgresql.auto.conf"
dockerContainerPgDataDir = "/var/lib/postgresql/data"
)
var retryDelayOverride *time.Duration
type Restorer struct {
apiClient *api.Client
log *slog.Logger
targetPgDataDir string
backupID string
targetTime string
pgType string
}
func NewRestorer(
apiClient *api.Client,
log *slog.Logger,
targetPgDataDir string,
backupID string,
targetTime string,
pgType string,
) *Restorer {
return &Restorer{
apiClient,
log,
targetPgDataDir,
backupID,
targetTime,
pgType,
}
}
func (r *Restorer) Run(ctx context.Context) error {
var parsedTargetTime *time.Time
if r.targetTime != "" {
parsed, err := time.Parse(time.RFC3339, r.targetTime)
if err != nil {
return fmt.Errorf("invalid --target-time format (expected RFC3339, e.g. 2026-02-28T14:30:00Z): %w", err)
}
parsedTargetTime = &parsed
}
if err := r.validateTargetPgDataDir(); err != nil {
return err
}
plan, err := r.getRestorePlanFromServer(ctx)
if err != nil {
return err
}
r.logRestorePlan(plan, parsedTargetTime)
r.log.Info("Downloading and extracting basebackup...")
if err := r.downloadAndExtractBasebackup(ctx, plan.FullBackup.BackupID); err != nil {
return fmt.Errorf("basebackup download failed: %w", err)
}
r.log.Info("Basebackup extracted successfully")
if err := r.downloadAllWalSegments(ctx, plan.WalSegments); err != nil {
return err
}
if err := r.configurePostgresRecovery(parsedTargetTime); err != nil {
return fmt.Errorf("failed to configure recovery: %w", err)
}
if err := os.Chmod(r.targetPgDataDir, 0o700); err != nil {
return fmt.Errorf("set PGDATA permissions: %w", err)
}
r.printCompletionMessage()
return nil
}
func (r *Restorer) validateTargetPgDataDir() error {
info, err := os.Stat(r.targetPgDataDir)
if err != nil {
if os.IsNotExist(err) {
return fmt.Errorf("target pgdata directory does not exist: %s", r.targetPgDataDir)
}
return fmt.Errorf("cannot access target pgdata directory: %w", err)
}
if !info.IsDir() {
return fmt.Errorf("target pgdata path is not a directory: %s", r.targetPgDataDir)
}
entries, err := os.ReadDir(r.targetPgDataDir)
if err != nil {
return fmt.Errorf("cannot read target pgdata directory: %w", err)
}
if len(entries) > 0 {
return fmt.Errorf("target pgdata directory is not empty: %s", r.targetPgDataDir)
}
return nil
}
func (r *Restorer) getRestorePlanFromServer(ctx context.Context) (*api.GetRestorePlanResponse, error) {
plan, planErr, err := r.apiClient.GetRestorePlan(ctx, r.backupID)
if err != nil {
return nil, fmt.Errorf("failed to fetch restore plan: %w", err)
}
if planErr != nil {
if planErr.LastContiguousSegment != "" {
return nil, fmt.Errorf("restore plan error: %s (last contiguous segment: %s)",
planErr.Message, planErr.LastContiguousSegment)
}
return nil, fmt.Errorf("restore plan error: %s", planErr.Message)
}
return plan, nil
}
func (r *Restorer) logRestorePlan(plan *api.GetRestorePlanResponse, parsedTargetTime *time.Time) {
recoveryTarget := "full recovery (all available WAL)"
if parsedTargetTime != nil {
recoveryTarget = parsedTargetTime.Format(time.RFC3339)
}
r.log.Info("Restore plan",
"fullBackupID", plan.FullBackup.BackupID,
"fullBackupCreatedAt", plan.FullBackup.CreatedAt.Format(time.RFC3339),
"pgVersion", plan.FullBackup.PgVersion,
"walSegmentCount", len(plan.WalSegments),
"totalDownloadSize", formatSizeBytes(plan.TotalSizeBytes),
"latestAvailableSegment", plan.LatestAvailableSegment,
"recoveryTarget", recoveryTarget,
)
}
func (r *Restorer) downloadAndExtractBasebackup(ctx context.Context, backupID string) error {
body, err := r.apiClient.DownloadBackupFile(ctx, backupID)
if err != nil {
return err
}
defer func() { _ = body.Close() }()
zstdReader, err := zstd.NewReader(body)
if err != nil {
return fmt.Errorf("create zstd decompressor: %w", err)
}
defer zstdReader.Close()
tarReader := tar.NewReader(zstdReader)
return r.extractTarArchive(tarReader)
}
func (r *Restorer) extractTarArchive(tarReader *tar.Reader) error {
for {
header, err := tarReader.Next()
if errors.Is(err, io.EOF) {
return nil
}
if err != nil {
return fmt.Errorf("read tar entry: %w", err)
}
targetPath := filepath.Join(r.targetPgDataDir, header.Name)
relativePath, err := filepath.Rel(r.targetPgDataDir, targetPath)
if err != nil || strings.HasPrefix(relativePath, "..") {
return fmt.Errorf("tar entry attempts path traversal: %s", header.Name)
}
switch header.Typeflag {
case tar.TypeDir:
if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil {
return fmt.Errorf("create directory %s: %w", header.Name, err)
}
case tar.TypeReg:
parentDir := filepath.Dir(targetPath)
if err := os.MkdirAll(parentDir, 0o755); err != nil {
return fmt.Errorf("create parent directory for %s: %w", header.Name, err)
}
file, err := os.OpenFile(targetPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode))
if err != nil {
return fmt.Errorf("create file %s: %w", header.Name, err)
}
if _, err := io.Copy(file, tarReader); err != nil {
_ = file.Close()
return fmt.Errorf("write file %s: %w", header.Name, err)
}
_ = file.Close()
case tar.TypeSymlink:
if err := os.Symlink(header.Linkname, targetPath); err != nil {
return fmt.Errorf("create symlink %s: %w", header.Name, err)
}
case tar.TypeLink:
linkTarget := filepath.Join(r.targetPgDataDir, header.Linkname)
if err := os.Link(linkTarget, targetPath); err != nil {
return fmt.Errorf("create hard link %s: %w", header.Name, err)
}
default:
r.log.Warn("Skipping unsupported tar entry type",
"name", header.Name,
"type", header.Typeflag,
)
}
}
}
func (r *Restorer) downloadAllWalSegments(ctx context.Context, segments []api.RestorePlanWalSegment) error {
walRestorePath := filepath.Join(r.targetPgDataDir, walRestoreDir)
if err := os.MkdirAll(walRestorePath, 0o755); err != nil {
return fmt.Errorf("create WAL restore directory: %w", err)
}
for segmentIndex, segment := range segments {
if err := r.downloadWalSegmentWithRetry(ctx, segment, segmentIndex, len(segments)); err != nil {
return err
}
}
return nil
}
func (r *Restorer) downloadWalSegmentWithRetry(
ctx context.Context,
segment api.RestorePlanWalSegment,
segmentIndex int,
segmentsTotal int,
) error {
r.log.Info("Downloading WAL segment",
"segment", segment.SegmentName,
"progress", fmt.Sprintf("%d/%d", segmentIndex+1, segmentsTotal),
)
var lastErr error
for attempt := range maxRetryAttempts {
if err := r.downloadWalSegment(ctx, segment); err != nil {
lastErr = err
delay := r.getRetryDelay(attempt)
r.log.Warn("WAL segment download failed, retrying",
"segment", segment.SegmentName,
"attempt", attempt+1,
"maxAttempts", maxRetryAttempts,
"retryDelay", delay,
"error", err,
)
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(delay):
continue
}
}
return nil
}
return fmt.Errorf("failed to download WAL segment %s after %d attempts: %w",
segment.SegmentName, maxRetryAttempts, lastErr)
}
func (r *Restorer) downloadWalSegment(ctx context.Context, segment api.RestorePlanWalSegment) error {
body, err := r.apiClient.DownloadBackupFile(ctx, segment.BackupID)
if err != nil {
return err
}
defer func() { _ = body.Close() }()
zstdReader, err := zstd.NewReader(body)
if err != nil {
return fmt.Errorf("create zstd decompressor: %w", err)
}
defer zstdReader.Close()
segmentPath := filepath.Join(r.targetPgDataDir, walRestoreDir, segment.SegmentName)
file, err := os.Create(segmentPath)
if err != nil {
return fmt.Errorf("create WAL segment file: %w", err)
}
defer func() { _ = file.Close() }()
if _, err := io.Copy(file, zstdReader); err != nil {
return fmt.Errorf("write WAL segment: %w", err)
}
return nil
}
func (r *Restorer) configurePostgresRecovery(parsedTargetTime *time.Time) error {
recoverySignalPath := filepath.Join(r.targetPgDataDir, recoverySignalFile)
if err := os.WriteFile(recoverySignalPath, []byte{}, 0o644); err != nil {
return fmt.Errorf("create recovery.signal: %w", err)
}
walRestoreAbsPath, err := r.resolveWalRestorePath()
if err != nil {
return err
}
autoConfPath := filepath.Join(r.targetPgDataDir, autoConfFile)
autoConfFile, err := os.OpenFile(autoConfPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return fmt.Errorf("open postgresql.auto.conf: %w", err)
}
defer func() { _ = autoConfFile.Close() }()
var configLines strings.Builder
configLines.WriteString("\n# Added by databasus-agent restore\n")
fmt.Fprintf(&configLines, "restore_command = 'cp %s/%%f %%p'\n", walRestoreAbsPath)
fmt.Fprintf(&configLines, "recovery_end_command = 'rm -rf %s'\n", walRestoreAbsPath)
configLines.WriteString("recovery_target_action = 'promote'\n")
if parsedTargetTime != nil {
fmt.Fprintf(&configLines, "recovery_target_time = '%s'\n", parsedTargetTime.Format(time.RFC3339))
}
if _, err := autoConfFile.WriteString(configLines.String()); err != nil {
return fmt.Errorf("write to postgresql.auto.conf: %w", err)
}
return nil
}
func (r *Restorer) printCompletionMessage() {
absPgDataDir, _ := filepath.Abs(r.targetPgDataDir)
isDocker := r.pgType == "docker"
fmt.Printf("\nRestore complete. PGDATA directory is ready at %s.\n", absPgDataDir)
fmt.Print(`
What happens when you start PostgreSQL:
1. PostgreSQL detects recovery.signal and enters recovery mode
2. It replays WAL from the basebackup's consistency point
3. It executes restore_command to fetch WAL segments from databasus-wal-restore/
4. WAL replay continues until target_time (if PITR) or end of available WAL
5. recovery_end_command automatically removes databasus-wal-restore/
6. PostgreSQL promotes to primary and removes recovery.signal
7. Normal operations resume
`)
if isDocker {
fmt.Printf(`
Start PostgreSQL by launching a container with the restored data mounted:
docker run -d -v %s:%s postgres:<VERSION>
Or if you have an existing container:
docker start <CONTAINER_NAME>
Ensure %s is mounted as the container's pgdata volume at %s.
`, absPgDataDir, dockerContainerPgDataDir, absPgDataDir, dockerContainerPgDataDir)
} else {
fmt.Printf(`
Start PostgreSQL:
pg_ctl -D %s start
Note: If you move the PGDATA directory before starting PostgreSQL,
update restore_command and recovery_end_command paths in
postgresql.auto.conf accordingly.
`, absPgDataDir)
}
}
func (r *Restorer) resolveWalRestorePath() (string, error) {
if r.pgType == "docker" {
return dockerContainerPgDataDir + "/" + walRestoreDir, nil
}
absPgDataDir, err := filepath.Abs(r.targetPgDataDir)
if err != nil {
return "", fmt.Errorf("resolve absolute path: %w", err)
}
absPgDataDir = filepath.ToSlash(absPgDataDir)
return absPgDataDir + "/" + walRestoreDir, nil
}
func (r *Restorer) getRetryDelay(attempt int) time.Duration {
if retryDelayOverride != nil {
return *retryDelayOverride
}
return retryBaseDelay * time.Duration(1<<attempt)
}
func formatSizeBytes(sizeBytes int64) string {
const (
kilobyte = 1024
megabyte = 1024 * kilobyte
gigabyte = 1024 * megabyte
)
switch {
case sizeBytes >= gigabyte:
return fmt.Sprintf("%.2f GB", float64(sizeBytes)/float64(gigabyte))
case sizeBytes >= megabyte:
return fmt.Sprintf("%.2f MB", float64(sizeBytes)/float64(megabyte))
case sizeBytes >= kilobyte:
return fmt.Sprintf("%.2f KB", float64(sizeBytes)/float64(kilobyte))
default:
return fmt.Sprintf("%d B", sizeBytes)
}
}

View File

@@ -0,0 +1,711 @@
package restore
import (
"archive/tar"
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/klauspost/compress/zstd"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"databasus-agent/internal/features/api"
"databasus-agent/internal/logger"
)
const (
testRestorePlanPath = "/api/v1/backups/postgres/wal/restore/plan"
testRestoreDownloadPath = "/api/v1/backups/postgres/wal/restore/download"
testFullBackupID = "full-backup-id-1234"
testWalSegment1 = "000000010000000100000001"
testWalSegment2 = "000000010000000100000002"
)
func Test_RunRestore_WhenBasebackupAndWalSegmentsAvailable_FilesExtractedAndRecoveryConfigured(t *testing.T) {
tarFiles := map[string][]byte{
"PG_VERSION": []byte("16"),
"base/1/somefile": []byte("table-data"),
}
zstdTarData := createZstdTar(t, tarFiles)
walData1 := createZstdData(t, []byte("wal-segment-1-data"))
walData2 := createZstdData(t, []byte("wal-segment-2-data"))
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testRestorePlanPath:
writeJSON(w, api.GetRestorePlanResponse{
FullBackup: api.RestorePlanFullBackup{
BackupID: testFullBackupID,
FullBackupWalStartSegment: testWalSegment1,
FullBackupWalStopSegment: testWalSegment1,
PgVersion: "16",
CreatedAt: time.Now().UTC(),
SizeBytes: 1024,
},
WalSegments: []api.RestorePlanWalSegment{
{BackupID: "wal-1", SegmentName: testWalSegment1, SizeBytes: 512},
{BackupID: "wal-2", SegmentName: testWalSegment2, SizeBytes: 512},
},
TotalSizeBytes: 2048,
LatestAvailableSegment: testWalSegment2,
})
case testRestoreDownloadPath:
backupID := r.URL.Query().Get("backupId")
switch backupID {
case testFullBackupID:
w.Header().Set("Content-Type", "application/octet-stream")
_, _ = w.Write(zstdTarData)
case "wal-1":
w.Header().Set("Content-Type", "application/octet-stream")
_, _ = w.Write(walData1)
case "wal-2":
w.Header().Set("Content-Type", "application/octet-stream")
_, _ = w.Write(walData2)
default:
w.WriteHeader(http.StatusBadRequest)
}
default:
w.WriteHeader(http.StatusNotFound)
}
})
targetDir := createTestTargetDir(t)
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
err := restorer.Run(t.Context())
require.NoError(t, err)
pgVersionContent, err := os.ReadFile(filepath.Join(targetDir, "PG_VERSION"))
require.NoError(t, err)
assert.Equal(t, "16", string(pgVersionContent))
someFileContent, err := os.ReadFile(filepath.Join(targetDir, "base", "1", "somefile"))
require.NoError(t, err)
assert.Equal(t, "table-data", string(someFileContent))
walSegment1Content, err := os.ReadFile(filepath.Join(targetDir, walRestoreDir, testWalSegment1))
require.NoError(t, err)
assert.Equal(t, "wal-segment-1-data", string(walSegment1Content))
walSegment2Content, err := os.ReadFile(filepath.Join(targetDir, walRestoreDir, testWalSegment2))
require.NoError(t, err)
assert.Equal(t, "wal-segment-2-data", string(walSegment2Content))
recoverySignalPath := filepath.Join(targetDir, "recovery.signal")
recoverySignalInfo, err := os.Stat(recoverySignalPath)
require.NoError(t, err)
assert.Equal(t, int64(0), recoverySignalInfo.Size())
autoConfContent, err := os.ReadFile(filepath.Join(targetDir, "postgresql.auto.conf"))
require.NoError(t, err)
autoConfStr := string(autoConfContent)
assert.Contains(t, autoConfStr, "restore_command")
assert.Contains(t, autoConfStr, walRestoreDir)
assert.Contains(t, autoConfStr, "recovery_target_action = 'promote'")
assert.Contains(t, autoConfStr, "recovery_end_command")
assert.NotContains(t, autoConfStr, "recovery_target_time")
}
func Test_RunRestore_WhenTargetTimeProvided_RecoveryTargetTimeWrittenToConfig(t *testing.T) {
tarFiles := map[string][]byte{"PG_VERSION": []byte("16")}
zstdTarData := createZstdTar(t, tarFiles)
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testRestorePlanPath:
writeJSON(w, api.GetRestorePlanResponse{
FullBackup: api.RestorePlanFullBackup{
BackupID: testFullBackupID,
PgVersion: "16",
CreatedAt: time.Now().UTC(),
SizeBytes: 1024,
},
WalSegments: []api.RestorePlanWalSegment{},
TotalSizeBytes: 1024,
LatestAvailableSegment: "",
})
case testRestoreDownloadPath:
w.Header().Set("Content-Type", "application/octet-stream")
_, _ = w.Write(zstdTarData)
default:
w.WriteHeader(http.StatusNotFound)
}
})
targetDir := createTestTargetDir(t)
restorer := newTestRestorer(server.URL, targetDir, "", "2026-02-28T14:30:00Z", "")
err := restorer.Run(t.Context())
require.NoError(t, err)
autoConfContent, err := os.ReadFile(filepath.Join(targetDir, "postgresql.auto.conf"))
require.NoError(t, err)
assert.Contains(t, string(autoConfContent), "recovery_target_time = '2026-02-28T14:30:00Z'")
}
func Test_RunRestore_WhenPgDataDirNotEmpty_ReturnsError(t *testing.T) {
targetDir := createTestTargetDir(t)
err := os.WriteFile(filepath.Join(targetDir, "existing-file"), []byte("data"), 0o644)
require.NoError(t, err)
restorer := newTestRestorer("http://localhost:0", targetDir, "", "", "")
err = restorer.Run(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), "not empty")
}
func Test_RunRestore_WhenPgDataDirDoesNotExist_ReturnsError(t *testing.T) {
nonExistentDir := filepath.Join(os.TempDir(), "databasus-test-nonexistent-dir-12345")
restorer := newTestRestorer("http://localhost:0", nonExistentDir, "", "", "")
err := restorer.Run(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), "does not exist")
}
func Test_RunRestore_WhenNoBackupsAvailable_ReturnsError(t *testing.T) {
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
_ = json.NewEncoder(w).Encode(api.GetRestorePlanErrorResponse{
Error: "no_backups",
Message: "No full backups available",
})
})
targetDir := createTestTargetDir(t)
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
err := restorer.Run(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), "No full backups available")
}
func Test_RunRestore_WhenWalChainBroken_ReturnsError(t *testing.T) {
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
_ = json.NewEncoder(w).Encode(api.GetRestorePlanErrorResponse{
Error: "wal_chain_broken",
Message: "WAL chain broken",
LastContiguousSegment: testWalSegment1,
})
})
targetDir := createTestTargetDir(t)
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
err := restorer.Run(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), "WAL chain broken")
assert.Contains(t, err.Error(), testWalSegment1)
}
func Test_DownloadWalSegment_WhenFirstAttemptFails_RetriesAndSucceeds(t *testing.T) {
tarFiles := map[string][]byte{"PG_VERSION": []byte("16")}
zstdTarData := createZstdTar(t, tarFiles)
walData := createZstdData(t, []byte("wal-segment-data"))
var mu sync.Mutex
var walDownloadAttempts int
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testRestorePlanPath:
writeJSON(w, api.GetRestorePlanResponse{
FullBackup: api.RestorePlanFullBackup{
BackupID: testFullBackupID,
PgVersion: "16",
CreatedAt: time.Now().UTC(),
SizeBytes: 1024,
},
WalSegments: []api.RestorePlanWalSegment{
{BackupID: "wal-1", SegmentName: testWalSegment1, SizeBytes: 512},
},
TotalSizeBytes: 1536,
LatestAvailableSegment: testWalSegment1,
})
case testRestoreDownloadPath:
backupID := r.URL.Query().Get("backupId")
if backupID == testFullBackupID {
w.Header().Set("Content-Type", "application/octet-stream")
_, _ = w.Write(zstdTarData)
return
}
mu.Lock()
walDownloadAttempts++
attempt := walDownloadAttempts
mu.Unlock()
if attempt == 1 {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(`{"error":"storage unavailable"}`))
return
}
w.Header().Set("Content-Type", "application/octet-stream")
_, _ = w.Write(walData)
default:
w.WriteHeader(http.StatusNotFound)
}
})
targetDir := createTestTargetDir(t)
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
origDelay := retryDelayOverride
testDelay := 10 * time.Millisecond
retryDelayOverride = &testDelay
defer func() { retryDelayOverride = origDelay }()
err := restorer.Run(t.Context())
require.NoError(t, err)
mu.Lock()
attempts := walDownloadAttempts
mu.Unlock()
assert.Equal(t, 2, attempts)
walContent, err := os.ReadFile(filepath.Join(targetDir, walRestoreDir, testWalSegment1))
require.NoError(t, err)
assert.Equal(t, "wal-segment-data", string(walContent))
}
func Test_DownloadWalSegment_WhenAllAttemptsFail_ReturnsErrorWithSegmentName(t *testing.T) {
tarFiles := map[string][]byte{"PG_VERSION": []byte("16")}
zstdTarData := createZstdTar(t, tarFiles)
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testRestorePlanPath:
writeJSON(w, api.GetRestorePlanResponse{
FullBackup: api.RestorePlanFullBackup{
BackupID: testFullBackupID,
PgVersion: "16",
CreatedAt: time.Now().UTC(),
SizeBytes: 1024,
},
WalSegments: []api.RestorePlanWalSegment{
{BackupID: "wal-1", SegmentName: testWalSegment1, SizeBytes: 512},
},
TotalSizeBytes: 1536,
LatestAvailableSegment: testWalSegment1,
})
case testRestoreDownloadPath:
backupID := r.URL.Query().Get("backupId")
if backupID == testFullBackupID {
w.Header().Set("Content-Type", "application/octet-stream")
_, _ = w.Write(zstdTarData)
return
}
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(`{"error":"storage unavailable"}`))
default:
w.WriteHeader(http.StatusNotFound)
}
})
targetDir := createTestTargetDir(t)
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
origDelay := retryDelayOverride
testDelay := 10 * time.Millisecond
retryDelayOverride = &testDelay
defer func() { retryDelayOverride = origDelay }()
err := restorer.Run(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), testWalSegment1)
assert.Contains(t, err.Error(), "3 attempts")
}
func Test_RunRestore_WhenInvalidTargetTimeFormat_ReturnsError(t *testing.T) {
targetDir := createTestTargetDir(t)
restorer := newTestRestorer("http://localhost:0", targetDir, "", "not-a-valid-time", "")
err := restorer.Run(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid --target-time format")
}
func Test_RunRestore_WhenBasebackupDownloadFails_ReturnsError(t *testing.T) {
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testRestorePlanPath:
writeJSON(w, api.GetRestorePlanResponse{
FullBackup: api.RestorePlanFullBackup{
BackupID: testFullBackupID,
PgVersion: "16",
CreatedAt: time.Now().UTC(),
SizeBytes: 1024,
},
WalSegments: []api.RestorePlanWalSegment{},
TotalSizeBytes: 1024,
LatestAvailableSegment: "",
})
case testRestoreDownloadPath:
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(`{"error":"storage error"}`))
default:
w.WriteHeader(http.StatusNotFound)
}
})
targetDir := createTestTargetDir(t)
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
err := restorer.Run(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), "basebackup download failed")
}
func Test_RunRestore_WhenNoWalSegmentsInPlan_BasebackupRestoredSuccessfully(t *testing.T) {
tarFiles := map[string][]byte{
"PG_VERSION": []byte("16"),
"global/pg_control": []byte("control-data"),
}
zstdTarData := createZstdTar(t, tarFiles)
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testRestorePlanPath:
writeJSON(w, api.GetRestorePlanResponse{
FullBackup: api.RestorePlanFullBackup{
BackupID: testFullBackupID,
PgVersion: "16",
CreatedAt: time.Now().UTC(),
SizeBytes: 1024,
},
WalSegments: []api.RestorePlanWalSegment{},
TotalSizeBytes: 1024,
LatestAvailableSegment: "",
})
case testRestoreDownloadPath:
w.Header().Set("Content-Type", "application/octet-stream")
_, _ = w.Write(zstdTarData)
default:
w.WriteHeader(http.StatusNotFound)
}
})
targetDir := createTestTargetDir(t)
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
err := restorer.Run(t.Context())
require.NoError(t, err)
pgVersionContent, err := os.ReadFile(filepath.Join(targetDir, "PG_VERSION"))
require.NoError(t, err)
assert.Equal(t, "16", string(pgVersionContent))
walRestoreDirInfo, err := os.Stat(filepath.Join(targetDir, walRestoreDir))
require.NoError(t, err)
assert.True(t, walRestoreDirInfo.IsDir())
_, err = os.Stat(filepath.Join(targetDir, "recovery.signal"))
require.NoError(t, err)
autoConfContent, err := os.ReadFile(filepath.Join(targetDir, "postgresql.auto.conf"))
require.NoError(t, err)
assert.Contains(t, string(autoConfContent), "restore_command")
}
func Test_RunRestore_WhenMakingApiCalls_AuthTokenIncludedInRequests(t *testing.T) {
tarFiles := map[string][]byte{"PG_VERSION": []byte("16")}
zstdTarData := createZstdTar(t, tarFiles)
var receivedAuthHeaders atomic.Int32
var mu sync.Mutex
var authHeaderValues []string
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader != "" {
receivedAuthHeaders.Add(1)
mu.Lock()
authHeaderValues = append(authHeaderValues, authHeader)
mu.Unlock()
}
switch r.URL.Path {
case testRestorePlanPath:
writeJSON(w, api.GetRestorePlanResponse{
FullBackup: api.RestorePlanFullBackup{
BackupID: testFullBackupID,
PgVersion: "16",
CreatedAt: time.Now().UTC(),
SizeBytes: 1024,
},
WalSegments: []api.RestorePlanWalSegment{},
TotalSizeBytes: 1024,
LatestAvailableSegment: "",
})
case testRestoreDownloadPath:
w.Header().Set("Content-Type", "application/octet-stream")
_, _ = w.Write(zstdTarData)
default:
w.WriteHeader(http.StatusNotFound)
}
})
targetDir := createTestTargetDir(t)
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
err := restorer.Run(t.Context())
require.NoError(t, err)
assert.GreaterOrEqual(t, int(receivedAuthHeaders.Load()), 2)
mu.Lock()
defer mu.Unlock()
for _, headerValue := range authHeaderValues {
assert.Equal(t, "test-token", headerValue)
}
}
func Test_ConfigurePostgresRecovery_WhenPgTypeHost_UsesHostAbsolutePath(t *testing.T) {
tarFiles := map[string][]byte{"PG_VERSION": []byte("16")}
zstdTarData := createZstdTar(t, tarFiles)
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testRestorePlanPath:
writeJSON(w, api.GetRestorePlanResponse{
FullBackup: api.RestorePlanFullBackup{
BackupID: testFullBackupID,
PgVersion: "16",
CreatedAt: time.Now().UTC(),
SizeBytes: 1024,
},
WalSegments: []api.RestorePlanWalSegment{},
TotalSizeBytes: 1024,
LatestAvailableSegment: "",
})
case testRestoreDownloadPath:
w.Header().Set("Content-Type", "application/octet-stream")
_, _ = w.Write(zstdTarData)
default:
w.WriteHeader(http.StatusNotFound)
}
})
targetDir := createTestTargetDir(t)
restorer := newTestRestorer(server.URL, targetDir, "", "", "host")
err := restorer.Run(t.Context())
require.NoError(t, err)
autoConfContent, err := os.ReadFile(filepath.Join(targetDir, "postgresql.auto.conf"))
require.NoError(t, err)
autoConfStr := string(autoConfContent)
absTargetDir, _ := filepath.Abs(targetDir)
absTargetDir = filepath.ToSlash(absTargetDir)
expectedWalPath := absTargetDir + "/" + walRestoreDir
assert.Contains(t, autoConfStr, fmt.Sprintf("restore_command = 'cp %s/%%f %%p'", expectedWalPath))
assert.Contains(t, autoConfStr, fmt.Sprintf("recovery_end_command = 'rm -rf %s'", expectedWalPath))
assert.NotContains(t, autoConfStr, "/var/lib/postgresql/data")
}
func Test_ConfigurePostgresRecovery_WhenPgTypeDocker_UsesContainerPath(t *testing.T) {
tarFiles := map[string][]byte{"PG_VERSION": []byte("16")}
zstdTarData := createZstdTar(t, tarFiles)
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testRestorePlanPath:
writeJSON(w, api.GetRestorePlanResponse{
FullBackup: api.RestorePlanFullBackup{
BackupID: testFullBackupID,
PgVersion: "16",
CreatedAt: time.Now().UTC(),
SizeBytes: 1024,
},
WalSegments: []api.RestorePlanWalSegment{},
TotalSizeBytes: 1024,
LatestAvailableSegment: "",
})
case testRestoreDownloadPath:
w.Header().Set("Content-Type", "application/octet-stream")
_, _ = w.Write(zstdTarData)
default:
w.WriteHeader(http.StatusNotFound)
}
})
targetDir := createTestTargetDir(t)
restorer := newTestRestorer(server.URL, targetDir, "", "", "docker")
err := restorer.Run(t.Context())
require.NoError(t, err)
autoConfContent, err := os.ReadFile(filepath.Join(targetDir, "postgresql.auto.conf"))
require.NoError(t, err)
autoConfStr := string(autoConfContent)
expectedWalPath := "/var/lib/postgresql/data/" + walRestoreDir
assert.Contains(t, autoConfStr, fmt.Sprintf("restore_command = 'cp %s/%%f %%p'", expectedWalPath))
assert.Contains(t, autoConfStr, fmt.Sprintf("recovery_end_command = 'rm -rf %s'", expectedWalPath))
absTargetDir, _ := filepath.Abs(targetDir)
absTargetDir = filepath.ToSlash(absTargetDir)
assert.NotContains(t, autoConfStr, absTargetDir)
}
func newTestServer(t *testing.T, handler http.HandlerFunc) *httptest.Server {
t.Helper()
server := httptest.NewServer(handler)
t.Cleanup(server.Close)
return server
}
func createTestTargetDir(t *testing.T) string {
t.Helper()
baseDir := filepath.Join(".", ".test-tmp")
if err := os.MkdirAll(baseDir, 0o755); err != nil {
t.Fatalf("failed to create base test dir: %v", err)
}
dir, err := os.MkdirTemp(baseDir, t.Name()+"-*")
if err != nil {
t.Fatalf("failed to create test target dir: %v", err)
}
t.Cleanup(func() {
_ = os.RemoveAll(dir)
})
return dir
}
func createZstdTar(t *testing.T, files map[string][]byte) []byte {
t.Helper()
var tarBuffer bytes.Buffer
tarWriter := tar.NewWriter(&tarBuffer)
createdDirs := make(map[string]bool)
for name, content := range files {
dir := filepath.Dir(name)
if dir != "." && !createdDirs[dir] {
parts := strings.Split(filepath.ToSlash(dir), "/")
for partIndex := range parts {
partialDir := strings.Join(parts[:partIndex+1], "/")
if !createdDirs[partialDir] {
err := tarWriter.WriteHeader(&tar.Header{
Name: partialDir + "/",
Typeflag: tar.TypeDir,
Mode: 0o755,
})
require.NoError(t, err)
createdDirs[partialDir] = true
}
}
}
err := tarWriter.WriteHeader(&tar.Header{
Name: name,
Size: int64(len(content)),
Mode: 0o644,
Typeflag: tar.TypeReg,
})
require.NoError(t, err)
_, err = tarWriter.Write(content)
require.NoError(t, err)
}
require.NoError(t, tarWriter.Close())
var zstdBuffer bytes.Buffer
encoder, err := zstd.NewWriter(&zstdBuffer,
zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(5)),
zstd.WithEncoderCRC(true),
)
require.NoError(t, err)
_, err = encoder.Write(tarBuffer.Bytes())
require.NoError(t, err)
require.NoError(t, encoder.Close())
return zstdBuffer.Bytes()
}
func createZstdData(t *testing.T, data []byte) []byte {
t.Helper()
var buffer bytes.Buffer
encoder, err := zstd.NewWriter(&buffer,
zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(5)),
zstd.WithEncoderCRC(true),
)
require.NoError(t, err)
_, err = encoder.Write(data)
require.NoError(t, err)
require.NoError(t, encoder.Close())
return buffer.Bytes()
}
func newTestRestorer(serverURL, targetPgDataDir, backupID, targetTime, pgType string) *Restorer {
apiClient := api.NewClient(serverURL, "test-token", logger.GetLogger())
return NewRestorer(apiClient, logger.GetLogger(), targetPgDataDir, backupID, targetTime, pgType)
}
func writeJSON(w http.ResponseWriter, value any) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(value); err != nil {
w.WriteHeader(http.StatusInternalServerError)
}
}

View File

@@ -0,0 +1,121 @@
//go:build !windows
package start
import (
"context"
"errors"
"fmt"
"log/slog"
"os"
"os/exec"
"syscall"
"time"
)
const (
logFileName = "databasus.log"
stopTimeout = 30 * time.Second
stopPollInterval = 500 * time.Millisecond
daemonStartupDelay = 500 * time.Millisecond
)
func Stop(log *slog.Logger) error {
pid, err := ReadLockFilePID()
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return errors.New("agent is not running (no lock file found)")
}
return fmt.Errorf("failed to read lock file: %w", err)
}
if !isProcessAlive(pid) {
_ = os.Remove(lockFileName)
return fmt.Errorf("agent is not running (stale lock file removed, PID %d)", pid)
}
log.Info("Sending SIGTERM to agent", "pid", pid)
if err := syscall.Kill(pid, syscall.SIGTERM); err != nil {
return fmt.Errorf("failed to send SIGTERM to PID %d: %w", pid, err)
}
deadline := time.Now().Add(stopTimeout)
for time.Now().Before(deadline) {
if !isProcessAlive(pid) {
log.Info("Agent stopped", "pid", pid)
return nil
}
time.Sleep(stopPollInterval)
}
return fmt.Errorf("agent (PID %d) did not stop within %s — process may be stuck", pid, stopTimeout)
}
func Status(log *slog.Logger) error {
pid, err := ReadLockFilePID()
if err != nil {
if errors.Is(err, os.ErrNotExist) {
fmt.Println("Agent is not running")
return nil
}
return fmt.Errorf("failed to read lock file: %w", err)
}
if isProcessAlive(pid) {
fmt.Printf("Agent is running (PID %d)\n", pid)
} else {
fmt.Println("Agent is not running (stale lock file)")
_ = os.Remove(lockFileName)
}
return nil
}
func spawnDaemon(log *slog.Logger) (int, error) {
execPath, err := os.Executable()
if err != nil {
return 0, fmt.Errorf("failed to resolve executable path: %w", err)
}
args := []string{"_run"}
logFile, err := os.OpenFile(logFileName, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return 0, fmt.Errorf("failed to open log file %s: %w", logFileName, err)
}
cwd, err := os.Getwd()
if err != nil {
_ = logFile.Close()
return 0, fmt.Errorf("failed to get working directory: %w", err)
}
cmd := exec.CommandContext(context.Background(), execPath, args...)
cmd.Dir = cwd
cmd.Stderr = logFile
cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true}
if err := cmd.Start(); err != nil {
_ = logFile.Close()
return 0, fmt.Errorf("failed to start daemon process: %w", err)
}
pid := cmd.Process.Pid
// Detach — we don't wait for the child
_ = logFile.Close()
time.Sleep(daemonStartupDelay)
if !isProcessAlive(pid) {
return 0, fmt.Errorf("daemon process (PID %d) exited immediately — check %s for details", pid, logFileName)
}
log.Info("Daemon spawned", "pid", pid, "log", logFileName)
return pid, nil
}

View File

@@ -0,0 +1,20 @@
//go:build windows
package start
import (
"errors"
"log/slog"
)
func Stop(log *slog.Logger) error {
return errors.New("stop is not supported on Windows — use Ctrl+C in the terminal where the agent is running")
}
func Status(log *slog.Logger) error {
return errors.New("status is not supported on Windows — check the terminal where the agent is running")
}
func spawnDaemon(_ *slog.Logger) (int, error) {
return 0, errors.New("daemon mode is not supported on Windows")
}

View File

@@ -0,0 +1,132 @@
//go:build !windows
package start
import (
"errors"
"fmt"
"io"
"log/slog"
"os"
"strconv"
"strings"
"syscall"
)
const lockFileName = "databasus.lock"
func AcquireLock(log *slog.Logger) (*os.File, error) {
f, err := os.OpenFile(lockFileName, os.O_CREATE|os.O_RDWR, 0o644)
if err != nil {
return nil, fmt.Errorf("failed to open lock file: %w", err)
}
err = syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB)
if err == nil {
if err := writePID(f); err != nil {
_ = f.Close()
return nil, err
}
log.Info("Process lock acquired", "pid", os.Getpid(), "lockFile", lockFileName)
return f, nil
}
if !errors.Is(err, syscall.EWOULDBLOCK) {
_ = f.Close()
return nil, fmt.Errorf("failed to acquire lock: %w", err)
}
pid, pidErr := readLockPID(f)
_ = f.Close()
if pidErr != nil {
return nil, fmt.Errorf("another instance is already running")
}
return nil, fmt.Errorf("another instance is already running (PID %d)", pid)
}
func ReleaseLock(f *os.File) {
_ = syscall.Flock(int(f.Fd()), syscall.LOCK_UN)
lockedStat, lockedErr := f.Stat()
_ = f.Close()
if lockedErr != nil {
_ = os.Remove(lockFileName)
return
}
diskStat, diskErr := os.Stat(lockFileName)
if diskErr != nil {
return
}
if os.SameFile(lockedStat, diskStat) {
_ = os.Remove(lockFileName)
}
}
func ReadLockFilePID() (int, error) {
f, err := os.Open(lockFileName)
if err != nil {
return 0, err
}
defer func() { _ = f.Close() }()
return readLockPID(f)
}
func writePID(f *os.File) error {
if err := f.Truncate(0); err != nil {
return fmt.Errorf("failed to truncate lock file: %w", err)
}
if _, err := f.Seek(0, io.SeekStart); err != nil {
return fmt.Errorf("failed to seek lock file: %w", err)
}
if _, err := fmt.Fprintf(f, "%d\n", os.Getpid()); err != nil {
return fmt.Errorf("failed to write PID to lock file: %w", err)
}
return f.Sync()
}
func readLockPID(f *os.File) (int, error) {
if _, err := f.Seek(0, io.SeekStart); err != nil {
return 0, err
}
data, err := io.ReadAll(f)
if err != nil {
return 0, err
}
s := strings.TrimSpace(string(data))
if s == "" {
return 0, errors.New("lock file is empty")
}
pid, err := strconv.Atoi(s)
if err != nil {
return 0, fmt.Errorf("invalid PID in lock file: %w", err)
}
return pid, nil
}
func isProcessAlive(pid int) bool {
err := syscall.Kill(pid, 0)
if err == nil {
return true
}
if errors.Is(err, syscall.EPERM) {
return true
}
return false
}

View File

@@ -0,0 +1,148 @@
//go:build !windows
package start
import (
"fmt"
"os"
"strconv"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"databasus-agent/internal/logger"
)
func Test_AcquireLock_LockFileCreatedWithPID(t *testing.T) {
setupTempDir(t)
log := logger.GetLogger()
lockFile, err := AcquireLock(log)
require.NoError(t, err)
defer ReleaseLock(lockFile)
data, err := os.ReadFile(lockFileName)
require.NoError(t, err)
pid, err := strconv.Atoi(strings.TrimSpace(string(data)))
require.NoError(t, err)
assert.Equal(t, os.Getpid(), pid)
}
func Test_AcquireLock_SecondAcquireFails_WhenFirstHeld(t *testing.T) {
setupTempDir(t)
log := logger.GetLogger()
first, err := AcquireLock(log)
require.NoError(t, err)
defer ReleaseLock(first)
second, err := AcquireLock(log)
assert.Nil(t, second)
require.Error(t, err)
assert.Contains(t, err.Error(), "another instance is already running")
assert.Contains(t, err.Error(), fmt.Sprintf("PID %d", os.Getpid()))
}
func Test_AcquireLock_StaleLockReacquired_WhenProcessDead(t *testing.T) {
setupTempDir(t)
log := logger.GetLogger()
err := os.WriteFile(lockFileName, []byte("999999999\n"), 0o644)
require.NoError(t, err)
lockFile, err := AcquireLock(log)
require.NoError(t, err)
defer ReleaseLock(lockFile)
data, err := os.ReadFile(lockFileName)
require.NoError(t, err)
pid, err := strconv.Atoi(strings.TrimSpace(string(data)))
require.NoError(t, err)
assert.Equal(t, os.Getpid(), pid)
}
func Test_ReleaseLock_LockFileRemoved(t *testing.T) {
setupTempDir(t)
log := logger.GetLogger()
lockFile, err := AcquireLock(log)
require.NoError(t, err)
ReleaseLock(lockFile)
_, err = os.Stat(lockFileName)
assert.True(t, os.IsNotExist(err))
}
func Test_AcquireLock_ReacquiredAfterRelease(t *testing.T) {
setupTempDir(t)
log := logger.GetLogger()
first, err := AcquireLock(log)
require.NoError(t, err)
ReleaseLock(first)
second, err := AcquireLock(log)
require.NoError(t, err)
defer ReleaseLock(second)
data, err := os.ReadFile(lockFileName)
require.NoError(t, err)
pid, err := strconv.Atoi(strings.TrimSpace(string(data)))
require.NoError(t, err)
assert.Equal(t, os.Getpid(), pid)
}
func Test_isProcessAlive_ReturnsTrueForSelf(t *testing.T) {
assert.True(t, isProcessAlive(os.Getpid()))
}
func Test_isProcessAlive_ReturnsFalseForNonExistentPID(t *testing.T) {
assert.False(t, isProcessAlive(999999999))
}
func Test_readLockPID_ParsesValidPID(t *testing.T) {
setupTempDir(t)
f, err := os.CreateTemp("", "lock-test-*")
require.NoError(t, err)
defer os.Remove(f.Name())
_, err = f.WriteString("12345\n")
require.NoError(t, err)
pid, err := readLockPID(f)
require.NoError(t, err)
assert.Equal(t, 12345, pid)
}
func Test_readLockPID_ReturnsErrorForEmptyFile(t *testing.T) {
setupTempDir(t)
f, err := os.CreateTemp("", "lock-test-*")
require.NoError(t, err)
defer os.Remove(f.Name())
_, err = readLockPID(f)
require.Error(t, err)
assert.Contains(t, err.Error(), "lock file is empty")
}
func setupTempDir(t *testing.T) string {
t.Helper()
origDir, err := os.Getwd()
require.NoError(t, err)
dir := t.TempDir()
require.NoError(t, os.Chdir(dir))
t.Cleanup(func() { _ = os.Chdir(origDir) })
return dir
}

View File

@@ -0,0 +1,90 @@
//go:build !windows
package start
import (
"context"
"log/slog"
"os"
"syscall"
"time"
)
const lockWatchInterval = 5 * time.Second
type LockWatcher struct {
originalInode uint64
cancel context.CancelFunc
log *slog.Logger
}
func NewLockWatcher(lockFile *os.File, cancel context.CancelFunc, log *slog.Logger) (*LockWatcher, error) {
inode, err := getFileInode(lockFile)
if err != nil {
return nil, err
}
return &LockWatcher{
originalInode: inode,
cancel: cancel,
log: log,
}, nil
}
func (w *LockWatcher) Run(ctx context.Context) {
ticker := time.NewTicker(lockWatchInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
w.check()
}
}
}
func (w *LockWatcher) check() {
info, err := os.Stat(lockFileName)
if err != nil {
w.log.Error("Lock file disappeared, shutting down", "file", lockFileName, "error", err)
w.cancel()
return
}
currentInode, err := getStatInode(info)
if err != nil {
w.log.Error("Failed to read lock file inode, shutting down", "error", err)
w.cancel()
return
}
if currentInode != w.originalInode {
w.log.Error("Lock file was replaced (inode changed), shutting down",
"originalInode", w.originalInode,
"currentInode", currentInode,
)
w.cancel()
}
}
func getFileInode(f *os.File) (uint64, error) {
info, err := f.Stat()
if err != nil {
return 0, err
}
return getStatInode(info)
}
func getStatInode(info os.FileInfo) (uint64, error) {
stat, ok := info.Sys().(*syscall.Stat_t)
if !ok {
return 0, os.ErrInvalid
}
return stat.Ino, nil
}

View File

@@ -0,0 +1,110 @@
//go:build !windows
package start
import (
"context"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"databasus-agent/internal/logger"
)
func Test_NewLockWatcher_CapturesInode(t *testing.T) {
setupTempDir(t)
log := logger.GetLogger()
lockFile, err := AcquireLock(log)
require.NoError(t, err)
defer ReleaseLock(lockFile)
_, cancel := context.WithCancel(t.Context())
defer cancel()
watcher, err := NewLockWatcher(lockFile, cancel, log)
require.NoError(t, err)
assert.NotZero(t, watcher.originalInode)
}
func Test_LockWatcher_FileUnchanged_ContextNotCancelled(t *testing.T) {
setupTempDir(t)
log := logger.GetLogger()
lockFile, err := AcquireLock(log)
require.NoError(t, err)
defer ReleaseLock(lockFile)
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
watcher, err := NewLockWatcher(lockFile, cancel, log)
require.NoError(t, err)
watcher.check()
watcher.check()
watcher.check()
select {
case <-ctx.Done():
t.Fatal("context should not be cancelled when lock file is unchanged")
default:
}
}
func Test_LockWatcher_FileDeleted_CancelsContext(t *testing.T) {
setupTempDir(t)
log := logger.GetLogger()
lockFile, err := AcquireLock(log)
require.NoError(t, err)
defer ReleaseLock(lockFile)
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
watcher, err := NewLockWatcher(lockFile, cancel, log)
require.NoError(t, err)
err = os.Remove(lockFileName)
require.NoError(t, err)
watcher.check()
select {
case <-ctx.Done():
default:
t.Fatal("context should be cancelled when lock file is deleted")
}
}
func Test_LockWatcher_FileReplacedWithDifferentInode_CancelsContext(t *testing.T) {
setupTempDir(t)
log := logger.GetLogger()
lockFile, err := AcquireLock(log)
require.NoError(t, err)
defer ReleaseLock(lockFile)
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
watcher, err := NewLockWatcher(lockFile, cancel, log)
require.NoError(t, err)
err = os.Remove(lockFileName)
require.NoError(t, err)
err = os.WriteFile(lockFileName, []byte("99999\n"), 0o644)
require.NoError(t, err)
watcher.check()
select {
case <-ctx.Done():
default:
t.Fatal("context should be cancelled when lock file inode changes")
}
}

View File

@@ -0,0 +1,17 @@
//go:build windows
package start
import (
"context"
"log/slog"
"os"
)
type LockWatcher struct{}
func NewLockWatcher(_ *os.File, _ context.CancelFunc, _ *slog.Logger) (*LockWatcher, error) {
return &LockWatcher{}, nil
}
func (w *LockWatcher) Run(_ context.Context) {}

View File

@@ -0,0 +1,18 @@
package start
import (
"log/slog"
"os"
)
func AcquireLock(log *slog.Logger) (*os.File, error) {
log.Warn("Process locking is not supported on Windows, skipping")
return nil, nil
}
func ReleaseLock(f *os.File) {
if f != nil {
_ = f.Close()
}
}

View File

@@ -1,21 +1,101 @@
package start
import (
"context"
"errors"
"fmt"
"log/slog"
"os"
"os/exec"
"os/signal"
"path/filepath"
"runtime"
"strconv"
"strings"
"syscall"
"time"
"github.com/jackc/pgx/v5"
"databasus-agent/internal/config"
"databasus-agent/internal/features/api"
full_backup "databasus-agent/internal/features/full_backup"
"databasus-agent/internal/features/upgrade"
"databasus-agent/internal/features/wal"
)
func Run(cfg *config.Config, log *slog.Logger) error {
const (
pgBasebackupVerifyTimeout = 10 * time.Second
dbVerifyTimeout = 10 * time.Second
minPgMajorVersion = 15
)
func Start(cfg *config.Config, agentVersion string, isDev bool, log *slog.Logger) error {
if err := validateConfig(cfg); err != nil {
return err
}
log.Info("start: stub — not yet implemented",
"dbId", cfg.DbID,
"hasToken", cfg.Token != "",
)
if err := verifyPgBasebackup(cfg, log); err != nil {
return err
}
if err := verifyDatabase(cfg, log); err != nil {
return err
}
if runtime.GOOS == "windows" {
return RunDaemon(cfg, agentVersion, isDev, log)
}
pid, err := spawnDaemon(log)
if err != nil {
return err
}
fmt.Printf("Agent started in background (PID %d)\n", pid)
return nil
}
func RunDaemon(cfg *config.Config, agentVersion string, isDev bool, log *slog.Logger) error {
lockFile, err := AcquireLock(log)
if err != nil {
return err
}
defer ReleaseLock(lockFile)
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer cancel()
watcher, err := NewLockWatcher(lockFile, cancel, log)
if err != nil {
return fmt.Errorf("failed to initialize lock watcher: %w", err)
}
go watcher.Run(ctx)
apiClient := api.NewClient(cfg.DatabasusHost, cfg.Token, log)
var backgroundUpgrader *upgrade.BackgroundUpgrader
if agentVersion != "dev" && runtime.GOOS != "windows" {
backgroundUpgrader = upgrade.NewBackgroundUpgrader(apiClient, agentVersion, isDev, cancel, log)
go backgroundUpgrader.Run(ctx)
}
fullBackuper := full_backup.NewFullBackuper(cfg, apiClient, log)
go fullBackuper.Run(ctx)
streamer := wal.NewStreamer(cfg, apiClient, log)
streamer.Run(ctx)
if backgroundUpgrader != nil {
backgroundUpgrader.WaitForCompletion(30 * time.Second)
if backgroundUpgrader.IsUpgraded() {
return upgrade.ErrUpgradeRestart
}
}
log.Info("Agent stopped")
return nil
}
@@ -33,5 +113,213 @@ func validateConfig(cfg *config.Config) error {
return errors.New("argument token is required")
}
if cfg.PgHost == "" {
return errors.New("argument pg-host is required")
}
if cfg.PgPort <= 0 {
return errors.New("argument pg-port must be a positive number")
}
if cfg.PgUser == "" {
return errors.New("argument pg-user is required")
}
if cfg.PgType != "host" && cfg.PgType != "docker" {
return fmt.Errorf("argument pg-type must be 'host' or 'docker', got '%s'", cfg.PgType)
}
if cfg.PgWalDir == "" {
return errors.New("argument pg-wal-dir is required")
}
if cfg.PgType == "docker" && cfg.PgDockerContainerName == "" {
return errors.New("argument pg-docker-container-name is required when pg-type is 'docker'")
}
return nil
}
func verifyPgBasebackup(cfg *config.Config, log *slog.Logger) error {
switch cfg.PgType {
case "host":
return verifyPgBasebackupHost(cfg, log)
case "docker":
return verifyPgBasebackupDocker(cfg, log)
default:
return fmt.Errorf("unexpected pg-type: %s", cfg.PgType)
}
}
func verifyPgBasebackupHost(cfg *config.Config, log *slog.Logger) error {
binary := "pg_basebackup"
if cfg.PgHostBinDir != "" {
binary = filepath.Join(cfg.PgHostBinDir, "pg_basebackup")
}
ctx, cancel := context.WithTimeout(context.Background(), pgBasebackupVerifyTimeout)
defer cancel()
output, err := exec.CommandContext(ctx, binary, "--version").CombinedOutput()
if err != nil {
if cfg.PgHostBinDir != "" {
return fmt.Errorf(
"pg_basebackup not found at '%s': %w. Verify pg-host-bin-dir is correct",
binary, err,
)
}
return fmt.Errorf(
"pg_basebackup not found in PATH: %w. Install PostgreSQL client tools or set pg-host-bin-dir",
err,
)
}
log.Info("pg_basebackup verified", "version", strings.TrimSpace(string(output)))
return nil
}
func verifyPgBasebackupDocker(cfg *config.Config, log *slog.Logger) error {
ctx, cancel := context.WithTimeout(context.Background(), pgBasebackupVerifyTimeout)
defer cancel()
output, err := exec.CommandContext(ctx,
"docker", "exec", cfg.PgDockerContainerName,
"pg_basebackup", "--version",
).CombinedOutput()
if err != nil {
return fmt.Errorf(
"pg_basebackup not available in container '%s': %w. "+
"Check that the container is running and pg_basebackup is installed inside it",
cfg.PgDockerContainerName, err,
)
}
log.Info("pg_basebackup verified (docker)",
"container", cfg.PgDockerContainerName,
"version", strings.TrimSpace(string(output)),
)
return nil
}
func verifyDatabase(cfg *config.Config, log *slog.Logger) error {
switch cfg.PgType {
case "docker":
return verifyDatabaseDocker(cfg, log)
default:
return verifyDatabaseHost(cfg, log)
}
}
func verifyDatabaseHost(cfg *config.Config, log *slog.Logger) error {
connStr := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=postgres sslmode=disable",
cfg.PgHost, cfg.PgPort, cfg.PgUser, cfg.PgPassword,
)
ctx, cancel := context.WithTimeout(context.Background(), dbVerifyTimeout)
defer cancel()
conn, err := pgx.Connect(ctx, connStr)
if err != nil {
return fmt.Errorf(
"failed to connect to PostgreSQL at %s:%d as user '%s': %w",
cfg.PgHost, cfg.PgPort, cfg.PgUser, err,
)
}
defer func() { _ = conn.Close(ctx) }()
if err := conn.Ping(ctx); err != nil {
return fmt.Errorf("PostgreSQL ping failed at %s:%d: %w",
cfg.PgHost, cfg.PgPort, err,
)
}
var versionNumStr string
if err := conn.QueryRow(ctx, "SHOW server_version_num").Scan(&versionNumStr); err != nil {
return fmt.Errorf("failed to query PostgreSQL version: %w", err)
}
majorVersion, err := parsePgVersionNum(versionNumStr)
if err != nil {
return fmt.Errorf("failed to parse PostgreSQL version '%s': %w", versionNumStr, err)
}
if majorVersion < minPgMajorVersion {
return fmt.Errorf(
"PostgreSQL %d is not supported, minimum required version is %d",
majorVersion, minPgMajorVersion,
)
}
log.Info("PostgreSQL connection verified",
"host", cfg.PgHost,
"port", cfg.PgPort,
"user", cfg.PgUser,
"version", majorVersion,
)
return nil
}
func verifyDatabaseDocker(cfg *config.Config, log *slog.Logger) error {
ctx, cancel := context.WithTimeout(context.Background(), dbVerifyTimeout)
defer cancel()
query := "SELECT current_setting('server_version_num')"
cmd := exec.CommandContext(ctx,
"docker", "exec",
"-e", "PGPASSWORD="+cfg.PgPassword,
cfg.PgDockerContainerName,
"psql", "-h", "localhost", "-p", "5432", "-U", cfg.PgUser,
"-d", "postgres", "-t", "-A", "-c", query,
)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf(
"failed to connect to PostgreSQL in container '%s' as user '%s': %w (output: %s)",
cfg.PgDockerContainerName, cfg.PgUser, err, strings.TrimSpace(string(output)),
)
}
versionNumStr := strings.TrimSpace(string(output))
majorVersion, err := parsePgVersionNum(versionNumStr)
if err != nil {
return fmt.Errorf("failed to parse PostgreSQL version '%s': %w", versionNumStr, err)
}
if majorVersion < minPgMajorVersion {
return fmt.Errorf(
"PostgreSQL %d is not supported, minimum required version is %d",
majorVersion, minPgMajorVersion,
)
}
log.Info("PostgreSQL connection verified (docker)",
"container", cfg.PgDockerContainerName,
"user", cfg.PgUser,
"version", majorVersion,
)
return nil
}
func parsePgVersionNum(versionNumStr string) (int, error) {
versionNum, err := strconv.Atoi(strings.TrimSpace(versionNumStr))
if err != nil {
return 0, fmt.Errorf("invalid version number: %w", err)
}
if versionNum <= 0 {
return 0, fmt.Errorf("invalid version number: %d", versionNum)
}
majorVersion := versionNum / 10000
return majorVersion, nil
}

View File

@@ -0,0 +1,84 @@
package start
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_ParsePgVersionNum_SupportedVersions_ReturnsMajorVersion(t *testing.T) {
tests := []struct {
name string
versionNumStr string
expectedMajor int
}{
{name: "PG 15.0", versionNumStr: "150000", expectedMajor: 15},
{name: "PG 15.4", versionNumStr: "150004", expectedMajor: 15},
{name: "PG 16.0", versionNumStr: "160000", expectedMajor: 16},
{name: "PG 16.3", versionNumStr: "160003", expectedMajor: 16},
{name: "PG 17.2", versionNumStr: "170002", expectedMajor: 17},
{name: "PG 18.0", versionNumStr: "180000", expectedMajor: 18},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
major, err := parsePgVersionNum(tt.versionNumStr)
require.NoError(t, err)
assert.Equal(t, tt.expectedMajor, major)
assert.GreaterOrEqual(t, major, minPgMajorVersion)
})
}
}
func Test_ParsePgVersionNum_UnsupportedVersions_ReturnsMajorVersionBelow15(t *testing.T) {
tests := []struct {
name string
versionNumStr string
expectedMajor int
}{
{name: "PG 12.5", versionNumStr: "120005", expectedMajor: 12},
{name: "PG 13.0", versionNumStr: "130000", expectedMajor: 13},
{name: "PG 14.12", versionNumStr: "140012", expectedMajor: 14},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
major, err := parsePgVersionNum(tt.versionNumStr)
require.NoError(t, err)
assert.Equal(t, tt.expectedMajor, major)
assert.Less(t, major, minPgMajorVersion)
})
}
}
func Test_ParsePgVersionNum_InvalidInput_ReturnsError(t *testing.T) {
tests := []struct {
name string
versionNumStr string
}{
{name: "empty string", versionNumStr: ""},
{name: "non-numeric", versionNumStr: "abc"},
{name: "negative number", versionNumStr: "-1"},
{name: "zero", versionNumStr: "0"},
{name: "float", versionNumStr: "15.4"},
{name: "whitespace only", versionNumStr: " "},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := parsePgVersionNum(tt.versionNumStr)
require.Error(t, err)
})
}
}
func Test_ParsePgVersionNum_WithWhitespace_ParsesCorrectly(t *testing.T) {
major, err := parsePgVersionNum(" 150004 ")
require.NoError(t, err)
assert.Equal(t, 15, major)
}

View File

@@ -0,0 +1,88 @@
package upgrade
import (
"context"
"log/slog"
"sync/atomic"
"time"
"databasus-agent/internal/features/api"
)
const backgroundCheckInterval = 10 * time.Second
type BackgroundUpgrader struct {
apiClient *api.Client
currentVersion string
isDev bool
cancel context.CancelFunc
isUpgraded atomic.Bool
log *slog.Logger
done chan struct{}
}
func NewBackgroundUpgrader(
apiClient *api.Client,
currentVersion string,
isDev bool,
cancel context.CancelFunc,
log *slog.Logger,
) *BackgroundUpgrader {
return &BackgroundUpgrader{
apiClient,
currentVersion,
isDev,
cancel,
atomic.Bool{},
log,
make(chan struct{}),
}
}
func (u *BackgroundUpgrader) Run(ctx context.Context) {
defer close(u.done)
ticker := time.NewTicker(backgroundCheckInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if u.checkAndUpgrade() {
return
}
}
}
}
func (u *BackgroundUpgrader) IsUpgraded() bool {
return u.isUpgraded.Load()
}
func (u *BackgroundUpgrader) WaitForCompletion(timeout time.Duration) {
select {
case <-u.done:
case <-time.After(timeout):
}
}
func (u *BackgroundUpgrader) checkAndUpgrade() bool {
isUpgraded, err := CheckAndUpdate(u.apiClient, u.currentVersion, u.isDev, u.log)
if err != nil {
u.log.Warn("Background update check failed", "error", err)
return false
}
if !isUpgraded {
return false
}
u.log.Info("Background upgrade complete, restarting...")
u.isUpgraded.Store(true)
u.cancel()
return true
}

View File

@@ -0,0 +1,5 @@
package upgrade
import "errors"
var ErrUpgradeRestart = errors.New("agent upgraded, restart required")

View File

@@ -2,49 +2,47 @@ package upgrade
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"log/slog"
"os"
"os/exec"
"runtime"
"strings"
"syscall"
"time"
"databasus-agent/internal/features/api"
)
type Logger interface {
Info(msg string, args ...any)
Warn(msg string, args ...any)
Error(msg string, args ...any)
}
type versionResponse struct {
Version string `json:"version"`
}
func CheckAndUpdate(databasusHost, currentVersion string, isDev bool, log Logger) error {
// CheckAndUpdate checks if a new version is available and upgrades the binary on disk.
// Returns (true, nil) if the binary was upgraded, (false, nil) if already up to date,
// or (false, err) on failure. Callers are responsible for re-exec or restart signaling.
func CheckAndUpdate(apiClient *api.Client, currentVersion string, isDev bool, log *slog.Logger) (bool, error) {
if isDev {
log.Info("Skipping update check (development mode)")
return nil
return false, nil
}
serverVersion, err := fetchServerVersion(databasusHost, log)
serverVersion, err := apiClient.FetchServerVersion(context.Background())
if err != nil {
return nil
log.Warn("Could not reach server for update check", "error", err)
return false, fmt.Errorf(
"unable to check version, please verify Databasus server is available: %w",
err,
)
}
if serverVersion == currentVersion {
log.Info("Agent version is up to date", "version", currentVersion)
return nil
return false, nil
}
log.Info("Updating agent...", "current", currentVersion, "target", serverVersion)
selfPath, err := os.Executable()
if err != nil {
return fmt.Errorf("failed to determine executable path: %w", err)
return false, fmt.Errorf("failed to determine executable path: %w", err)
}
tempPath := selfPath + ".update"
@@ -53,93 +51,25 @@ func CheckAndUpdate(databasusHost, currentVersion string, isDev bool, log Logger
_ = os.Remove(tempPath)
}()
if err := downloadBinary(databasusHost, tempPath); err != nil {
return fmt.Errorf("failed to download update: %w", err)
if err := apiClient.DownloadAgentBinary(context.Background(), runtime.GOARCH, tempPath); err != nil {
return false, fmt.Errorf("failed to download update: %w", err)
}
if err := os.Chmod(tempPath, 0o755); err != nil {
return fmt.Errorf("failed to set permissions on update: %w", err)
return false, fmt.Errorf("failed to set permissions on update: %w", err)
}
if err := verifyBinary(tempPath, serverVersion); err != nil {
return fmt.Errorf("update verification failed: %w", err)
return false, fmt.Errorf("update verification failed: %w", err)
}
if err := os.Rename(tempPath, selfPath); err != nil {
return fmt.Errorf("failed to replace binary (try --skip-update if this persists): %w", err)
return false, fmt.Errorf("failed to replace binary (try --skip-update if this persists): %w", err)
}
log.Info("Update complete, re-executing...")
log.Info("Agent binary updated", "version", serverVersion)
return syscall.Exec(selfPath, os.Args, os.Environ())
}
func fetchServerVersion(host string, log Logger) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
client := &http.Client{Timeout: 10 * time.Second}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, host+"/api/v1/system/version", nil)
if err != nil {
return "", err
}
resp, err := client.Do(req)
if err != nil {
log.Warn("Could not reach server for update check, continuing", "error", err)
return "", err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
log.Warn(
"Server returned non-OK status for version check, continuing",
"status",
resp.StatusCode,
)
return "", fmt.Errorf("status %d", resp.StatusCode)
}
var ver versionResponse
if err := json.NewDecoder(resp.Body).Decode(&ver); err != nil {
log.Warn("Failed to parse server version response, continuing", "error", err)
return "", err
}
return ver.Version, nil
}
func downloadBinary(host, destPath string) error {
url := fmt.Sprintf("%s/api/v1/system/agent?arch=%s", host, runtime.GOARCH)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("server returned %d for agent download", resp.StatusCode)
}
f, err := os.Create(destPath)
if err != nil {
return err
}
defer func() { _ = f.Close() }()
_, err = io.Copy(f, resp.Body)
return err
return true, nil
}
func verifyBinary(binaryPath, expectedVersion string) error {

View File

@@ -0,0 +1,204 @@
package wal
import (
"context"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"regexp"
"slices"
"strings"
"time"
"github.com/klauspost/compress/zstd"
"databasus-agent/internal/config"
"databasus-agent/internal/features/api"
)
var uploadIdleTimeout = 5 * time.Minute
const (
pollInterval = 10 * time.Second
uploadTimeout = 5 * time.Minute
)
var segmentNameRegex = regexp.MustCompile(`^[0-9A-Fa-f]{24}$`)
type Streamer struct {
cfg *config.Config
apiClient *api.Client
log *slog.Logger
}
func NewStreamer(cfg *config.Config, apiClient *api.Client, log *slog.Logger) *Streamer {
return &Streamer{
cfg: cfg,
apiClient: apiClient,
log: log,
}
}
func (s *Streamer) Run(ctx context.Context) {
s.log.Info("WAL streamer started", "pgWalDir", s.cfg.PgWalDir)
s.processQueue(ctx)
ticker := time.NewTicker(pollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
s.log.Info("WAL streamer stopping")
return
case <-ticker.C:
s.processQueue(ctx)
}
}
}
func (s *Streamer) processQueue(ctx context.Context) {
segments, err := s.listSegments()
if err != nil {
s.log.Error("Failed to list WAL segments", "error", err)
return
}
if len(segments) == 0 {
s.log.Info("No WAL segments pending", "dir", s.cfg.PgWalDir)
return
}
s.log.Info("WAL segments pending upload", "dir", s.cfg.PgWalDir, "count", len(segments))
for _, segmentName := range segments {
if ctx.Err() != nil {
return
}
if err := s.uploadSegment(ctx, segmentName); err != nil {
s.log.Error("Failed to upload WAL segment",
"segment", segmentName,
"error", err,
)
return
}
}
}
func (s *Streamer) listSegments() ([]string, error) {
entries, err := os.ReadDir(s.cfg.PgWalDir)
if err != nil {
return nil, fmt.Errorf("read wal dir: %w", err)
}
var segments []string
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if strings.HasSuffix(name, ".tmp") {
continue
}
if !segmentNameRegex.MatchString(name) {
continue
}
segments = append(segments, name)
}
slices.Sort(segments)
return segments, nil
}
func (s *Streamer) uploadSegment(ctx context.Context, segmentName string) error {
filePath := filepath.Join(s.cfg.PgWalDir, segmentName)
pr, pw := io.Pipe()
defer func() { _ = pr.Close() }()
go s.compressAndStream(pw, filePath)
uploadCtx, timeoutCancel := context.WithTimeout(ctx, uploadTimeout)
defer timeoutCancel()
idleCtx, idleCancel := context.WithCancelCause(uploadCtx)
defer idleCancel(nil)
idleReader := api.NewIdleTimeoutReader(pr, uploadIdleTimeout, idleCancel)
defer idleReader.Stop()
s.log.Info("Uploading WAL segment", "segment", segmentName)
result, err := s.apiClient.UploadWalSegment(idleCtx, segmentName, idleReader)
if err != nil {
if cause := context.Cause(idleCtx); cause != nil {
return fmt.Errorf("upload WAL segment: %w", cause)
}
return err
}
if result.IsGapDetected {
s.log.Warn("WAL chain gap detected",
"segment", segmentName,
"expected", result.ExpectedSegmentName,
"received", result.ReceivedSegmentName,
)
return fmt.Errorf("gap detected for segment %s", segmentName)
}
s.log.Info("WAL segment uploaded", "segment", segmentName)
if *s.cfg.IsDeleteWalAfterUpload {
if err := os.Remove(filePath); err != nil {
s.log.Warn("Failed to delete uploaded WAL segment",
"segment", segmentName,
"error", err,
)
}
}
return nil
}
func (s *Streamer) compressAndStream(pw *io.PipeWriter, filePath string) {
f, err := os.Open(filePath)
if err != nil {
_ = pw.CloseWithError(fmt.Errorf("open file: %w", err))
return
}
defer func() { _ = f.Close() }()
encoder, err := zstd.NewWriter(pw,
zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(5)),
zstd.WithEncoderCRC(true),
)
if err != nil {
_ = pw.CloseWithError(fmt.Errorf("create zstd encoder: %w", err))
return
}
if _, err := io.Copy(encoder, f); err != nil {
_ = encoder.Close()
_ = pw.CloseWithError(fmt.Errorf("compress: %w", err))
return
}
if err := encoder.Close(); err != nil {
_ = pw.CloseWithError(fmt.Errorf("close encoder: %w", err))
return
}
_ = pw.Close()
}

View File

@@ -0,0 +1,393 @@
package wal
import (
"context"
"crypto/rand"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/klauspost/compress/zstd"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"databasus-agent/internal/config"
"databasus-agent/internal/features/api"
"databasus-agent/internal/logger"
)
func Test_UploadSegment_SingleSegment_ServerReceivesCorrectHeadersAndBody(t *testing.T) {
walDir := createTestWalDir(t)
segmentContent := []byte("test-wal-segment-data-for-upload")
writeTestSegment(t, walDir, "000000010000000100000001", segmentContent)
var receivedHeaders http.Header
var receivedBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedHeaders = r.Header.Clone()
body, err := io.ReadAll(r.Body)
require.NoError(t, err)
receivedBody = body
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
streamer := newTestStreamer(walDir, server.URL)
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
defer cancel()
go streamer.Run(ctx)
time.Sleep(500 * time.Millisecond)
cancel()
require.NotNil(t, receivedHeaders)
assert.Equal(t, "test-token", receivedHeaders.Get("Authorization"))
assert.Equal(t, "application/octet-stream", receivedHeaders.Get("Content-Type"))
assert.Equal(t, "000000010000000100000001", receivedHeaders.Get("X-Wal-Segment-Name"))
decompressed := decompressZstd(t, receivedBody)
assert.Equal(t, segmentContent, decompressed)
}
func Test_UploadSegments_MultipleSegmentsOutOfOrder_UploadedInAscendingOrder(t *testing.T) {
walDir := createTestWalDir(t)
writeTestSegment(t, walDir, "000000010000000100000003", []byte("third"))
writeTestSegment(t, walDir, "000000010000000100000001", []byte("first"))
writeTestSegment(t, walDir, "000000010000000100000002", []byte("second"))
var mu sync.Mutex
var uploadOrder []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
uploadOrder = append(uploadOrder, r.Header.Get("X-Wal-Segment-Name"))
mu.Unlock()
_, _ = io.ReadAll(r.Body)
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
streamer := newTestStreamer(walDir, server.URL)
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
defer cancel()
go streamer.Run(ctx)
time.Sleep(500 * time.Millisecond)
cancel()
mu.Lock()
defer mu.Unlock()
require.Len(t, uploadOrder, 3)
assert.Equal(t, "000000010000000100000001", uploadOrder[0])
assert.Equal(t, "000000010000000100000002", uploadOrder[1])
assert.Equal(t, "000000010000000100000003", uploadOrder[2])
}
func Test_UploadSegments_DirectoryHasTmpFiles_TmpFilesIgnored(t *testing.T) {
walDir := createTestWalDir(t)
writeTestSegment(t, walDir, "000000010000000100000001", []byte("real segment"))
writeTestSegment(t, walDir, "000000010000000100000002.tmp", []byte("partial copy"))
var mu sync.Mutex
var uploadedSegments []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
uploadedSegments = append(uploadedSegments, r.Header.Get("X-Wal-Segment-Name"))
mu.Unlock()
_, _ = io.ReadAll(r.Body)
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
streamer := newTestStreamer(walDir, server.URL)
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
defer cancel()
go streamer.Run(ctx)
time.Sleep(500 * time.Millisecond)
cancel()
mu.Lock()
defer mu.Unlock()
require.Len(t, uploadedSegments, 1)
assert.Equal(t, "000000010000000100000001", uploadedSegments[0])
}
func Test_UploadSegment_DeleteEnabled_FileRemovedAfterUpload(t *testing.T) {
walDir := createTestWalDir(t)
segmentName := "000000010000000100000001"
writeTestSegment(t, walDir, segmentName, []byte("segment data"))
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.ReadAll(r.Body)
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
isDeleteEnabled := true
cfg := createTestConfig(walDir, server.URL)
cfg.IsDeleteWalAfterUpload = &isDeleteEnabled
apiClient := api.NewClient(server.URL, cfg.Token, logger.GetLogger())
streamer := NewStreamer(cfg, apiClient, logger.GetLogger())
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
defer cancel()
go streamer.Run(ctx)
time.Sleep(500 * time.Millisecond)
cancel()
_, err := os.Stat(filepath.Join(walDir, segmentName))
assert.True(t, os.IsNotExist(err), "segment file should be deleted after successful upload")
}
func Test_UploadSegment_DeleteDisabled_FileKeptAfterUpload(t *testing.T) {
walDir := createTestWalDir(t)
segmentName := "000000010000000100000001"
writeTestSegment(t, walDir, segmentName, []byte("segment data"))
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.ReadAll(r.Body)
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
isDeleteDisabled := false
cfg := createTestConfig(walDir, server.URL)
cfg.IsDeleteWalAfterUpload = &isDeleteDisabled
apiClient := api.NewClient(server.URL, cfg.Token, logger.GetLogger())
streamer := NewStreamer(cfg, apiClient, logger.GetLogger())
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
defer cancel()
go streamer.Run(ctx)
time.Sleep(500 * time.Millisecond)
cancel()
_, err := os.Stat(filepath.Join(walDir, segmentName))
assert.NoError(t, err, "segment file should be kept when delete is disabled")
}
func Test_UploadSegment_ServerReturns500_FileKeptInQueue(t *testing.T) {
walDir := createTestWalDir(t)
segmentName := "000000010000000100000001"
writeTestSegment(t, walDir, segmentName, []byte("segment data"))
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.ReadAll(r.Body)
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(`{"error":"internal server error"}`))
}))
defer server.Close()
streamer := newTestStreamer(walDir, server.URL)
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
defer cancel()
go streamer.Run(ctx)
time.Sleep(500 * time.Millisecond)
cancel()
_, err := os.Stat(filepath.Join(walDir, segmentName))
assert.NoError(t, err, "segment file should remain in queue after server error")
}
func Test_ProcessQueue_EmptyDirectory_NoUploads(t *testing.T) {
walDir := createTestWalDir(t)
uploadCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
uploadCount++
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
streamer := newTestStreamer(walDir, server.URL)
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
defer cancel()
go streamer.Run(ctx)
time.Sleep(500 * time.Millisecond)
cancel()
assert.Equal(t, 0, uploadCount, "no uploads should occur for empty directory")
}
func Test_Run_ContextCancelled_StopsImmediately(t *testing.T) {
walDir := createTestWalDir(t)
streamer := newTestStreamer(walDir, "http://localhost:0")
ctx, cancel := context.WithCancel(t.Context())
cancel()
done := make(chan struct{})
go func() {
streamer.Run(ctx)
close(done)
}()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("Run should have stopped immediately when context is already cancelled")
}
}
func Test_UploadSegment_ServerReturns409_FileNotDeleted(t *testing.T) {
walDir := createTestWalDir(t)
segmentName := "000000010000000100000005"
writeTestSegment(t, walDir, segmentName, []byte("gap segment"))
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.ReadAll(r.Body)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusConflict)
resp := map[string]string{
"error": "gap_detected",
"expectedSegmentName": "000000010000000100000003",
"receivedSegmentName": segmentName,
}
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
streamer := newTestStreamer(walDir, server.URL)
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
defer cancel()
go streamer.Run(ctx)
time.Sleep(500 * time.Millisecond)
cancel()
_, err := os.Stat(filepath.Join(walDir, segmentName))
assert.NoError(t, err, "segment file should not be deleted on gap detection")
}
func Test_UploadSegment_WhenUploadStalls_FailsWithIdleTimeout(t *testing.T) {
walDir := createTestWalDir(t)
// Use incompressible random data to ensure TCP buffers fill up
segmentContent := make([]byte, 1024*1024)
_, err := rand.Read(segmentContent)
require.NoError(t, err)
writeTestSegment(t, walDir, "000000010000000100000001", segmentContent)
var requestReceived atomic.Bool
handlerDone := make(chan struct{})
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestReceived.Store(true)
// Read one byte then stall — simulates a network stall
buf := make([]byte, 1)
_, _ = r.Body.Read(buf)
<-handlerDone
}))
defer server.Close()
defer close(handlerDone)
origIdleTimeout := uploadIdleTimeout
uploadIdleTimeout = 200 * time.Millisecond
defer func() { uploadIdleTimeout = origIdleTimeout }()
streamer := newTestStreamer(walDir, server.URL)
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
defer cancel()
uploadErr := streamer.uploadSegment(ctx, "000000010000000100000001")
assert.Error(t, uploadErr, "upload should fail when stalled")
assert.True(t, requestReceived.Load(), "server should have received the request")
assert.Contains(t, uploadErr.Error(), "idle timeout", "error should mention idle timeout")
_, statErr := os.Stat(filepath.Join(walDir, "000000010000000100000001"))
assert.NoError(t, statErr, "segment file should remain in queue after idle timeout")
}
func newTestStreamer(walDir, serverURL string) *Streamer {
cfg := createTestConfig(walDir, serverURL)
apiClient := api.NewClient(serverURL, cfg.Token, logger.GetLogger())
return NewStreamer(cfg, apiClient, logger.GetLogger())
}
func createTestWalDir(t *testing.T) string {
t.Helper()
baseDir := filepath.Join(".", ".test-tmp")
if err := os.MkdirAll(baseDir, 0o755); err != nil {
t.Fatalf("failed to create base test dir: %v", err)
}
dir, err := os.MkdirTemp(baseDir, t.Name()+"-*")
if err != nil {
t.Fatalf("failed to create test wal dir: %v", err)
}
t.Cleanup(func() {
_ = os.RemoveAll(dir)
})
return dir
}
func writeTestSegment(t *testing.T, dir, name string, content []byte) {
t.Helper()
if err := os.WriteFile(filepath.Join(dir, name), content, 0o644); err != nil {
t.Fatalf("failed to write test segment %s: %v", name, err)
}
}
func createTestConfig(walDir, serverURL string) *config.Config {
isDeleteEnabled := true
return &config.Config{
DatabasusHost: serverURL,
DbID: "test-db-id",
Token: "test-token",
PgWalDir: walDir,
IsDeleteWalAfterUpload: &isDeleteEnabled,
}
}
func decompressZstd(t *testing.T, data []byte) []byte {
t.Helper()
decoder, err := zstd.NewReader(nil)
require.NoError(t, err)
defer decoder.Close()
decoded, err := decoder.DecodeAll(data, nil)
require.NoError(t, err)
return decoded
}

View File

@@ -1,47 +1,115 @@
package logger
import (
"fmt"
"io"
"log/slog"
"os"
"sync"
"time"
)
var (
loggerInstance *slog.Logger
once sync.Once
const (
logFileName = "databasus.log"
oldLogFileName = "databasus.log.old"
maxLogFileSize = 5 * 1024 * 1024 // 5MB
)
func Init(isDebug bool) {
level := slog.LevelInfo
if isDebug {
level = slog.LevelDebug
}
once.Do(func() {
loggerInstance = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
Level: level,
ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr {
if a.Key == slog.TimeKey {
a.Value = slog.StringValue(time.Now().Format("2006/01/02 15:04:05"))
}
if a.Key == slog.LevelKey {
return slog.Attr{}
}
return a
},
}))
loggerInstance.Info("Text structured logger initialized")
})
type rotatingWriter struct {
mu sync.Mutex
file *os.File
currentSize int64
maxSize int64
logPath string
oldLogPath string
}
// GetLogger returns a singleton slog.Logger that logs to the console
func GetLogger() *slog.Logger {
if loggerInstance == nil {
Init(false)
func (w *rotatingWriter) Write(p []byte) (int, error) {
w.mu.Lock()
defer w.mu.Unlock()
if w.currentSize+int64(len(p)) > w.maxSize {
if err := w.rotate(); err != nil {
return 0, fmt.Errorf("failed to rotate log file: %w", err)
}
}
n, err := w.file.Write(p)
w.currentSize += int64(n)
return n, err
}
func (w *rotatingWriter) rotate() error {
if err := w.file.Close(); err != nil {
return fmt.Errorf("failed to close %s: %w", w.logPath, err)
}
if err := os.Remove(w.oldLogPath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to remove %s: %w", w.oldLogPath, err)
}
if err := os.Rename(w.logPath, w.oldLogPath); err != nil {
return fmt.Errorf("failed to rename %s to %s: %w", w.logPath, w.oldLogPath, err)
}
f, err := os.OpenFile(w.logPath, os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return fmt.Errorf("failed to create new %s: %w", w.logPath, err)
}
w.file = f
w.currentSize = 0
return nil
}
var loggerInstance *slog.Logger
var initLogger = sync.OnceFunc(initialize)
func GetLogger() *slog.Logger {
initLogger()
return loggerInstance
}
func initialize() {
writer := buildWriter()
loggerInstance = slog.New(slog.NewTextHandler(writer, &slog.HandlerOptions{
Level: slog.LevelInfo,
ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr {
if a.Key == slog.TimeKey {
a.Value = slog.StringValue(time.Now().Format("2006/01/02 15:04:05"))
}
if a.Key == slog.LevelKey {
return slog.Attr{}
}
return a
},
}))
}
func buildWriter() io.Writer {
f, err := os.OpenFile(logFileName, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
fmt.Fprintf(os.Stderr, "Warning: failed to open %s for logging: %v\n", logFileName, err)
return os.Stdout
}
var currentSize int64
if info, err := f.Stat(); err == nil {
currentSize = info.Size()
}
rw := &rotatingWriter{
file: f,
currentSize: currentSize,
maxSize: maxLogFileSize,
logPath: logFileName,
oldLogPath: oldLogFileName,
}
return io.MultiWriter(os.Stdout, rw)
}

View File

@@ -0,0 +1,128 @@
package logger
import (
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_Write_DataWrittenToFile(t *testing.T) {
rw, logPath, _ := setupRotatingWriter(t, 1024)
data := []byte("hello world\n")
n, err := rw.Write(data)
require.NoError(t, err)
assert.Equal(t, len(data), n)
assert.Equal(t, int64(len(data)), rw.currentSize)
content, err := os.ReadFile(logPath)
require.NoError(t, err)
assert.Equal(t, string(data), string(content))
}
func Test_Write_WhenLimitExceeded_FileRotated(t *testing.T) {
rw, logPath, oldLogPath := setupRotatingWriter(t, 100)
firstData := []byte(strings.Repeat("A", 80))
_, err := rw.Write(firstData)
require.NoError(t, err)
secondData := []byte(strings.Repeat("B", 30))
_, err = rw.Write(secondData)
require.NoError(t, err)
oldContent, err := os.ReadFile(oldLogPath)
require.NoError(t, err)
assert.Equal(t, string(firstData), string(oldContent))
newContent, err := os.ReadFile(logPath)
require.NoError(t, err)
assert.Equal(t, string(secondData), string(newContent))
assert.Equal(t, int64(len(secondData)), rw.currentSize)
}
func Test_Write_WhenOldFileExists_OldFileReplaced(t *testing.T) {
rw, _, oldLogPath := setupRotatingWriter(t, 100)
require.NoError(t, os.WriteFile(oldLogPath, []byte("stale data"), 0o644))
_, err := rw.Write([]byte(strings.Repeat("A", 80)))
require.NoError(t, err)
_, err = rw.Write([]byte(strings.Repeat("B", 30)))
require.NoError(t, err)
oldContent, err := os.ReadFile(oldLogPath)
require.NoError(t, err)
assert.Equal(t, strings.Repeat("A", 80), string(oldContent))
}
func Test_Write_MultipleSmallWrites_CurrentSizeAccumulated(t *testing.T) {
rw, _, _ := setupRotatingWriter(t, 1024)
var totalWritten int64
for range 10 {
data := []byte("line\n")
n, err := rw.Write(data)
require.NoError(t, err)
totalWritten += int64(n)
}
assert.Equal(t, totalWritten, rw.currentSize)
assert.Equal(t, int64(50), rw.currentSize)
}
func Test_Write_ExactlyAtBoundary_NoRotationUntilNextByte(t *testing.T) {
rw, logPath, oldLogPath := setupRotatingWriter(t, 100)
exactData := []byte(strings.Repeat("X", 100))
_, err := rw.Write(exactData)
require.NoError(t, err)
_, err = os.Stat(oldLogPath)
assert.True(t, os.IsNotExist(err), ".old file should not exist yet")
content, err := os.ReadFile(logPath)
require.NoError(t, err)
assert.Equal(t, string(exactData), string(content))
_, err = rw.Write([]byte("Z"))
require.NoError(t, err)
_, err = os.Stat(oldLogPath)
assert.NoError(t, err, ".old file should exist after exceeding limit")
assert.Equal(t, int64(1), rw.currentSize)
}
func setupRotatingWriter(t *testing.T, maxSize int64) (*rotatingWriter, string, string) {
t.Helper()
dir := t.TempDir()
logPath := filepath.Join(dir, "test.log")
oldLogPath := filepath.Join(dir, "test.log.old")
f, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY, 0o644)
require.NoError(t, err)
rw := &rotatingWriter{
file: f,
currentSize: 0,
maxSize: maxSize,
logPath: logPath,
oldLogPath: oldLogPath,
}
t.Cleanup(func() {
rw.file.Close()
})
return rw, logPath, oldLogPath
}

View File

@@ -27,6 +27,13 @@ VALKEY_PORT=6379
VALKEY_USERNAME=
VALKEY_PASSWORD=
VALKEY_IS_SSL=false
# billing
PRICE_PER_GB_CENTS=
IS_PADDLE_SANDBOX=true
PADDLE_API_KEY=
PADDLE_WEBHOOK_SECRET=
PADDLE_PRICE_ID=
PADDLE_CLIENT_TOKEN=
# testing
# to get Google Drive env variables: add storage in UI and copy data from added storage here
TEST_GOOGLE_DRIVE_CLIENT_ID=
@@ -45,9 +52,6 @@ TEST_MINIO_PORT=9000
TEST_MINIO_CONSOLE_PORT=9001
# testing NAS
TEST_NAS_PORT=7006
# testing Telegram
TEST_TELEGRAM_BOT_TOKEN=
TEST_TELEGRAM_CHAT_ID=
# testing Azure Blob Storage
TEST_AZURITE_BLOB_PORT=10000
# supabase

View File

@@ -9,6 +9,7 @@ import (
"os/exec"
"os/signal"
"path/filepath"
"runtime/debug"
"syscall"
"time"
@@ -25,6 +26,8 @@ import (
backups_download "databasus-backend/internal/features/backups/backups/download"
backups_services "databasus-backend/internal/features/backups/backups/services"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/billing"
billing_paddle "databasus-backend/internal/features/billing/paddle"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/disk"
"databasus-backend/internal/features/encryption/secrets"
@@ -105,7 +108,9 @@ func main() {
go generateSwaggerDocs(log)
gin.SetMode(gin.ReleaseMode)
ginApp := gin.Default()
ginApp := gin.New()
ginApp.Use(gin.Logger())
ginApp.Use(ginRecoveryWithLogger(log))
// Add GZIP compression middleware
ginApp.Use(gzip.Gzip(
@@ -188,7 +193,7 @@ func startServerWithGracefulShutdown(log *slog.Logger, app *gin.Engine) {
log.Info("Shutdown signal received")
// Gracefully shutdown VictoriaLogs writer
logger.ShutdownVictoriaLogs(5 * time.Second)
logger.ShutdownVictoriaLogs()
// The context is used to inform the server it has 10 seconds to finish
// the request it is currently handling
@@ -217,6 +222,10 @@ func setUpRoutes(r *gin.Engine) {
backups_controllers.GetPostgresWalBackupController().RegisterRoutes(v1)
databases.GetDatabaseController().RegisterPublicRoutes(v1)
if config.GetEnv().IsCloud {
billing_paddle.GetPaddleBillingController().RegisterPublicRoutes(v1)
}
// Setup auth middleware
userService := users_services.GetUserService()
authMiddleware := users_middleware.AuthMiddleware(userService)
@@ -240,6 +249,7 @@ func setUpRoutes(r *gin.Engine) {
audit_logs.GetAuditLogController().RegisterRoutes(protected)
users_controllers.GetManagementController().RegisterRoutes(protected)
users_controllers.GetSettingsController().RegisterRoutes(protected)
billing.GetBillingController().RegisterRoutes(protected)
}
func setUpDependencies() {
@@ -252,6 +262,11 @@ func setUpDependencies() {
storages.SetupDependencies()
backups_config.SetupDependencies()
task_cancellation.SetupDependencies()
billing.SetupDependencies()
if config.GetEnv().IsCloud {
billing_paddle.SetupDependencies()
}
}
func runBackgroundTasks(log *slog.Logger) {
@@ -308,6 +323,12 @@ func runBackgroundTasks(log *slog.Logger) {
go runWithPanicLogging(log, "restore nodes registry background service", func() {
restoring.GetRestoreNodesRegistry().Run(ctx)
})
if config.GetEnv().IsCloud {
go runWithPanicLogging(log, "billing background service", func() {
billing.GetBillingService().Run(ctx, *log)
})
}
} else {
log.Info("Skipping primary node tasks as not primary node")
}
@@ -330,7 +351,7 @@ func runBackgroundTasks(log *slog.Logger) {
func runWithPanicLogging(log *slog.Logger, serviceName string, fn func()) {
defer func() {
if r := recover(); r != nil {
log.Error("Panic in "+serviceName, "error", r)
log.Error("Panic in "+serviceName, "error", r, "stacktrace", string(debug.Stack()))
}
}()
fn()
@@ -410,6 +431,25 @@ func enableCors(ginApp *gin.Engine) {
}
}
func ginRecoveryWithLogger(log *slog.Logger) gin.HandlerFunc {
return func(ctx *gin.Context) {
defer func() {
if r := recover(); r != nil {
log.Error("Panic recovered in HTTP handler",
"error", r,
"stacktrace", string(debug.Stack()),
"method", ctx.Request.Method,
"path", ctx.Request.URL.Path,
)
ctx.AbortWithStatus(http.StatusInternalServerError)
}
}()
ctx.Next()
}
}
func mountFrontend(ginApp *gin.Engine) {
staticDir := "./ui/build"
ginApp.NoRoute(func(c *gin.Context) {

View File

@@ -5,6 +5,7 @@ go 1.26.1
require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.3
github.com/PaddleHQ/paddle-go-sdk v1.0.0
github.com/gin-contrib/cors v1.7.5
github.com/gin-contrib/gzip v1.2.3
github.com/gin-gonic/gin v1.10.0
@@ -100,6 +101,8 @@ require (
github.com/emersion/go-message v0.18.2 // indirect
github.com/emersion/go-vcard v0.0.0-20241024213814-c9703dde27ff // indirect
github.com/flynn/noise v1.1.0 // indirect
github.com/ggicci/httpin v0.19.0 // indirect
github.com/ggicci/owl v0.8.2 // indirect
github.com/go-chi/chi/v5 v5.2.3 // indirect
github.com/go-darwin/apfs v0.0.0-20211011131704-f84b94dbf348 // indirect
github.com/go-git/go-billy/v5 v5.6.2 // indirect

View File

@@ -77,6 +77,8 @@ github.com/Max-Sum/base32768 v0.0.0-20230304063302-18e6ce5945fd/go.mod h1:C8yoIf
github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/PaddleHQ/paddle-go-sdk v1.0.0 h1:+EXitsPFbRcc0CpQE/MIeudxiVOR8pFe/aOWTEUHDKU=
github.com/PaddleHQ/paddle-go-sdk v1.0.0/go.mod h1:kbBBzf0BHEj38QvhtoELqlGip3alKgA/I+vl7RQzB58=
github.com/ProtonMail/bcrypt v0.0.0-20210511135022-227b4adcab57/go.mod h1:HecWFHognK8GfRDGnFQbW/LiV7A3MX3gZVs45vk5h8I=
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf h1:yc9daCCYUefEs69zUkSzubzjBbL+cmOXgnmt9Fyd9ug=
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf/go.mod h1:o0ESU9p83twszAU8LBeJKFAAMX14tISa0yk4Oo5TOqo=
@@ -248,6 +250,10 @@ github.com/gabriel-vasile/mimetype v1.4.11/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9t
github.com/geoffgarside/ber v1.1.0/go.mod h1:jVPKeCbj6MvQZhwLYsGwaGI52oUorHoHKNecGT85ZCc=
github.com/geoffgarside/ber v1.2.0 h1:/loowoRcs/MWLYmGX9QtIAbA+V/FrnVLsMMPhwiRm64=
github.com/geoffgarside/ber v1.2.0/go.mod h1:jVPKeCbj6MvQZhwLYsGwaGI52oUorHoHKNecGT85ZCc=
github.com/ggicci/httpin v0.19.0 h1:p0B3SWLVgg770VirYiHB14M5wdRx3zR8mCTzM/TkTQ8=
github.com/ggicci/httpin v0.19.0/go.mod h1:hzsQHcbqLabmGOycf7WNw6AAzcVbsMeoOp46bWAbIWc=
github.com/ggicci/owl v0.8.2 h1:og+lhqpzSMPDdEB+NJfzoAJARP7qCG3f8uUC3xvGukA=
github.com/ggicci/owl v0.8.2/go.mod h1:PHRD57u41vFN5UtFz2SF79yTVoM3HlWpjMiE+ZU2dj4=
github.com/gin-contrib/cors v1.7.5 h1:cXC9SmofOrRg0w9PigwGlHG3ztswH6bqq4vJVXnvYMk=
github.com/gin-contrib/cors v1.7.5/go.mod h1:4q3yi7xBEDDWKapjT2o1V7mScKDDr8k+jZ0fSquGoy0=
github.com/gin-contrib/gzip v1.2.3 h1:dAhT722RuEG330ce2agAs75z7yB+NKvX/ZM1r8w0u2U=
@@ -454,6 +460,8 @@ github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1
github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk=
github.com/jtolio/noiseconn v0.0.0-20231127013910-f6d9ecbf1de7 h1:JcltaO1HXM5S2KYOYcKgAV7slU0xPy1OcvrVgn98sRQ=
github.com/jtolio/noiseconn v0.0.0-20231127013910-f6d9ecbf1de7/go.mod h1:MEkhEPFwP3yudWO0lj6vfYpLIB+3eIcuIW+e0AZzUQk=
github.com/justinas/alice v1.2.0 h1:+MHSA/vccVCF4Uq37S42jwlkvI2Xzl7zTPCN5BnZNVo=
github.com/justinas/alice v1.2.0/go.mod h1:fN5HRH/reO/zrUflLfTN43t3vXvKzvZIENsNEe7i7qA=
github.com/jzelinskie/whirlpool v0.0.0-20201016144138-0675e54bb004 h1:G+9t9cEtnC9jFiTxyptEKuNIAbiN5ZCQzX2a74lj3xg=
github.com/jzelinskie/whirlpool v0.0.0-20201016144138-0675e54bb004/go.mod h1:KmHnJWQrgEvbuy0vcvj00gtMqbvNn1L+3YUZLK/B92c=
github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU=

View File

@@ -5,6 +5,7 @@ import (
"path/filepath"
"strings"
"sync"
"time"
"github.com/ilyakaznacheev/cleanenv"
"github.com/joho/godotenv"
@@ -53,6 +54,20 @@ type EnvVariables struct {
TempFolder string
SecretKeyPath string
// Billing (always tax-exclusive)
PricePerGBCents int64 `env:"PRICE_PER_GB_CENTS"`
MinStorageGB int
MaxStorageGB int
TrialDuration time.Duration
TrialStorageGB int
GracePeriod time.Duration
// Paddle billing
IsPaddleSandbox bool `env:"IS_PADDLE_SANDBOX"`
PaddleApiKey string `env:"PADDLE_API_KEY"`
PaddleWebhookSecret string `env:"PADDLE_WEBHOOK_SECRET"`
PaddlePriceID string `env:"PADDLE_PRICE_ID"`
PaddleClientToken string `env:"PADDLE_CLIENT_TOKEN"`
TestGoogleDriveClientID string `env:"TEST_GOOGLE_DRIVE_CLIENT_ID"`
TestGoogleDriveClientSecret string `env:"TEST_GOOGLE_DRIVE_CLIENT_SECRET"`
TestGoogleDriveTokenJSON string `env:"TEST_GOOGLE_DRIVE_TOKEN_JSON"`
@@ -109,10 +124,6 @@ type EnvVariables struct {
CloudflareTurnstileSecretKey string `env:"CLOUDFLARE_TURNSTILE_SECRET_KEY"`
CloudflareTurnstileSiteKey string `env:"CLOUDFLARE_TURNSTILE_SITE_KEY"`
// testing Telegram
TestTelegramBotToken string `env:"TEST_TELEGRAM_BOT_TOKEN"`
TestTelegramChatID string `env:"TEST_TELEGRAM_CHAT_ID"`
// testing Supabase
TestSupabaseHost string `env:"TEST_SUPABASE_HOST"`
TestSupabasePort string `env:"TEST_SUPABASE_PORT"`
@@ -131,14 +142,13 @@ type EnvVariables struct {
DatabasusURL string `env:"DATABASUS_URL"`
}
var (
env EnvVariables
once sync.Once
)
var env EnvVariables
func GetEnv() EnvVariables {
once.Do(loadEnvVariables)
return env
var initEnv = sync.OnceFunc(loadEnvVariables)
func GetEnv() *EnvVariables {
initEnv()
return &env
}
func loadEnvVariables() {
@@ -365,16 +375,41 @@ func loadEnvVariables() {
os.Exit(1)
}
if env.TestTelegramBotToken == "" {
log.Error("TEST_TELEGRAM_BOT_TOKEN is empty")
}
// Billing
if env.IsCloud {
if env.PricePerGBCents == 0 {
log.Error("PRICE_PER_GB_CENTS is empty or zero")
os.Exit(1)
}
if env.TestTelegramChatID == "" {
log.Error("TEST_TELEGRAM_CHAT_ID is empty")
if env.PaddleApiKey == "" {
log.Error("PADDLE_API_KEY is empty")
os.Exit(1)
}
if env.PaddleWebhookSecret == "" {
log.Error("PADDLE_WEBHOOK_SECRET is empty")
os.Exit(1)
}
if env.PaddlePriceID == "" {
log.Error("PADDLE_PRICE_ID is empty")
os.Exit(1)
}
if env.PaddleClientToken == "" {
log.Error("PADDLE_CLIENT_TOKEN is empty")
os.Exit(1)
}
}
env.MinStorageGB = 20
env.MaxStorageGB = 10_000
env.TrialDuration = 24 * time.Hour
env.TrialStorageGB = 20
env.GracePeriod = 30 * 24 * time.Hour
log.Info("Environment variables loaded successfully!")
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
)
@@ -13,39 +12,32 @@ type AuditLogBackgroundService struct {
auditLogService *AuditLogService
logger *slog.Logger
runOnce sync.Once
hasRun atomic.Bool
hasRun atomic.Bool
}
func (s *AuditLogBackgroundService) Run(ctx context.Context) {
wasAlreadyRun := s.hasRun.Load()
if s.hasRun.Swap(true) {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
s.runOnce.Do(func() {
s.hasRun.Store(true)
s.logger.Info("Starting audit log cleanup background service")
s.logger.Info("Starting audit log cleanup background service")
if ctx.Err() != nil {
return
}
if ctx.Err() != nil {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
}
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.cleanOldAuditLogs(); err != nil {
s.logger.Error("Failed to clean old audit logs", "error", err)
}
case <-ticker.C:
if err := s.cleanOldAuditLogs(); err != nil {
s.logger.Error("Failed to clean old audit logs", "error", err)
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
}

View File

@@ -102,7 +102,7 @@ func Test_CleanOldAuditLogs_DeletesMultipleOldLogs(t *testing.T) {
// Create many old logs with specific UUIDs to track them
testLogIDs := make([]uuid.UUID, 5)
for i := 0; i < 5; i++ {
for i := range 5 {
testLogIDs[i] = uuid.New()
daysAgo := 400 + (i * 10)
log := &AuditLog{

View File

@@ -2,7 +2,6 @@ package audit_logs
import (
"sync"
"sync/atomic"
users_services "databasus-backend/internal/features/users/services"
"databasus-backend/internal/util/logger"
@@ -23,8 +22,6 @@ var auditLogController = &AuditLogController{
var auditLogBackgroundService = &AuditLogBackgroundService{
auditLogService: auditLogService,
logger: logger.GetLogger(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
func GetAuditLogService() *AuditLogService {
@@ -39,23 +36,8 @@ func GetAuditLogBackgroundService() *AuditLogBackgroundService {
return auditLogBackgroundService
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
users_services.GetUserService().SetAuditLogWriter(auditLogService)
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}
var SetupDependencies = sync.OnceFunc(func() {
users_services.GetUserService().SetAuditLogWriter(auditLogService)
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
})

View File

@@ -9,7 +9,6 @@ import (
"log/slog"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
@@ -46,80 +45,73 @@ type BackuperNode struct {
lastHeartbeat time.Time
runOnce sync.Once
hasRun atomic.Bool
hasRun atomic.Bool
}
func (n *BackuperNode) Run(ctx context.Context) {
wasAlreadyRun := n.hasRun.Load()
if n.hasRun.Swap(true) {
panic(fmt.Sprintf("%T.Run() called multiple times", n))
}
n.runOnce.Do(func() {
n.hasRun.Store(true)
n.lastHeartbeat = time.Now().UTC()
n.lastHeartbeat = time.Now().UTC()
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
backupNode := BackupNode{
ID: n.nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: time.Now().UTC(),
}
backupNode := BackupNode{
ID: n.nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: time.Now().UTC(),
}
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
n.logger.Error("Failed to register node in registry", "error", err)
panic(err)
}
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
n.logger.Error("Failed to register node in registry", "error", err)
panic(err)
}
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
go func() {
n.MakeBackup(backupID, isCallNotifier)
if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil {
n.logger.Error(
"Failed to publish backup completion",
"error",
err,
"backupID",
backupID,
)
}
}()
}
err := n.backupNodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID, backupHandler)
if err != nil {
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
panic(err)
}
defer func() {
if err := n.backupNodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
go func() {
n.MakeBackup(backupID, isCallNotifier)
if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil {
n.logger.Error(
"Failed to publish backup completion",
"error",
err,
"backupID",
backupID,
)
}
}()
}
ticker := time.NewTicker(heartbeatTickerInterval)
defer ticker.Stop()
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
for {
select {
case <-ctx.Done():
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
if err := n.backupNodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
n.logger.Error("Failed to unregister node from registry", "error", err)
}
return
case <-ticker.C:
n.sendHeartbeat(&backupNode)
}
err := n.backupNodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID, backupHandler)
if err != nil {
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
panic(err)
}
defer func() {
if err := n.backupNodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
}
})
}()
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", n))
ticker := time.NewTicker(heartbeatTickerInterval)
defer ticker.Stop()
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
for {
select {
case <-ctx.Done():
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
if err := n.backupNodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
n.logger.Error("Failed to unregister node from registry", "error", err)
}
return
case <-ticker.C:
n.sendHeartbeat(&backupNode)
}
}
}
@@ -171,26 +163,6 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
backup.BackupSizeMb = completedMBs
backup.BackupDurationMs = time.Since(start).Milliseconds()
// Check size limit (0 = unlimited)
if backupConfig.MaxBackupSizeMB > 0 &&
completedMBs > float64(backupConfig.MaxBackupSizeMB) {
errMsg := fmt.Sprintf(
"backup size (%.2f MB) exceeded maximum allowed size (%d MB)",
completedMBs,
backupConfig.MaxBackupSizeMB,
)
backup.Status = backups_core.BackupStatusFailed
backup.IsSkipRetry = true
backup.FailMessage = &errMsg
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to save backup with size exceeded error", "error", err)
}
cancel() // Cancel the backup context
return
}
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to update backup progress", "error", err)
}

View File

@@ -153,121 +153,3 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
assert.Equal(t, notifier.ID, capturedNotifier.ID)
})
}
func Test_BackupSizeLimits(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
t.Run("UnlimitedSize_MaxBackupSizeMBIsZero_BackupCompletes", func(t *testing.T) {
// Enable backups with unlimited size (0)
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backupConfig.MaxBackupSizeMB = 0 // unlimited
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backuperNode := CreateTestBackuperNode()
backuperNode.createBackupUseCase = &CreateLargeBackupUsecase{}
// Create a backup record
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backuperNode.MakeBackup(backup.ID, false)
// Verify backup completed successfully even with large size
updatedBackup, err := backupRepository.FindByID(backup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
assert.Equal(t, float64(10000), updatedBackup.BackupSizeMb)
assert.Nil(t, updatedBackup.FailMessage)
})
t.Run("SizeExceeded_BackupFailedWithIsSkipRetry", func(t *testing.T) {
// Enable backups with 5 MB limit
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backupConfig.MaxBackupSizeMB = 5
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backuperNode := CreateTestBackuperNode()
backuperNode.createBackupUseCase = &CreateProgressiveBackupUsecase{}
// Create a backup record
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backuperNode.MakeBackup(backup.ID, false)
// Verify backup was marked as failed with IsSkipRetry=true
updatedBackup, err := backupRepository.FindByID(backup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusFailed, updatedBackup.Status)
assert.True(t, updatedBackup.IsSkipRetry)
assert.NotNil(t, updatedBackup.FailMessage)
assert.Contains(t, *updatedBackup.FailMessage, "exceeded maximum allowed size")
assert.Contains(t, *updatedBackup.FailMessage, "10.00 MB")
assert.Contains(t, *updatedBackup.FailMessage, "5 MB")
assert.Greater(t, updatedBackup.BackupSizeMb, float64(5))
})
t.Run("SizeWithinLimit_BackupCompletes", func(t *testing.T) {
// Enable backups with 100 MB limit
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backupConfig.MaxBackupSizeMB = 100
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backuperNode := CreateTestBackuperNode()
backuperNode.createBackupUseCase = &CreateMediumBackupUsecase{}
// Create a backup record
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backuperNode.MakeBackup(backup.ID, false)
// Verify backup completed successfully
updatedBackup, err := backupRepository.FindByID(backup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
assert.Equal(t, float64(50), updatedBackup.BackupSizeMb)
assert.Nil(t, updatedBackup.FailMessage)
})
}

View File

@@ -4,12 +4,12 @@ import (
"context"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"databasus-backend/internal/config"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/storages"
@@ -26,45 +26,47 @@ type BackupCleaner struct {
backupRepository *backups_core.BackupRepository
storageService *storages.StorageService
backupConfigService *backups_config.BackupConfigService
billingService BillingService
fieldEncryptor util_encryption.FieldEncryptor
logger *slog.Logger
backupRemoveListeners []backups_core.BackupRemoveListener
runOnce sync.Once
hasRun atomic.Bool
hasRun atomic.Bool
}
func (c *BackupCleaner) Run(ctx context.Context) {
wasAlreadyRun := c.hasRun.Load()
if c.hasRun.Swap(true) {
panic(fmt.Sprintf("%T.Run() called multiple times", c))
}
c.runOnce.Do(func() {
c.hasRun.Store(true)
if ctx.Err() != nil {
return
}
if ctx.Err() != nil {
retentionLog := c.logger.With("task_name", "clean_by_retention_policy")
exceededLog := c.logger.With("task_name", "clean_exceeded_storage_backups")
staleLog := c.logger.With("task_name", "clean_stale_basebackups")
ticker := time.NewTicker(cleanerTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
}
case <-ticker.C:
if err := c.cleanByRetentionPolicy(retentionLog); err != nil {
retentionLog.Error("failed to clean backups by retention policy", "error", err)
}
ticker := time.NewTicker(cleanerTickerInterval)
defer ticker.Stop()
if err := c.cleanExceededStorageBackups(exceededLog); err != nil {
exceededLog.Error("failed to clean exceeded backups", "error", err)
}
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := c.cleanByRetentionPolicy(); err != nil {
c.logger.Error("Failed to clean backups by retention policy", "error", err)
}
if err := c.cleanExceededBackups(); err != nil {
c.logger.Error("Failed to clean exceeded backups", "error", err)
}
if err := c.cleanStaleUploadedBasebackups(staleLog); err != nil {
staleLog.Error("failed to clean stale uploaded basebackups", "error", err)
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", c))
}
}
@@ -100,59 +102,109 @@ func (c *BackupCleaner) AddBackupRemoveListener(listener backups_core.BackupRemo
c.backupRemoveListeners = append(c.backupRemoveListeners, listener)
}
func (c *BackupCleaner) cleanByRetentionPolicy() error {
func (c *BackupCleaner) cleanStaleUploadedBasebackups(logger *slog.Logger) error {
staleBackups, err := c.backupRepository.FindStaleUploadedBasebackups(
time.Now().UTC().Add(-10 * time.Minute),
)
if err != nil {
return fmt.Errorf("failed to find stale uploaded basebackups: %w", err)
}
for _, backup := range staleBackups {
backupLog := logger.With("database_id", backup.DatabaseID, "backup_id", backup.ID)
staleStorage, storageErr := c.storageService.GetStorageByID(backup.StorageID)
if storageErr != nil {
backupLog.Error(
"failed to get storage for stale basebackup cleanup",
"storage_id", backup.StorageID,
"error", storageErr,
)
} else {
if err := staleStorage.DeleteFile(c.fieldEncryptor, backup.FileName); err != nil {
backupLog.Error(
fmt.Sprintf("failed to delete stale basebackup file: %s", backup.FileName),
"error",
err,
)
}
metadataFileName := backup.FileName + ".metadata"
if err := staleStorage.DeleteFile(c.fieldEncryptor, metadataFileName); err != nil {
backupLog.Error(
fmt.Sprintf("failed to delete stale basebackup metadata file: %s", metadataFileName),
"error",
err,
)
}
}
failMsg := "basebackup finalization timed out after 10 minutes"
backup.Status = backups_core.BackupStatusFailed
backup.FailMessage = &failMsg
if err := c.backupRepository.Save(backup); err != nil {
backupLog.Error("failed to mark stale uploaded basebackup as failed", "error", err)
continue
}
backupLog.Info("marked stale uploaded basebackup as failed and cleaned storage")
}
return nil
}
func (c *BackupCleaner) cleanByRetentionPolicy(logger *slog.Logger) error {
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
dbLog := logger.With("database_id", backupConfig.DatabaseID, "policy", backupConfig.RetentionPolicyType)
var cleanErr error
switch backupConfig.RetentionPolicyType {
case backups_config.RetentionPolicyTypeCount:
cleanErr = c.cleanByCount(backupConfig)
cleanErr = c.cleanByCount(dbLog, backupConfig)
case backups_config.RetentionPolicyTypeGFS:
cleanErr = c.cleanByGFS(backupConfig)
cleanErr = c.cleanByGFS(dbLog, backupConfig)
default:
cleanErr = c.cleanByTimePeriod(backupConfig)
cleanErr = c.cleanByTimePeriod(dbLog, backupConfig)
}
if cleanErr != nil {
c.logger.Error(
"Failed to clean backups by retention policy",
"databaseId", backupConfig.DatabaseID,
"policy", backupConfig.RetentionPolicyType,
"error", cleanErr,
)
dbLog.Error("failed to clean backups by retention policy", "error", cleanErr)
}
}
return nil
}
func (c *BackupCleaner) cleanExceededBackups() error {
func (c *BackupCleaner) cleanExceededStorageBackups(logger *slog.Logger) error {
if !config.GetEnv().IsCloud {
return nil
}
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
if backupConfig.MaxBackupsTotalSizeMB <= 0 {
dbLog := logger.With("database_id", backupConfig.DatabaseID)
subscription, subErr := c.billingService.GetSubscription(dbLog, backupConfig.DatabaseID)
if subErr != nil {
dbLog.Error("failed to get subscription for exceeded backups check", "error", subErr)
continue
}
if err := c.cleanExceededBackupsForDatabase(
backupConfig.DatabaseID,
backupConfig.MaxBackupsTotalSizeMB,
); err != nil {
c.logger.Error(
"Failed to clean exceeded backups for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
storageLimitMB := int64(subscription.GetBackupsStorageGB()) * 1024
if err := c.cleanExceededBackupsForDatabase(dbLog, backupConfig.DatabaseID, storageLimitMB); err != nil {
dbLog.Error("failed to clean exceeded backups for database", "error", err)
continue
}
}
@@ -160,7 +212,7 @@ func (c *BackupCleaner) cleanExceededBackups() error {
return nil
}
func (c *BackupCleaner) cleanByTimePeriod(backupConfig *backups_config.BackupConfig) error {
func (c *BackupCleaner) cleanByTimePeriod(logger *slog.Logger, backupConfig *backups_config.BackupConfig) error {
if backupConfig.RetentionTimePeriod == "" {
return nil
}
@@ -190,21 +242,17 @@ func (c *BackupCleaner) cleanByTimePeriod(backupConfig *backups_config.BackupCon
}
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
logger.Error("failed to delete old backup", "backup_id", backup.ID, "error", err)
continue
}
c.logger.Info(
"Deleted old backup",
"backupId", backup.ID,
"databaseId", backupConfig.DatabaseID,
)
logger.Info("deleted old backup", "backup_id", backup.ID)
}
return nil
}
func (c *BackupCleaner) cleanByCount(backupConfig *backups_config.BackupConfig) error {
func (c *BackupCleaner) cleanByCount(logger *slog.Logger, backupConfig *backups_config.BackupConfig) error {
if backupConfig.RetentionCount <= 0 {
return nil
}
@@ -233,28 +281,20 @@ func (c *BackupCleaner) cleanByCount(backupConfig *backups_config.BackupConfig)
}
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error(
"Failed to delete backup by count policy",
"backupId",
backup.ID,
"error",
err,
)
logger.Error("failed to delete backup by count policy", "backup_id", backup.ID, "error", err)
continue
}
c.logger.Info(
"Deleted backup by count policy",
"backupId", backup.ID,
"databaseId", backupConfig.DatabaseID,
"retentionCount", backupConfig.RetentionCount,
logger.Info(
fmt.Sprintf("deleted backup by count policy: retention count is %d", backupConfig.RetentionCount),
"backup_id", backup.ID,
)
}
return nil
}
func (c *BackupCleaner) cleanByGFS(backupConfig *backups_config.BackupConfig) error {
func (c *BackupCleaner) cleanByGFS(logger *slog.Logger, backupConfig *backups_config.BackupConfig) error {
if backupConfig.RetentionGfsHours <= 0 && backupConfig.RetentionGfsDays <= 0 &&
backupConfig.RetentionGfsWeeks <= 0 && backupConfig.RetentionGfsMonths <= 0 &&
backupConfig.RetentionGfsYears <= 0 {
@@ -292,29 +332,20 @@ func (c *BackupCleaner) cleanByGFS(backupConfig *backups_config.BackupConfig) er
}
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error(
"Failed to delete backup by GFS policy",
"backupId",
backup.ID,
"error",
err,
)
logger.Error("failed to delete backup by GFS policy", "backup_id", backup.ID, "error", err)
continue
}
c.logger.Info(
"Deleted backup by GFS policy",
"backupId", backup.ID,
"databaseId", backupConfig.DatabaseID,
)
logger.Info("deleted backup by GFS policy", "backup_id", backup.ID)
}
return nil
}
func (c *BackupCleaner) cleanExceededBackupsForDatabase(
logger *slog.Logger,
databaseID uuid.UUID,
limitperDbMB int64,
limitPerDbMB int64,
) error {
for {
backupsTotalSizeMB, err := c.backupRepository.GetTotalSizeByDatabase(databaseID)
@@ -322,7 +353,7 @@ func (c *BackupCleaner) cleanExceededBackupsForDatabase(
return err
}
if backupsTotalSizeMB <= float64(limitperDbMB) {
if backupsTotalSizeMB <= float64(limitPerDbMB) {
break
}
@@ -335,59 +366,27 @@ func (c *BackupCleaner) cleanExceededBackupsForDatabase(
}
if len(oldestBackups) == 0 {
c.logger.Warn(
"No backups to delete but still over limit",
"databaseId",
databaseID,
"totalSizeMB",
backupsTotalSizeMB,
"limitMB",
limitperDbMB,
)
logger.Warn(fmt.Sprintf(
"no backups to delete but still over limit: total size is %.1f MB, limit is %d MB",
backupsTotalSizeMB, limitPerDbMB,
))
break
}
backup := oldestBackups[0]
if isRecentBackup(backup) {
c.logger.Warn(
"Oldest backup is too recent to delete, stopping size cleanup",
"databaseId",
databaseID,
"backupId",
backup.ID,
"totalSizeMB",
backupsTotalSizeMB,
"limitMB",
limitperDbMB,
)
break
}
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error(
"Failed to delete exceeded backup",
"backupId",
backup.ID,
"databaseId",
databaseID,
"error",
err,
)
logger.Error("failed to delete exceeded backup", "backup_id", backup.ID, "error", err)
return err
}
c.logger.Info(
"Deleted exceeded backup",
"backupId",
backup.ID,
"databaseId",
databaseID,
"backupSizeMB",
backup.BackupSizeMb,
"totalSizeMB",
backupsTotalSizeMB,
"limitMB",
limitperDbMB,
logger.Info(
fmt.Sprintf("deleted exceeded backup: backup size is %.1f MB, total size is %.1f MB, limit is %d MB",
backup.BackupSizeMb, backupsTotalSizeMB, limitPerDbMB),
"backup_id", backup.ID,
)
}

View File

@@ -33,7 +33,7 @@ func Test_BuildGFSKeepSet(t *testing.T) {
// backupsEveryDay returns n backups, newest-first, each 1 day apart.
backupsEveryDay := func(n int) []*backups_core.Backup {
bs := make([]*backups_core.Backup, n)
for i := 0; i < n; i++ {
for i := range n {
bs[i] = newBackup(ref.Add(-time.Duration(i) * day))
}
return bs
@@ -42,7 +42,7 @@ func Test_BuildGFSKeepSet(t *testing.T) {
// backupsEveryWeek returns n backups, newest-first, each 7 days apart.
backupsEveryWeek := func(n int) []*backups_core.Backup {
bs := make([]*backups_core.Backup, n)
for i := 0; i < n; i++ {
for i := range n {
bs[i] = newBackup(ref.Add(-time.Duration(i) * week))
}
return bs
@@ -53,7 +53,7 @@ func Test_BuildGFSKeepSet(t *testing.T) {
// backupsEveryHour returns n backups, newest-first, each 1 hour apart.
backupsEveryHour := func(n int) []*backups_core.Backup {
bs := make([]*backups_core.Backup, n)
for i := 0; i < n; i++ {
for i := range n {
bs[i] = newBackup(ref.Add(-time.Duration(i) * hour))
}
return bs
@@ -62,7 +62,7 @@ func Test_BuildGFSKeepSet(t *testing.T) {
// backupsEveryMonth returns n backups, newest-first, each ~1 month apart.
backupsEveryMonth := func(n int) []*backups_core.Backup {
bs := make([]*backups_core.Backup, n)
for i := 0; i < n; i++ {
for i := range n {
bs[i] = newBackup(ref.AddDate(0, -i, 0))
}
return bs
@@ -71,7 +71,7 @@ func Test_BuildGFSKeepSet(t *testing.T) {
// backupsEveryYear returns n backups, newest-first, each 1 year apart.
backupsEveryYear := func(n int) []*backups_core.Backup {
bs := make([]*backups_core.Backup, n)
for i := 0; i < n; i++ {
for i := range n {
bs[i] = newBackup(ref.AddDate(-i, 0, 0))
}
return bs
@@ -410,7 +410,7 @@ func Test_CleanByGFS_KeepsCorrectBackupsPerSlot(t *testing.T) {
// Create 5 backups on 5 different days; only the 3 newest days should be kept
var backupIDs []uuid.UUID
for i := 0; i < 5; i++ {
for i := range 5 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
@@ -425,7 +425,7 @@ func Test_CleanByGFS_KeepsCorrectBackupsPerSlot(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -486,7 +486,7 @@ func Test_CleanByGFS_WithWeeklyAndMonthlySlots_KeepsWiderSpread(t *testing.T) {
// Create one backup per week for 6 weeks (each on Monday of that week)
// GFS should keep: 2 daily (most recent 2 unique days) + 2 weekly + 1 monthly = up to 5 unique
var createdIDs []uuid.UUID
for i := 0; i < 6; i++ {
for i := range 6 {
weekOffset := time.Duration(5-i) * 7 * 24 * time.Hour
backup := &backups_core.Backup{
ID: uuid.New(),
@@ -502,7 +502,7 @@ func Test_CleanByGFS_WithWeeklyAndMonthlySlots_KeepsWiderSpread(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -561,7 +561,7 @@ func Test_CleanByGFS_WithHourlySlots_KeepsCorrectBackups(t *testing.T) {
// Create 5 backups spaced 1 hour apart; only the 3 newest hours should be kept
var backupIDs []uuid.UUID
for i := 0; i < 5; i++ {
for i := range 5 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
@@ -576,7 +576,7 @@ func Test_CleanByGFS_WithHourlySlots_KeepsCorrectBackups(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -677,7 +677,7 @@ func Test_CleanByGFS_SkipsRecentBackup_WhenNotInKeepSet(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -759,7 +759,7 @@ func Test_CleanByGFS_With20DailyBackups_KeepsOnlyExpectedCount(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -824,8 +824,8 @@ func Test_CleanByGFS_WithMultipleBackupsPerDay_KeepsOnlyOnePerDailySlot(t *testi
// Create 3 backups per day for 10 days = 30 total, all beyond grace period.
// Each day gets backups at base+0h, base+6h, base+12h.
for day := 0; day < 10; day++ {
for sub := 0; sub < 3; sub++ {
for day := range 10 {
for sub := range 3 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
@@ -844,7 +844,7 @@ func Test_CleanByGFS_WithMultipleBackupsPerDay_KeepsOnlyOnePerDailySlot(t *testi
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -915,7 +915,7 @@ func Test_CleanByGFS_With24HourlySlotsAnd23DailyBackups_DeletesExcessBackups(t *
now := time.Now().UTC()
for i := 0; i < 23; i++ {
for i := range 23 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
@@ -929,7 +929,7 @@ func Test_CleanByGFS_With24HourlySlotsAnd23DailyBackups_DeletesExcessBackups(t *
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -985,7 +985,7 @@ func Test_CleanByGFS_WithDisabledHourlySlotsAnd23DailyBackups_DeletesExcessBacku
now := time.Now().UTC()
for i := 0; i < 23; i++ {
for i := range 23 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
@@ -999,7 +999,7 @@ func Test_CleanByGFS_WithDisabledHourlySlotsAnd23DailyBackups_DeletesExcessBacku
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -1055,7 +1055,7 @@ func Test_CleanByGFS_WithDailySlotsAndWeeklyBackups_DeletesExcessBackups(t *test
// Create 10 weekly backups (1 per week, all >2h old past grace period).
// With 7d/4w config, correct behavior: ~8 kept (4 weekly + overlap with daily for recent ones).
// Daily slots should NOT absorb weekly backups that are older than 7 days.
for i := 0; i < 10; i++ {
for i := range 10 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
@@ -1069,7 +1069,7 @@ func Test_CleanByGFS_WithDailySlotsAndWeeklyBackups_DeletesExcessBackups(t *test
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -1138,7 +1138,7 @@ func Test_CleanByGFS_WithWeeklySlotsAndMonthlyBackups_DeletesExcessBackups(t *te
// With 52w/3m config, correct behavior: 3 kept (3 monthly slots; weekly should only
// cover recent 52 weeks but not artificially retain old monthly backups).
// Bug: all 8 kept because each monthly backup fills a unique weekly slot.
for i := 0; i < 8; i++ {
for i := range 8 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
@@ -1152,7 +1152,7 @@ func Test_CleanByGFS_WithWeeklySlotsAndMonthlyBackups_DeletesExcessBackups(t *te
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)

View File

@@ -1,14 +1,17 @@
package backuping
import (
"log/slog"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
billing_models "databasus-backend/internal/features/billing/models"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
@@ -17,6 +20,7 @@ import (
users_testing "databasus-backend/internal/features/users/testing"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/storage"
"databasus-backend/internal/util/logger"
"databasus-backend/internal/util/period"
)
@@ -51,6 +55,7 @@ func Test_CleanOldBackups_DeletesBackupsOlderThanRetentionTimePeriod(t *testing.
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -89,7 +94,7 @@ func Test_CleanOldBackups_DeletesBackupsOlderThanRetentionTimePeriod(t *testing.
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -129,6 +134,7 @@ func Test_CleanOldBackups_SkipsDatabaseWithForeverRetentionPeriod(t *testing.T)
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -145,7 +151,7 @@ func Test_CleanOldBackups_SkipsDatabaseWithForeverRetentionPeriod(t *testing.T)
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -154,7 +160,8 @@ func Test_CleanOldBackups_SkipsDatabaseWithForeverRetentionPeriod(t *testing.T)
assert.Equal(t, oldBackup.ID, remainingBackups[0].ID)
}
func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
func Test_CleanExceededBackups_WhenUnderStorageLimit_NoBackupsDeleted(t *testing.T) {
enableCloud(t)
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
@@ -178,33 +185,36 @@ func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 100,
BackupIntervalID: interval.ID,
BackupInterval: interval,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
for i := 0; i < 3; i++ {
for i := range 3 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 16.67,
BackupSizeMb: 100,
CreatedAt: time.Now().UTC().Add(-time.Duration(i) * time.Hour),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
}
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
}
cleaner := CreateTestBackupCleaner(mockBilling)
err = cleaner.cleanExceededStorageBackups(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -212,7 +222,8 @@ func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
assert.Equal(t, 3, len(remainingBackups))
}
func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T) {
func Test_CleanExceededBackups_WhenOverStorageLimit_DeletesOldestBackups(t *testing.T) {
enableCloud(t)
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
@@ -236,27 +247,29 @@ func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T)
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 30,
BackupIntervalID: interval.ID,
BackupInterval: interval,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// 5 backups at 300 MB each = 1500 MB total, limit = 1 GB (1024 MB)
// Expect 2 oldest deleted, 3 remain (900 MB < 1024 MB)
now := time.Now().UTC()
var backupIDs []uuid.UUID
for i := 0; i < 5; i++ {
for i := range 5 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
BackupSizeMb: 300,
CreatedAt: now.Add(-time.Duration(4-i) * time.Hour),
}
err = backupRepository.Save(backup)
@@ -264,8 +277,11 @@ func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T)
backupIDs = append(backupIDs, backup.ID)
}
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
}
cleaner := CreateTestBackupCleaner(mockBilling)
err = cleaner.cleanExceededStorageBackups(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -284,6 +300,7 @@ func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T)
}
func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
enableCloud(t)
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
@@ -307,28 +324,29 @@ func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 50,
BackupIntervalID: interval.ID,
BackupInterval: interval,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
now := time.Now().UTC()
// 3 completed at 500 MB each = 1500 MB, limit = 1 GB (1024 MB)
completedBackups := make([]*backups_core.Backup, 3)
for i := 0; i < 3; i++ {
for i := range 3 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 30,
BackupSizeMb: 500,
CreatedAt: now.Add(-time.Duration(3-i) * time.Hour),
}
err = backupRepository.Save(backup)
@@ -347,8 +365,11 @@ func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
err = backupRepository.Save(inProgressBackup)
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
}
cleaner := CreateTestBackupCleaner(mockBilling)
err = cleaner.cleanExceededStorageBackups(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -365,7 +386,8 @@ func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
assert.True(t, inProgressFound, "In-progress backup should not be deleted")
}
func Test_CleanExceededBackups_WithZeroLimit_SkipsDatabase(t *testing.T) {
func Test_CleanExceededBackups_WithZeroStorageLimit_RemovesAllBackups(t *testing.T) {
enableCloud(t)
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
@@ -389,38 +411,42 @@ func Test_CleanExceededBackups_WithZeroLimit_SkipsDatabase(t *testing.T) {
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 0,
BackupIntervalID: interval.ID,
BackupInterval: interval,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
for i := 0; i < 10; i++ {
for i := range 10 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 100,
CreatedAt: time.Now().UTC().Add(-time.Duration(i) * time.Hour),
CreatedAt: time.Now().UTC().Add(-time.Duration(i+2) * time.Hour),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
}
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
// StorageGB=0 means no storage allowed — all backups should be removed
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{StorageGB: 0, Status: billing_models.StatusActive},
}
cleaner := CreateTestBackupCleaner(mockBilling)
err = cleaner.cleanExceededStorageBackups(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Equal(t, 10, len(remainingBackups))
assert.Equal(t, 0, len(remainingBackups))
}
func Test_GetTotalSizeByDatabase_CalculatesCorrectly(t *testing.T) {
@@ -522,13 +548,14 @@ func Test_CleanByCount_KeepsNewestNBackups_DeletesOlder(t *testing.T) {
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
now := time.Now().UTC()
var backupIDs []uuid.UUID
for i := 0; i < 5; i++ {
for i := range 5 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
@@ -545,7 +572,7 @@ func Test_CleanByCount_KeepsNewestNBackups_DeletesOlder(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -594,11 +621,12 @@ func Test_CleanByCount_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
for i := 0; i < 5; i++ {
for i := range 5 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
@@ -612,7 +640,7 @@ func Test_CleanByCount_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -651,13 +679,14 @@ func Test_CleanByCount_DoesNotDeleteInProgressBackups(t *testing.T) {
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
now := time.Now().UTC()
for i := 0; i < 3; i++ {
for i := range 3 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
@@ -682,7 +711,7 @@ func Test_CleanByCount_DoesNotDeleteInProgressBackups(t *testing.T) {
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -776,6 +805,7 @@ func Test_CleanByTimePeriod_SkipsRecentBackup_EvenIfOlderThanRetention(t *testin
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -805,7 +835,7 @@ func Test_CleanByTimePeriod_SkipsRecentBackup_EvenIfOlderThanRetention(t *testin
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -847,6 +877,7 @@ func Test_CleanByCount_SkipsRecentBackup_EvenIfOverLimit(t *testing.T) {
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -893,7 +924,7 @@ func Test_CleanByCount_SkipsRecentBackup_EvenIfOverLimit(t *testing.T) {
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
err = cleaner.cleanByRetentionPolicy(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -914,7 +945,8 @@ func Test_CleanByCount_SkipsRecentBackup_EvenIfOverLimit(t *testing.T) {
assert.True(t, remainingIDs[newestBackup.ID], "Newest backup should be preserved")
}
func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testing.T) {
func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverStorageLimit(t *testing.T) {
enableCloud(t)
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
@@ -937,18 +969,18 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
interval := createTestInterval()
// Total size limit is 10 MB. We have two backups of 8 MB each (16 MB total).
// Total size limit = 1 GB (1024 MB). Two backups of 600 MB each (1200 MB total).
// The oldest backup was created 30 minutes ago — within the grace period.
// The cleaner must stop and leave both backups intact.
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 10,
BackupIntervalID: interval.ID,
BackupInterval: interval,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
@@ -960,7 +992,7 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 8,
BackupSizeMb: 600,
CreatedAt: now.Add(-30 * time.Minute),
}
newerRecentBackup := &backups_core.Backup{
@@ -968,7 +1000,7 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 8,
BackupSizeMb: 600,
CreatedAt: now.Add(-10 * time.Minute),
}
@@ -977,8 +1009,11 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
err = backupRepository.Save(newerRecentBackup)
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
}
cleaner := CreateTestBackupCleaner(mockBilling)
err = cleaner.cleanExceededStorageBackups(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -991,6 +1026,82 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
)
}
func Test_CleanExceededStorageBackups_WhenNonCloud_SkipsCleanup(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodForever,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
Encryption: backups_config.BackupEncryptionEncrypted,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// 5 backups at 500 MB each = 2500 MB, would exceed 1 GB limit in cloud mode
now := time.Now().UTC()
for i := range 5 {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 500,
CreatedAt: now.Add(-time.Duration(i+2) * time.Hour),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
}
// IsCloud is false by default — cleaner should skip entirely
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
}
cleaner := CreateTestBackupCleaner(mockBilling)
err = cleaner.cleanExceededStorageBackups(testLogger())
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Equal(t, 5, len(remainingBackups), "All backups must remain in non-cloud mode")
}
type mockBillingService struct {
subscription *billing_models.Subscription
err error
}
func (m *mockBillingService) GetSubscription(
logger *slog.Logger,
databaseID uuid.UUID,
) (*billing_models.Subscription, error) {
return m.subscription, m.err
}
// Mock listener for testing
type mockBackupRemoveListener struct {
onBeforeBackupRemove func(*backups_core.Backup) error
@@ -1004,6 +1115,203 @@ func (m *mockBackupRemoveListener) OnBeforeBackupRemove(backup *backups_core.Bac
return nil
}
func Test_CleanStaleUploadedBasebackups_MarksAsFailed(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
staleTime := time.Now().UTC().Add(-15 * time.Minute)
walBackupType := backups_core.PgWalBackupTypeFullBackup
staleBackup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
PgWalBackupType: &walBackupType,
UploadCompletedAt: &staleTime,
CreatedAt: staleTime,
}
err := backupRepository.Save(staleBackup)
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
assert.NoError(t, err)
updated, err := backupRepository.FindByID(staleBackup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusFailed, updated.Status)
assert.NotNil(t, updated.FailMessage)
assert.Contains(t, *updated.FailMessage, "finalization timed out")
}
func Test_CleanStaleUploadedBasebackups_SkipsRecentUploads(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
recentTime := time.Now().UTC().Add(-2 * time.Minute)
walBackupType := backups_core.PgWalBackupTypeFullBackup
recentBackup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
PgWalBackupType: &walBackupType,
UploadCompletedAt: &recentTime,
CreatedAt: recentTime,
}
err := backupRepository.Save(recentBackup)
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
assert.NoError(t, err)
updated, err := backupRepository.FindByID(recentBackup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusInProgress, updated.Status)
}
func Test_CleanStaleUploadedBasebackups_SkipsActiveStreaming(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
walBackupType := backups_core.PgWalBackupTypeFullBackup
activeBackup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
PgWalBackupType: &walBackupType,
CreatedAt: time.Now().UTC().Add(-30 * time.Minute),
}
err := backupRepository.Save(activeBackup)
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
assert.NoError(t, err)
updated, err := backupRepository.FindByID(activeBackup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusInProgress, updated.Status)
assert.Nil(t, updated.UploadCompletedAt)
}
func Test_CleanStaleUploadedBasebackups_CleansStorageFiles(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
staleTime := time.Now().UTC().Add(-15 * time.Minute)
walBackupType := backups_core.PgWalBackupTypeFullBackup
staleBackup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
PgWalBackupType: &walBackupType,
UploadCompletedAt: &staleTime,
BackupSizeMb: 500,
FileName: "stale-basebackup-test-file",
CreatedAt: staleTime,
}
err := backupRepository.Save(staleBackup)
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
assert.NoError(t, err)
updated, err := backupRepository.FindByID(staleBackup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusFailed, updated.Status)
assert.NotNil(t, updated.FailMessage)
assert.Contains(t, *updated.FailMessage, "finalization timed out")
}
func enableCloud(t *testing.T) {
t.Helper()
config.GetEnv().IsCloud = true
t.Cleanup(func() {
config.GetEnv().IsCloud = false
})
}
func testLogger() *slog.Logger {
return logger.GetLogger().With("task_name", "test")
}
func createTestInterval() *intervals.Interval {
timeOfDay := "04:00"
interval := &intervals.Interval{

View File

@@ -1,7 +1,6 @@
package backuping
import (
"sync"
"sync/atomic"
"time"
@@ -10,6 +9,7 @@ import (
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/usecases"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/billing"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
@@ -28,10 +28,10 @@ var backupCleaner = &BackupCleaner{
backupRepository,
storages.GetStorageService(),
backups_config.GetBackupConfigService(),
billing.GetBillingService(),
encryption.GetFieldEncryptor(),
logger.GetLogger(),
[]backups_core.BackupRemoveListener{},
sync.Once{},
atomic.Bool{},
}
@@ -41,7 +41,6 @@ var backupNodesRegistry = &BackupNodesRegistry{
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
sync.Once{},
atomic.Bool{},
}
@@ -63,7 +62,6 @@ var backuperNode = &BackuperNode{
usecases.GetCreateBackupUsecase(),
getNodeID(),
time.Time{},
sync.Once{},
atomic.Bool{},
}
@@ -73,11 +71,11 @@ var backupsScheduler = &BackupsScheduler{
taskCancelManager,
backupNodesRegistry,
databases.GetDatabaseService(),
billing.GetBillingService(),
time.Now().UTC(),
logger.GetLogger(),
make(map[uuid.UUID]BackupToNodeRelation),
backuperNode,
sync.Once{},
atomic.Bool{},
}

View File

@@ -0,0 +1,13 @@
package backuping
import (
"log/slog"
"github.com/google/uuid"
billing_models "databasus-backend/internal/features/billing/models"
)
type BillingService interface {
GetSubscription(logger *slog.Logger, databaseID uuid.UUID) (*billing_models.Subscription, error)
}

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"log/slog"
"strings"
"sync"
"sync/atomic"
"time"
@@ -50,36 +49,30 @@ type BackupNodesRegistry struct {
pubsubBackups *cache_utils.PubSubManager
pubsubCompletions *cache_utils.PubSubManager
runOnce sync.Once
hasRun atomic.Bool
hasRun atomic.Bool
}
func (r *BackupNodesRegistry) Run(ctx context.Context) {
wasAlreadyRun := r.hasRun.Load()
if r.hasRun.Swap(true) {
panic(fmt.Sprintf("%T.Run() called multiple times", r))
}
r.runOnce.Do(func() {
r.hasRun.Store(true)
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
}
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
}
ticker := time.NewTicker(cleanupTickerInterval)
defer ticker.Stop()
ticker := time.NewTicker(cleanupTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes", "error", err)
}
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes", "error", err)
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", r))
}
}

View File

@@ -4,8 +4,6 @@ import (
"context"
"encoding/json"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
@@ -322,7 +320,7 @@ func Test_GetAvailableNodes_SkipsInvalidJsonData(t *testing.T) {
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
ctx, cancel := context.WithTimeout(t.Context(), registry.timeout)
defer cancel()
invalidKey := nodeInfoKeyPrefix + uuid.New().String() + nodeInfoKeySuffix
@@ -331,7 +329,7 @@ func Test_GetAvailableNodes_SkipsInvalidJsonData(t *testing.T) {
registry.client.B().Set().Key(invalidKey).Value("invalid json data").Build(),
)
defer func() {
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), registry.timeout)
cleanupCtx, cleanupCancel := context.WithTimeout(t.Context(), registry.timeout)
defer cleanupCancel()
registry.client.Do(cleanupCtx, registry.client.B().Del().Key(invalidKey).Build())
}()
@@ -401,7 +399,7 @@ func Test_GetAvailableNodes_ExcludesStaleNodesFromCache(t *testing.T) {
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
assert.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
ctx, cancel := context.WithTimeout(t.Context(), registry.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
@@ -419,7 +417,7 @@ func Test_GetAvailableNodes_ExcludesStaleNodesFromCache(t *testing.T) {
modifiedData, err := json.Marshal(node)
assert.NoError(t, err)
setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout)
setCtx, setCancel := context.WithTimeout(t.Context(), registry.timeout)
defer setCancel()
setResult := registry.client.Do(
setCtx,
@@ -464,7 +462,7 @@ func Test_GetBackupNodesStats_ExcludesStaleNodesFromCache(t *testing.T) {
err = registry.IncrementBackupsInProgress(node3.ID)
assert.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
ctx, cancel := context.WithTimeout(t.Context(), registry.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
@@ -482,7 +480,7 @@ func Test_GetBackupNodesStats_ExcludesStaleNodesFromCache(t *testing.T) {
modifiedData, err := json.Marshal(node)
assert.NoError(t, err)
setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout)
setCtx, setCancel := context.WithTimeout(t.Context(), registry.timeout)
defer setCancel()
setResult := registry.client.Do(
setCtx,
@@ -524,7 +522,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) {
err = registry.IncrementBackupsInProgress(node2.ID)
assert.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
ctx, cancel := context.WithTimeout(t.Context(), registry.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
@@ -542,7 +540,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) {
modifiedData, err := json.Marshal(node)
assert.NoError(t, err)
setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout)
setCtx, setCancel := context.WithTimeout(t.Context(), registry.timeout)
defer setCancel()
setResult := registry.client.Do(
setCtx,
@@ -553,7 +551,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) {
err = registry.cleanupDeadNodes()
assert.NoError(t, err)
checkCtx, checkCancel := context.WithTimeout(context.Background(), registry.timeout)
checkCtx, checkCancel := context.WithTimeout(t.Context(), registry.timeout)
defer checkCancel()
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
@@ -566,7 +564,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) {
node2.ID.String(),
nodeActiveBackupsSuffix,
)
counterCtx, counterCancel := context.WithTimeout(context.Background(), registry.timeout)
counterCtx, counterCancel := context.WithTimeout(t.Context(), registry.timeout)
defer counterCancel()
counterResult := registry.client.Do(
counterCtx,
@@ -575,7 +573,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) {
assert.Error(t, counterResult.Error())
activeInfoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node1.ID.String(), nodeInfoKeySuffix)
activeCtx, activeCancel := context.WithTimeout(context.Background(), registry.timeout)
activeCtx, activeCancel := context.WithTimeout(t.Context(), registry.timeout)
defer activeCancel()
activeResult := registry.client.Do(
activeCtx,
@@ -601,8 +599,6 @@ func createTestRegistry() *BackupNodesRegistry {
timeout: cache_utils.DefaultCacheTimeout,
pubsubBackups: cache_utils.NewPubSubManager(),
pubsubCompletions: cache_utils.NewPubSubManager(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}
@@ -732,7 +728,7 @@ func Test_SubscribeNodeForBackupsAssignment_HandlesInvalidJson(t *testing.T) {
time.Sleep(100 * time.Millisecond)
ctx := context.Background()
ctx := t.Context()
err = registry.pubsubBackups.Publish(ctx, "backup:submit", "invalid json")
assert.NoError(t, err)
@@ -978,7 +974,7 @@ func Test_SubscribeForBackupsCompletions_HandlesInvalidJson(t *testing.T) {
time.Sleep(100 * time.Millisecond)
ctx := context.Background()
ctx := t.Context()
err = registry.pubsubCompletions.Publish(ctx, "backup:completion", "invalid json")
assert.NoError(t, err)
@@ -1093,7 +1089,7 @@ func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) {
receivedAll2 := []uuid.UUID{}
receivedAll3 := []uuid.UUID{}
for i := 0; i < 3; i++ {
for range 3 {
select {
case received := <-receivedBackups1:
receivedAll1 = append(receivedAll1, received)
@@ -1102,7 +1098,7 @@ func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) {
}
}
for i := 0; i < 3; i++ {
for range 3 {
select {
case received := <-receivedBackups2:
receivedAll2 = append(receivedAll2, received)
@@ -1111,7 +1107,7 @@ func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) {
}
}
for i := 0; i < 3; i++ {
for range 3 {
select {
case received := <-receivedBackups3:
receivedAll3 = append(receivedAll3, received)

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
@@ -29,6 +28,7 @@ type BackupsScheduler struct {
taskCancelManager *task_cancellation.TaskCancelManager
backupNodesRegistry *BackupNodesRegistry
databaseService *databases.DatabaseService
billingService BillingService
lastBackupTime time.Time
logger *slog.Logger
@@ -36,68 +36,61 @@ type BackupsScheduler struct {
backupToNodeRelations map[uuid.UUID]BackupToNodeRelation
backuperNode *BackuperNode
runOnce sync.Once
hasRun atomic.Bool
hasRun atomic.Bool
}
func (s *BackupsScheduler) Run(ctx context.Context) {
wasAlreadyRun := s.hasRun.Load()
s.runOnce.Do(func() {
s.hasRun.Store(true)
s.lastBackupTime = time.Now().UTC()
if config.GetEnv().IsManyNodesMode {
// wait other nodes to start
time.Sleep(schedulerStartupDelay)
}
if err := s.failBackupsInProgress(); err != nil {
s.logger.Error("Failed to fail backups in progress", "error", err)
panic(err)
}
err := s.backupNodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted)
if err != nil {
s.logger.Error("Failed to subscribe to backup completions", "error", err)
panic(err)
}
defer func() {
if err := s.backupNodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
}
}()
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(schedulerTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.checkDeadNodesAndFailBackups(); err != nil {
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
}
if err := s.runPendingBackups(); err != nil {
s.logger.Error("Failed to run pending backups", "error", err)
}
s.lastBackupTime = time.Now().UTC()
}
}
})
if wasAlreadyRun {
if s.hasRun.Swap(true) {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
s.lastBackupTime = time.Now().UTC()
if config.GetEnv().IsManyNodesMode {
// wait other nodes to start
time.Sleep(schedulerStartupDelay)
}
if err := s.failBackupsInProgress(); err != nil {
s.logger.Error("Failed to fail backups in progress", "error", err)
panic(err)
}
err := s.backupNodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted)
if err != nil {
s.logger.Error("Failed to subscribe to backup completions", "error", err)
panic(err)
}
defer func() {
if err := s.backupNodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
}
}()
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(schedulerTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.checkDeadNodesAndFailBackups(); err != nil {
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
}
if err := s.runPendingBackups(); err != nil {
s.logger.Error("Failed to run pending backups", "error", err)
}
s.lastBackupTime = time.Now().UTC()
}
}
}
func (s *BackupsScheduler) IsSchedulerRunning() bool {
@@ -127,6 +120,34 @@ func (s *BackupsScheduler) StartBackup(database *databases.Database, isCallNotif
return
}
if config.GetEnv().IsCloud {
subscription, subErr := s.billingService.GetSubscription(s.logger, database.ID)
if subErr != nil || !subscription.CanCreateNewBackups() {
failMessage := "subscription has expired, please renew"
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: *backupConfig.StorageID,
Status: backups_core.BackupStatusFailed,
FailMessage: &failMessage,
IsSkipRetry: true,
CreatedAt: time.Now().UTC(),
}
backup.GenerateFilename(database.Name)
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error(
"failed to save failed backup for expired subscription",
"database_id", database.ID,
"error", err,
)
}
return
}
}
// Check for existing in-progress backups
inProgressBackups, err := s.backupRepository.FindByDatabaseIdAndStatus(
database.ID,
@@ -342,6 +363,31 @@ func (s *BackupsScheduler) runPendingBackups() error {
continue
}
if database.IsAgentManagedBackup() {
continue
}
if config.GetEnv().IsCloud {
subscription, subErr := s.billingService.GetSubscription(s.logger, backupConfig.DatabaseID)
if subErr != nil {
s.logger.Warn(
"failed to get subscription, skipping backup",
"database_id", backupConfig.DatabaseID,
"error", subErr,
)
continue
}
if !subscription.CanCreateNewBackups() {
s.logger.Debug(
"subscription is not active, skipping scheduled backup",
"database_id", backupConfig.DatabaseID,
"subscription_status", subscription.Status,
)
continue
}
}
s.StartBackup(database, remainedBackupTryCount == 1)
continue
}

View File

@@ -1,6 +1,7 @@
package backuping
import (
"context"
"testing"
"time"
@@ -9,6 +10,7 @@ import (
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
billing_models "databasus-backend/internal/features/billing/models"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
@@ -20,6 +22,128 @@ import (
"databasus-backend/internal/util/period"
)
func Test_RunPendingBackups_ByDatabaseType_OnlySchedulesNonAgentManagedBackups(t *testing.T) {
type testCase struct {
name string
createDatabase func(workspaceID uuid.UUID, storage *storages.Storage, notifier *notifiers.Notifier) *databases.Database
isBackupExpected bool
needsBackuperNode bool
}
testCases := []testCase{
{
name: "PostgreSQL PG_DUMP - backup runs",
createDatabase: func(workspaceID uuid.UUID, storage *storages.Storage, notifier *notifiers.Notifier) *databases.Database {
return databases.CreateTestDatabase(workspaceID, storage, notifier)
},
isBackupExpected: true,
needsBackuperNode: true,
},
{
name: "PostgreSQL WAL_V1 - backup skipped (agent-managed)",
createDatabase: func(workspaceID uuid.UUID, _ *storages.Storage, notifier *notifiers.Notifier) *databases.Database {
return databases.CreateTestPostgresWalDatabase(workspaceID, notifier)
},
isBackupExpected: false,
needsBackuperNode: false,
},
{
name: "MariaDB - backup runs",
createDatabase: func(workspaceID uuid.UUID, _ *storages.Storage, notifier *notifiers.Notifier) *databases.Database {
return databases.CreateTestMariadbDatabase(workspaceID, notifier)
},
isBackupExpected: true,
needsBackuperNode: true,
},
{
name: "MongoDB - backup runs",
createDatabase: func(workspaceID uuid.UUID, _ *storages.Storage, notifier *notifiers.Notifier) *databases.Database {
return databases.CreateTestMongodbDatabase(workspaceID, notifier)
},
isBackupExpected: true,
needsBackuperNode: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
cache_utils.ClearAllCache()
var backuperNode *BackuperNode
var cancel context.CancelFunc
if tc.needsBackuperNode {
backuperNode = CreateTestBackuperNode()
cancel = StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
}
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := tc.createDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// add old backup (24h ago)
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
})
GetBackupsScheduler().runPendingBackups()
if tc.isBackupExpected {
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 2)
} else {
time.Sleep(100 * time.Millisecond)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
}
time.Sleep(200 * time.Millisecond)
})
}
}
func Test_RunPendingBackups_WhenLastBackupWasYesterday_CreatesNewBackup(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
@@ -845,7 +969,7 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T
cache_utils.ClearAllCache()
// Start scheduler so it can handle task completions
scheduler := CreateTestScheduler()
scheduler := CreateTestScheduler(nil)
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
@@ -942,7 +1066,7 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
cache_utils.ClearAllCache()
// Start scheduler so it can handle task completions
scheduler := CreateTestScheduler()
scheduler := CreateTestScheduler(nil)
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
@@ -1209,7 +1333,7 @@ func Test_StartBackup_When2BackupsStartedForDifferentDatabases_BothUseCasesAreCa
defer StopBackuperNodeForTest(t, cancel, backuperNode)
// Create scheduler
scheduler := CreateTestScheduler()
scheduler := CreateTestScheduler(nil)
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
@@ -1335,3 +1459,313 @@ func Test_StartBackup_When2BackupsStartedForDifferentDatabases_BothUseCasesAreCa
time.Sleep(200 * time.Millisecond)
}
func Test_StartBackup_WhenCloudAndSubscriptionExpired_CreatesFailedBackup(t *testing.T) {
cache_utils.ClearAllCache()
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{
Status: billing_models.StatusExpired,
},
}
scheduler := CreateTestScheduler(mockBilling)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
enableCloud(t)
scheduler.StartBackup(database, false)
time.Sleep(100 * time.Millisecond)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
newestBackup := backups[0]
assert.Equal(t, backups_core.BackupStatusFailed, newestBackup.Status)
assert.NotNil(t, newestBackup.FailMessage)
assert.Equal(t, "subscription has expired, please renew", *newestBackup.FailMessage)
assert.True(t, newestBackup.IsSkipRetry)
time.Sleep(200 * time.Millisecond)
}
func Test_StartBackup_WhenCloudAndSubscriptionActive_ProceedsNormally(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{
Status: billing_models.StatusActive,
StorageGB: 10,
},
}
scheduler := CreateTestScheduler(mockBilling)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
enableCloud(t)
scheduler.StartBackup(database, false)
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
newestBackup := backups[0]
assert.Equal(t, backups_core.BackupStatusCompleted, newestBackup.Status)
time.Sleep(200 * time.Millisecond)
}
func Test_RunPendingBackups_WhenCloudAndSubscriptionExpired_SilentlySkips(t *testing.T) {
cache_utils.ClearAllCache()
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{
Status: billing_models.StatusExpired,
},
}
scheduler := CreateTestScheduler(mockBilling)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
})
enableCloud(t)
scheduler.runPendingBackups()
time.Sleep(100 * time.Millisecond)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1, "No new backup should be created, scheduler silently skips expired subscriptions")
time.Sleep(200 * time.Millisecond)
}
func Test_StartBackup_WhenNotCloudAndSubscriptionExpired_ProceedsNormally(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{
Status: billing_models.StatusExpired,
},
}
scheduler := CreateTestScheduler(mockBilling)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
scheduler.StartBackup(database, false)
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
assert.Equal(t, backups_core.BackupStatusCompleted, backups[0].Status,
"Billing check should not apply in non-cloud mode")
time.Sleep(200 * time.Millisecond)
}
func Test_RunPendingBackups_WhenNotCloudAndSubscriptionExpired_ProceedsNormally(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
mockBilling := &mockBillingService{
subscription: &billing_models.Subscription{
Status: billing_models.StatusExpired,
},
}
scheduler := CreateTestScheduler(mockBilling)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
})
scheduler.runPendingBackups()
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 2, "Billing check should not apply in non-cloud mode, new backup should be created")
time.Sleep(200 * time.Millisecond)
}

View File

@@ -3,7 +3,6 @@ package backuping
import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
@@ -35,58 +34,70 @@ func CreateTestRouter() *gin.Engine {
return router
}
func CreateTestBackupCleaner(billingService BillingService) *BackupCleaner {
return &BackupCleaner{
backupRepository,
storages.GetStorageService(),
backups_config.GetBackupConfigService(),
billingService,
encryption.GetFieldEncryptor(),
logger.GetLogger(),
[]backups_core.BackupRemoveListener{},
atomic.Bool{},
}
}
func CreateTestBackuperNode() *BackuperNode {
return &BackuperNode{
databaseService: databases.GetDatabaseService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
workspaceService: workspaces_services.GetWorkspaceService(),
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
notificationSender: notifiers.GetNotifierService(),
backupCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
logger: logger.GetLogger(),
createBackupUseCase: usecases.GetCreateBackupUsecase(),
nodeID: uuid.New(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
databases.GetDatabaseService(),
encryption.GetFieldEncryptor(),
workspaces_services.GetWorkspaceService(),
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
taskCancelManager,
backupNodesRegistry,
logger.GetLogger(),
usecases.GetCreateBackupUsecase(),
uuid.New(),
time.Time{},
atomic.Bool{},
}
}
func CreateTestBackuperNodeWithUseCase(useCase backups_core.CreateBackupUsecase) *BackuperNode {
return &BackuperNode{
databaseService: databases.GetDatabaseService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
workspaceService: workspaces_services.GetWorkspaceService(),
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
notificationSender: notifiers.GetNotifierService(),
backupCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
logger: logger.GetLogger(),
createBackupUseCase: useCase,
nodeID: uuid.New(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
databases.GetDatabaseService(),
encryption.GetFieldEncryptor(),
workspaces_services.GetWorkspaceService(),
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
taskCancelManager,
backupNodesRegistry,
logger.GetLogger(),
useCase,
uuid.New(),
time.Time{},
atomic.Bool{},
}
}
func CreateTestScheduler() *BackupsScheduler {
func CreateTestScheduler(billingService BillingService) *BackupsScheduler {
return &BackupsScheduler{
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
taskCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
lastBackupTime: time.Now().UTC(),
logger: logger.GetLogger(),
backupToNodeRelations: make(map[uuid.UUID]BackupToNodeRelation),
backuperNode: CreateTestBackuperNode(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
backupRepository,
backups_config.GetBackupConfigService(),
taskCancelManager,
backupNodesRegistry,
databases.GetDatabaseService(),
billingService,
time.Now().UTC(),
logger.GetLogger(),
make(map[uuid.UUID]BackupToNodeRelation),
CreateTestBackuperNode(),
atomic.Bool{},
}
}

View File

@@ -40,12 +40,15 @@ func (c *BackupController) RegisterPublicRoutes(router *gin.RouterGroup) {
// GetBackups
// @Summary Get backups for a database
// @Description Get paginated backups for the specified database
// @Description Get paginated backups for the specified database with optional filters
// @Tags backups
// @Produce json
// @Param database_id query string true "Database ID"
// @Param limit query int false "Number of items per page" default(10)
// @Param offset query int false "Offset for pagination" default(0)
// @Param status query []string false "Filter by backup status (can be repeated)" Enums(IN_PROGRESS, COMPLETED, FAILED, CANCELED)
// @Param beforeDate query string false "Filter backups created before this date (RFC3339)" format(date-time)
// @Param pgWalBackupType query string false "Filter by WAL backup type" Enums(PG_FULL_BACKUP, PG_WAL_SEGMENT)
// @Success 200 {object} backups_dto.GetBackupsResponse
// @Failure 400
// @Failure 401
@@ -70,7 +73,9 @@ func (c *BackupController) GetBackups(ctx *gin.Context) {
return
}
response, err := c.backupService.GetBackups(user, databaseID, request.Limit, request.Offset)
filters := c.buildBackupFilters(&request)
response, err := c.backupService.GetBackups(user, databaseID, request.Limit, request.Offset, filters)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -359,3 +364,35 @@ func (c *BackupController) startDownloadHeartbeat(ctx context.Context, userID uu
}
}
}
func (c *BackupController) buildBackupFilters(
request *backups_dto.GetBackupsRequest,
) *backups_core.BackupFilters {
isHasFilters := len(request.Statuses) > 0 ||
request.BeforeDate != nil ||
request.PgWalBackupType != nil
if !isHasFilters {
return nil
}
filters := &backups_core.BackupFilters{}
if len(request.Statuses) > 0 {
statuses := make([]backups_core.BackupStatus, 0, len(request.Statuses))
for _, statusStr := range request.Statuses {
statuses = append(statuses, backups_core.BackupStatus(statusStr))
}
filters.Statuses = statuses
}
filters.BeforeDate = request.BeforeDate
if request.PgWalBackupType != nil {
walType := backups_core.PgWalBackupType(*request.PgWalBackupType)
filters.PgWalBackupType = &walType
}
return filters
}

View File

@@ -140,6 +140,225 @@ func Test_GetBackups_PermissionsEnforced(t *testing.T) {
}
}
func Test_GetBackups_WithStatusFilter_ReturnsFilteredBackups(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
now := time.Now().UTC()
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
Status: backups_core.BackupStatusCompleted,
CreatedAt: now.Add(-3 * time.Hour),
})
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
Status: backups_core.BackupStatusFailed,
CreatedAt: now.Add(-2 * time.Hour),
})
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
Status: backups_core.BackupStatusCanceled,
CreatedAt: now.Add(-1 * time.Hour),
})
// Single status filter
var singleResponse backups_dto.GetBackupsResponse
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups?database_id=%s&status=COMPLETED", database.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&singleResponse,
)
assert.Equal(t, int64(1), singleResponse.Total)
assert.Len(t, singleResponse.Backups, 1)
assert.Equal(t, backups_core.BackupStatusCompleted, singleResponse.Backups[0].Status)
// Multiple status filter
var multiResponse backups_dto.GetBackupsResponse
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf(
"/api/v1/backups?database_id=%s&status=COMPLETED&status=FAILED",
database.ID.String(),
),
"Bearer "+owner.Token,
http.StatusOK,
&multiResponse,
)
assert.Equal(t, int64(2), multiResponse.Total)
assert.Len(t, multiResponse.Backups, 2)
for _, backup := range multiResponse.Backups {
assert.True(
t,
backup.Status == backups_core.BackupStatusCompleted ||
backup.Status == backups_core.BackupStatusFailed,
"expected COMPLETED or FAILED, got %s", backup.Status,
)
}
}
func Test_GetBackups_WithBeforeDateFilter_ReturnsFilteredBackups(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
now := time.Now().UTC()
cutoff := now.Add(-1 * time.Hour)
olderBackup := CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
Status: backups_core.BackupStatusCompleted,
CreatedAt: now.Add(-3 * time.Hour),
})
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
Status: backups_core.BackupStatusCompleted,
CreatedAt: now,
})
var response backups_dto.GetBackupsResponse
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf(
"/api/v1/backups?database_id=%s&beforeDate=%s",
database.ID.String(),
cutoff.Format(time.RFC3339),
),
"Bearer "+owner.Token,
http.StatusOK,
&response,
)
assert.Equal(t, int64(1), response.Total)
assert.Len(t, response.Backups, 1)
assert.Equal(t, olderBackup.ID, response.Backups[0].ID)
}
func Test_GetBackups_WithPgWalBackupTypeFilter_ReturnsFilteredBackups(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
now := time.Now().UTC()
fullBackupType := backups_core.PgWalBackupTypeFullBackup
walSegmentType := backups_core.PgWalBackupTypeWalSegment
fullBackup := CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
Status: backups_core.BackupStatusCompleted,
CreatedAt: now.Add(-2 * time.Hour),
PgWalBackupType: &fullBackupType,
})
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
Status: backups_core.BackupStatusCompleted,
CreatedAt: now.Add(-1 * time.Hour),
PgWalBackupType: &walSegmentType,
})
var response backups_dto.GetBackupsResponse
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf(
"/api/v1/backups?database_id=%s&pgWalBackupType=PG_FULL_BACKUP",
database.ID.String(),
),
"Bearer "+owner.Token,
http.StatusOK,
&response,
)
assert.Equal(t, int64(1), response.Total)
assert.Len(t, response.Backups, 1)
assert.Equal(t, fullBackup.ID, response.Backups[0].ID)
}
func Test_GetBackups_WithCombinedFilters_ReturnsFilteredBackups(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
now := time.Now().UTC()
cutoff := now.Add(-1 * time.Hour)
// Old completed — should match
oldCompleted := CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
Status: backups_core.BackupStatusCompleted,
CreatedAt: now.Add(-3 * time.Hour),
})
// Old failed — should NOT match (wrong status)
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
Status: backups_core.BackupStatusFailed,
CreatedAt: now.Add(-2 * time.Hour),
})
// New completed — should NOT match (too recent)
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
Status: backups_core.BackupStatusCompleted,
CreatedAt: now,
})
var response backups_dto.GetBackupsResponse
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf(
"/api/v1/backups?database_id=%s&status=COMPLETED&beforeDate=%s",
database.ID.String(),
cutoff.Format(time.RFC3339),
),
"Bearer "+owner.Token,
http.StatusOK,
&response,
)
assert.Equal(t, int64(1), response.Total)
assert.Len(t, response.Backups, 1)
assert.Equal(t, oldCompleted.ID, response.Backups[0].ID)
}
func Test_CreateBackup_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
@@ -376,7 +595,7 @@ func Test_DeleteBackup_PermissionsEnforced(t *testing.T) {
ownerUser, err := userService.GetUserFromToken(owner.Token)
assert.NoError(t, err)
response, err := backups_services.GetBackupService().GetBackups(ownerUser, database.ID, 10, 0)
response, err := backups_services.GetBackupService().GetBackups(ownerUser, database.ID, 10, 0, nil)
assert.NoError(t, err)
assert.Equal(t, 0, len(response.Backups))
}
@@ -1263,7 +1482,7 @@ func Test_MakeBackup_VerifyBackupAndMetadataFilesExistInStorage(t *testing.T) {
backuperCancel := backuping.StartBackuperNodeForTest(t, backuperNode)
defer backuping.StopBackuperNodeForTest(t, backuperCancel, backuperNode)
scheduler := backuping.CreateTestScheduler()
scheduler := backuping.CreateTestScheduler(nil)
schedulerCancel := backuping.StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
@@ -1838,7 +2057,7 @@ func Test_DeleteBackup_RemovesBackupAndMetadataFilesFromDisk(t *testing.T) {
backuperCancel := backuping.StartBackuperNodeForTest(t, backuperNode)
defer backuping.StopBackuperNodeForTest(t, backuperCancel, backuperNode)
scheduler := backuping.CreateTestScheduler()
scheduler := backuping.CreateTestScheduler(nil)
schedulerCancel := backuping.StartSchedulerForTest(t, scheduler)
defer schedulerCancel()

View File

@@ -3,12 +3,10 @@ package backups_controllers
import (
"io"
"net/http"
"strconv"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_dto "databasus-backend/internal/features/backups/backups/dto"
backups_services "databasus-backend/internal/features/backups/backups/services"
"databasus-backend/internal/features/databases"
@@ -25,8 +23,11 @@ func (c *PostgreWalBackupController) RegisterRoutes(router *gin.RouterGroup) {
walRoutes := router.Group("/backups/postgres/wal")
walRoutes.GET("/next-full-backup-time", c.GetNextFullBackupTime)
walRoutes.GET("/is-wal-chain-valid-since-last-full-backup", c.IsWalChainValidSinceLastBackup)
walRoutes.POST("/error", c.ReportError)
walRoutes.POST("/upload", c.Upload)
walRoutes.POST("/upload/wal", c.UploadWalSegment)
walRoutes.POST("/upload/full-start", c.StartFullBackupUpload)
walRoutes.POST("/upload/full-complete", c.CompleteFullBackupUpload)
walRoutes.GET("/restore/plan", c.GetRestorePlan)
walRoutes.GET("/restore/download", c.DownloadBackupFile)
}
@@ -90,91 +91,66 @@ func (c *PostgreWalBackupController) ReportError(ctx *gin.Context) {
ctx.Status(http.StatusOK)
}
// Upload
// @Summary Stream upload a basebackup or WAL segment
// @Description Accepts a zstd-compressed binary stream and stores it in the database's configured storage.
// The server generates the storage filename; agents do not control the destination path.
// For WAL segment uploads the server validates the WAL chain and returns 409 if a gap is detected
// or 400 if no full backup exists yet (agent should trigger a full basebackup in both cases).
// IsWalChainValidSinceLastBackup
// @Summary Check WAL chain validity since last full backup
// @Description Checks whether the WAL chain is continuous since the last completed full backup.
// Returns isValid=true if the chain is intact, or isValid=false with error details if not.
// @Tags backups-wal
// @Accept application/octet-stream
// @Produce json
// @Security AgentToken
// @Param X-Upload-Type header string true "Upload type" Enums(basebackup, wal)
// @Param X-Wal-Segment-Name header string false "24-hex WAL segment identifier (required for wal uploads, e.g. 0000000100000001000000AB)"
// @Param X-Wal-Segment-Size header int false "WAL segment size in bytes reported by the PostgreSQL instance (default: 16777216)"
// @Param fullBackupWalStartSegment query string false "First WAL segment needed to make the basebackup consistent (required for basebackup uploads)"
// @Param fullBackupWalStopSegment query string false "Last WAL segment included in the basebackup (required for basebackup uploads)"
// @Success 204
// @Failure 400 {object} backups_dto.UploadGapResponse "No full backup exists (error: no_full_backup)"
// @Success 200 {object} backups_dto.IsWalChainValidResponse
// @Failure 401 {object} map[string]string
// @Failure 409 {object} backups_dto.UploadGapResponse "WAL chain gap detected (error: gap_detected)"
// @Failure 500 {object} map[string]string
// @Router /backups/postgres/wal/upload [post]
func (c *PostgreWalBackupController) Upload(ctx *gin.Context) {
// @Router /backups/postgres/wal/is-wal-chain-valid-since-last-full-backup [get]
func (c *PostgreWalBackupController) IsWalChainValidSinceLastBackup(ctx *gin.Context) {
database, err := c.getDatabase(ctx)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
return
}
uploadType := backups_core.PgWalUploadType(ctx.GetHeader("X-Upload-Type"))
if uploadType != backups_core.PgWalUploadTypeBasebackup &&
uploadType != backups_core.PgWalUploadTypeWal {
response, err := c.walService.IsWalChainValid(database)
if err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, response)
}
// UploadWalSegment
// @Summary Stream upload a WAL segment
// @Description Accepts a zstd-compressed WAL segment binary stream and stores it in the database's configured storage.
// WAL segments are accepted unconditionally.
// @Tags backups-wal
// @Accept application/octet-stream
// @Security AgentToken
// @Param X-Wal-Segment-Name header string true "24-hex WAL segment identifier (e.g. 0000000100000001000000AB)"
// @Success 204
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /backups/postgres/wal/upload/wal [post]
func (c *PostgreWalBackupController) UploadWalSegment(ctx *gin.Context) {
database, err := c.getDatabase(ctx)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
return
}
walSegmentName := ctx.GetHeader("X-Wal-Segment-Name")
if walSegmentName == "" {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "X-Upload-Type must be 'basebackup' or 'wal'"},
gin.H{"error": "X-Wal-Segment-Name is required for wal uploads"},
)
return
}
walSegmentName := ""
if uploadType == backups_core.PgWalUploadTypeWal {
walSegmentName = ctx.GetHeader("X-Wal-Segment-Name")
if walSegmentName == "" {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "X-Wal-Segment-Name is required for wal uploads"},
)
return
}
}
if uploadType == backups_core.PgWalUploadTypeBasebackup {
if ctx.Query("fullBackupWalStartSegment") == "" ||
ctx.Query("fullBackupWalStopSegment") == "" {
ctx.JSON(
http.StatusBadRequest,
gin.H{
"error": "fullBackupWalStartSegment and fullBackupWalStopSegment are required for basebackup uploads",
},
)
return
}
}
walSegmentSizeBytes := int64(0)
if raw := ctx.GetHeader("X-Wal-Segment-Size"); raw != "" {
parsed, parseErr := strconv.ParseInt(raw, 10, 64)
if parseErr != nil || parsed <= 0 {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "X-Wal-Segment-Size must be a positive integer"},
)
return
}
walSegmentSizeBytes = parsed
}
gapResp, uploadErr := c.walService.UploadWal(
uploadErr := c.walService.UploadWalSegment(
ctx.Request.Context(),
database,
uploadType,
walSegmentName,
ctx.Query("fullBackupWalStartSegment"),
ctx.Query("fullBackupWalStopSegment"),
walSegmentSizeBytes,
ctx.Request.Body,
)
@@ -183,17 +159,89 @@ func (c *PostgreWalBackupController) Upload(ctx *gin.Context) {
return
}
if gapResp != nil {
if gapResp.Error == "no_full_backup" {
ctx.JSON(http.StatusBadRequest, gapResp)
return
}
ctx.Status(http.StatusNoContent)
}
ctx.JSON(http.StatusConflict, gapResp)
// StartFullBackupUpload
// @Summary Stream upload a full basebackup (Phase 1)
// @Description Accepts a zstd-compressed basebackup binary stream and stores it in the database's configured storage.
// Returns a backupId that must be completed via /upload/full-complete with WAL segment names.
// @Tags backups-wal
// @Accept application/octet-stream
// @Produce json
// @Security AgentToken
// @Success 200 {object} backups_dto.UploadBasebackupResponse
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /backups/postgres/wal/upload/full-start [post]
func (c *PostgreWalBackupController) StartFullBackupUpload(ctx *gin.Context) {
database, err := c.getDatabase(ctx)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
return
}
ctx.Status(http.StatusNoContent)
backupID, uploadErr := c.walService.UploadBasebackup(
ctx.Request.Context(),
database,
ctx.Request.Body,
)
if uploadErr != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": uploadErr.Error()})
return
}
ctx.JSON(http.StatusOK, backups_dto.UploadBasebackupResponse{
BackupID: backupID,
})
}
// CompleteFullBackupUpload
// @Summary Complete a previously uploaded basebackup (Phase 2)
// @Description Sets WAL segment names and marks the basebackup as completed, or marks it as failed if an error is provided.
// @Tags backups-wal
// @Accept json
// @Security AgentToken
// @Param request body backups_dto.FinalizeBasebackupRequest true "Completion details"
// @Success 200
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /backups/postgres/wal/upload/full-complete [post]
func (c *PostgreWalBackupController) CompleteFullBackupUpload(ctx *gin.Context) {
database, err := c.getDatabase(ctx)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
return
}
var request backups_dto.FinalizeBasebackupRequest
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if request.Error == nil && (request.StartSegment == "" || request.StopSegment == "") {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "startSegment and stopSegment are required when no error is provided"},
)
return
}
if err := c.walService.FinalizeBasebackup(
database,
request.BackupID,
request.StartSegment,
request.StopSegment,
request.Error,
); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.Status(http.StatusOK)
}
// GetRestorePlan

View File

@@ -38,7 +38,7 @@ func Test_WalUpload_InProgressStatusSetBeforeStream(t *testing.T) {
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
pr, pw := io.Pipe()
req := newWalUploadRequest(pr, agentToken, "wal", "000000010000000100000011", "", "")
req := newWalSegmentUploadRequest(pr, agentToken, "000000010000000100000011")
w := httptest.NewRecorder()
done := make(chan struct{})
@@ -67,7 +67,7 @@ func Test_WalUpload_CompletedStatusAfterSuccessfulStream(t *testing.T) {
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
body := bytes.NewReader([]byte("wal segment content"))
req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000011", "", "")
req := newWalSegmentUploadRequest(body, agentToken, "000000010000000100000011")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -99,7 +99,7 @@ func Test_WalUpload_FailedStatusWithErrorOnStreamError(t *testing.T) {
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
pr, pw := io.Pipe()
req := newWalUploadRequest(pr, agentToken, "wal", "000000010000000100000011", "", "")
req := newWalSegmentUploadRequest(pr, agentToken, "000000010000000100000011")
w := httptest.NewRecorder()
done := make(chan struct{})
@@ -129,59 +129,171 @@ func Test_WalUpload_FailedStatusWithErrorOnStreamError(t *testing.T) {
assert.NotNil(t, walBackup.FailMessage)
}
func Test_WalUpload_Basebackup_MissingWalSegments_Returns400(t *testing.T) {
func Test_WalUpload_Basebackup_StreamingUpload_Returns200WithBackupId(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
body := bytes.NewReader([]byte("basebackup content"))
req := newWalUploadRequest(body, agentToken, backups_core.PgWalUploadTypeBasebackup, "", "", "")
req, _ := http.NewRequest(http.MethodPost, "/api/v1/backups/postgres/wal/upload/full-start", body)
req.Header.Set("Authorization", agentToken)
req.Header.Set("Content-Type", "application/octet-stream")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
var response backups_dto.UploadBasebackupResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
assert.NotEqual(t, uuid.Nil, response.BackupID)
backup, err := backups_core.GetBackupRepository().FindByID(response.BackupID)
require.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusInProgress, backup.Status)
assert.NotNil(t, backup.UploadCompletedAt)
}
func Test_FinalizeBasebackup_ValidSegments_MarksCompleted(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
backupID := uploadBasebackupPhase1(t, router, agentToken)
completeFullBackupUpload(t, router, agentToken, backupID,
"000000010000000100000001", "000000010000000100000010", nil)
backup, err := backups_core.GetBackupRepository().FindByID(backupID)
require.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
require.NotNil(t, backup.PgFullBackupWalStartSegmentName)
assert.Equal(t, "000000010000000100000001", *backup.PgFullBackupWalStartSegmentName)
require.NotNil(t, backup.PgFullBackupWalStopSegmentName)
assert.Equal(t, "000000010000000100000010", *backup.PgFullBackupWalStopSegmentName)
}
func Test_FinalizeBasebackup_WithError_MarksFailed(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
backupID := uploadBasebackupPhase1(t, router, agentToken)
errMsg := "pg_basebackup stderr parse failed"
completeFullBackupUpload(t, router, agentToken, backupID, "", "", &errMsg)
backup, err := backups_core.GetBackupRepository().FindByID(backupID)
require.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusFailed, backup.Status)
require.NotNil(t, backup.FailMessage)
assert.Equal(t, errMsg, *backup.FailMessage)
}
func Test_FinalizeBasebackup_InvalidBackupId_Returns400(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
nonExistentID := uuid.New()
body, _ := json.Marshal(backups_dto.FinalizeBasebackupRequest{
BackupID: nonExistentID,
StartSegment: "000000010000000100000001",
StopSegment: "000000010000000100000010",
})
req, _ := http.NewRequest(
http.MethodPost,
"/api/v1/backups/postgres/wal/upload/full-complete",
bytes.NewReader(body),
)
req.Header.Set("Authorization", agentToken)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
}
func Test_WalUpload_WalSegment_NoFullBackup_Returns400(t *testing.T) {
func Test_FinalizeBasebackup_AlreadyCompleted_Returns400(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
backupID := uploadBasebackupPhase1(t, router, agentToken)
completeFullBackupUpload(t, router, agentToken, backupID,
"000000010000000100000001", "000000010000000100000010", nil)
// Second finalize should fail.
body, _ := json.Marshal(backups_dto.FinalizeBasebackupRequest{
BackupID: backupID,
StartSegment: "000000010000000100000001",
StopSegment: "000000010000000100000010",
})
req, _ := http.NewRequest(
http.MethodPost,
"/api/v1/backups/postgres/wal/upload/full-complete",
bytes.NewReader(body),
)
req.Header.Set("Authorization", agentToken)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
}
func Test_FinalizeBasebackup_InvalidToken_Returns401(t *testing.T) {
router, db, storage, _, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
body, _ := json.Marshal(backups_dto.FinalizeBasebackupRequest{
BackupID: uuid.New(),
StartSegment: "000000010000000100000001",
StopSegment: "000000010000000100000010",
})
req, _ := http.NewRequest(
http.MethodPost,
"/api/v1/backups/postgres/wal/upload/full-complete",
bytes.NewReader(body),
)
req.Header.Set("Authorization", "invalid-token")
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
}
func Test_WalUpload_WalSegment_WithoutFullBackup_Returns204(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
// No full backup inserted — chain anchor is missing.
body := bytes.NewReader([]byte("wal content"))
req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000001", "", "")
req := newWalSegmentUploadRequest(body, agentToken, "000000010000000100000001")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
var resp backups_dto.UploadGapResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "no_full_backup", resp.Error)
assert.Equal(t, http.StatusNoContent, w.Code)
}
func Test_WalUpload_WalSegment_GapDetected_Returns409WithExpectedAndReceived(t *testing.T) {
func Test_WalUpload_WalSegment_WithGap_Returns204(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
// Full backup stops at ...0010; upload one WAL segment at ...0011.
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
uploadWalSegment(t, router, agentToken, "000000010000000100000011")
// Send ...0013 — should be rejected because ...0012 is missing.
// Skip ...0012, upload ...0013 — should succeed (no chain validation on upload).
body := bytes.NewReader([]byte("wal content"))
req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000013", "", "")
req := newWalSegmentUploadRequest(body, agentToken, "000000010000000100000013")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusConflict, w.Code)
var resp backups_dto.UploadGapResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "gap_detected", resp.Error)
assert.Equal(t, "000000010000000100000012", resp.ExpectedSegmentName)
assert.Equal(t, "000000010000000100000013", resp.ReceivedSegmentName)
assert.Equal(t, http.StatusNoContent, w.Code)
}
func Test_WalUpload_WalSegment_DuplicateSegment_Returns200Idempotent(t *testing.T) {
@@ -192,14 +304,14 @@ func Test_WalUpload_WalSegment_DuplicateSegment_Returns200Idempotent(t *testing.
// Upload ...0011 once.
body1 := bytes.NewReader([]byte("wal content"))
req1 := newWalUploadRequest(body1, agentToken, "wal", "000000010000000100000011", "", "")
req1 := newWalSegmentUploadRequest(body1, agentToken, "000000010000000100000011")
w1 := httptest.NewRecorder()
router.ServeHTTP(w1, req1)
require.Equal(t, http.StatusNoContent, w1.Code)
// Upload the same segment again — must return 204 (idempotent).
body2 := bytes.NewReader([]byte("wal content"))
req2 := newWalUploadRequest(body2, agentToken, "wal", "000000010000000100000011", "", "")
req2 := newWalSegmentUploadRequest(body2, agentToken, "000000010000000100000011")
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
@@ -228,7 +340,7 @@ func Test_WalUpload_WalSegment_ValidNextSegment_Returns200AndCreatesRecord(t *te
// First WAL segment after the full backup stop segment.
body := bytes.NewReader([]byte("wal segment data"))
req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000011", "", "")
req := newWalSegmentUploadRequest(body, agentToken, "000000010000000100000011")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -255,6 +367,108 @@ func Test_WalUpload_WalSegment_ValidNextSegment_Returns200AndCreatesRecord(t *te
assert.Equal(t, "000000010000000100000011", *walBackup.PgWalSegmentName)
}
func Test_IsWalChainValid_NoFullBackup_ReturnsFalse(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
var response backups_dto.IsWalChainValidResponse
test_utils.MakeGetRequestAndUnmarshal(
t, router,
"/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup",
agentToken,
http.StatusOK,
&response,
)
assert.False(t, response.IsValid)
assert.Equal(t, "no_full_backup", response.Error)
}
func Test_IsWalChainValid_FullBackupOnly_ReturnsTrue(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
var response backups_dto.IsWalChainValidResponse
test_utils.MakeGetRequestAndUnmarshal(
t, router,
"/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup",
agentToken,
http.StatusOK,
&response,
)
assert.True(t, response.IsValid)
assert.Empty(t, response.Error)
}
func Test_IsWalChainValid_ContinuousChain_ReturnsTrue(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
uploadWalSegment(t, router, agentToken, "000000010000000100000011")
uploadWalSegment(t, router, agentToken, "000000010000000100000012")
uploadWalSegment(t, router, agentToken, "000000010000000100000013")
var response backups_dto.IsWalChainValidResponse
test_utils.MakeGetRequestAndUnmarshal(
t, router,
"/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup",
agentToken,
http.StatusOK,
&response,
)
assert.True(t, response.IsValid)
}
func Test_IsWalChainValid_BrokenChain_ReturnsFalse(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
uploadWalSegment(t, router, agentToken, "000000010000000100000011")
uploadWalSegment(t, router, agentToken, "000000010000000100000012")
uploadWalSegment(t, router, agentToken, "000000010000000100000013")
// Delete the middle segment to create a gap.
middleSeg, err := backups_core.GetBackupRepository().FindWalSegmentByName(
db.ID, "000000010000000100000012",
)
require.NoError(t, err)
require.NotNil(t, middleSeg)
require.NoError(t, backups_core.GetBackupRepository().DeleteByID(middleSeg.ID))
var response backups_dto.IsWalChainValidResponse
test_utils.MakeGetRequestAndUnmarshal(
t, router,
"/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup",
agentToken,
http.StatusOK,
&response,
)
assert.False(t, response.IsValid)
assert.Equal(t, "wal_chain_broken", response.Error)
assert.Equal(t, "000000010000000100000011", response.LastContiguousSegment)
}
func Test_IsWalChainValid_InvalidToken_Returns401(t *testing.T) {
router, db, storage, _, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
resp := test_utils.MakeGetRequest(
t, router,
"/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup",
"invalid-token",
http.StatusUnauthorized,
)
assert.Contains(t, string(resp.Body), "invalid agent token")
}
func Test_ReportError_ValidTokenAndError_CreatesFailedBackupRecord(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
@@ -457,29 +671,14 @@ func Test_GetNextFullBackupTime_WalSegmentAfterFullBackup_DoesNotImpactTime(t *t
setHourlyInterval(t, router, db.ID, ownerToken)
// Upload basebackup via API.
bbBody := bytes.NewReader([]byte("basebackup content"))
bbReq := newWalUploadRequest(
bbBody, agentToken, backups_core.PgWalUploadTypeBasebackup, "",
"000000010000000100000001", "000000010000000100000010",
)
bbW := httptest.NewRecorder()
router.ServeHTTP(bbW, bbReq)
require.Equal(t, http.StatusNoContent, bbW.Code)
uploadBasebackup(t, router, agentToken,
"000000010000000100000001", "000000010000000100000010")
// Shift the full backup's CreatedAt to 2 hours ago.
twoHoursAgo := time.Now().UTC().Add(-2 * time.Hour)
updateLastFullBackupTime(t, db.ID, twoHoursAgo)
// Upload WAL segment via API.
walBody := bytes.NewReader([]byte("wal segment content"))
walReq := newWalUploadRequest(
walBody, agentToken, backups_core.PgWalUploadTypeWal,
"000000010000000100000011", "", "",
)
walW := httptest.NewRecorder()
router.ServeHTTP(walW, walReq)
require.Equal(t, http.StatusNoContent, walW.Code)
uploadWalSegment(t, router, agentToken, "000000010000000100000011")
var response backups_dto.GetNextFullBackupTimeResponse
test_utils.MakeGetRequestAndUnmarshal(
@@ -508,15 +707,8 @@ func Test_GetNextFullBackupTime_FailedBasebackup_DoesNotImpactTime(t *testing.T)
setHourlyInterval(t, router, db.ID, ownerToken)
// Upload a successful basebackup via API.
bbBody := bytes.NewReader([]byte("basebackup content"))
bbReq := newWalUploadRequest(
bbBody, agentToken, backups_core.PgWalUploadTypeBasebackup, "",
"000000010000000100000001", "000000010000000100000010",
)
bbW := httptest.NewRecorder()
router.ServeHTTP(bbW, bbReq)
require.Equal(t, http.StatusNoContent, bbW.Code)
uploadBasebackup(t, router, agentToken,
"000000010000000100000001", "000000010000000100000010")
// Shift the full backup's CreatedAt to 2 hours ago.
twoHoursAgo := time.Now().UTC().Add(-2 * time.Hour)
@@ -563,15 +755,8 @@ func Test_GetNextFullBackupTime_NewCompletedFullBackup_ImpactsTime(t *testing.T)
setHourlyInterval(t, router, db.ID, ownerToken)
// Upload first basebackup via API.
bb1 := bytes.NewReader([]byte("first basebackup"))
bb1Req := newWalUploadRequest(
bb1, agentToken, backups_core.PgWalUploadTypeBasebackup, "",
"000000010000000100000001", "000000010000000100000010",
)
bb1W := httptest.NewRecorder()
router.ServeHTTP(bb1W, bb1Req)
require.Equal(t, http.StatusNoContent, bb1W.Code)
uploadBasebackup(t, router, agentToken,
"000000010000000100000001", "000000010000000100000010")
// Shift the first backup's CreatedAt to 3 hours ago.
threeHoursAgo := time.Now().UTC().Add(-3 * time.Hour)
@@ -595,15 +780,8 @@ func Test_GetNextFullBackupTime_NewCompletedFullBackup_ImpactsTime(t *testing.T)
"first next time should be in the past (old backup)",
)
// Upload second basebackup via API (created now).
bb2 := bytes.NewReader([]byte("second basebackup"))
bb2Req := newWalUploadRequest(
bb2, agentToken, backups_core.PgWalUploadTypeBasebackup, "",
"000000010000000100000011", "000000010000000100000020",
)
bb2W := httptest.NewRecorder()
router.ServeHTTP(bb2W, bb2Req)
require.Equal(t, http.StatusNoContent, bb2W.Code)
uploadBasebackup(t, router, agentToken,
"000000010000000100000011", "000000010000000100000020")
var secondResponse backups_dto.GetNextFullBackupTimeResponse
test_utils.MakeGetRequestAndUnmarshal(
@@ -760,6 +938,42 @@ func Test_GetRestorePlan_WithInvalidBackupId_Returns400(t *testing.T) {
assert.Equal(t, "no_backups", errResp.Error)
}
func Test_GetRestorePlan_WithWalSegmentId_ResolvesFullBackupAndReturnsWals(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
uploadWalSegment(t, router, agentToken, "000000010000000100000011")
uploadWalSegment(t, router, agentToken, "000000010000000100000012")
uploadWalSegment(t, router, agentToken, "000000010000000100000013")
WaitForBackupCompletion(t, db.ID, 3, 5*time.Second)
walSegment, err := backups_core.GetBackupRepository().FindWalSegmentByName(
db.ID, "000000010000000100000012",
)
require.NoError(t, err)
require.NotNil(t, walSegment)
var response backups_dto.GetRestorePlanResponse
test_utils.MakeGetRequestAndUnmarshal(
t, router,
"/api/v1/backups/postgres/wal/restore/plan?backupId="+walSegment.ID.String(),
agentToken,
http.StatusOK,
&response,
)
assert.NotEqual(t, uuid.Nil, response.FullBackup.BackupID)
assert.Equal(t, "000000010000000100000001", response.FullBackup.FullBackupWalStartSegment)
assert.Equal(t, "000000010000000100000010", response.FullBackup.FullBackupWalStopSegment)
require.Len(t, response.WalSegments, 3)
assert.Equal(t, "000000010000000100000011", response.WalSegments[0].SegmentName)
assert.Equal(t, "000000010000000100000012", response.WalSegments[1].SegmentName)
assert.Equal(t, "000000010000000100000013", response.WalSegments[2].SegmentName)
assert.Greater(t, response.TotalSizeBytes, int64(0))
}
func Test_GetRestorePlan_WithInvalidToken_Returns401(t *testing.T) {
router, db, storage, _, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
@@ -817,6 +1031,140 @@ func Test_GetRestorePlan_WithInvalidBackupIdFormat_Returns400(t *testing.T) {
assert.Contains(t, string(resp.Body), "invalid backupId format")
}
func Test_WalUpload_WalSegment_CompletedBackup_HasNonZeroDuration(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
uploadWalSegment(t, router, agentToken, "000000010000000100000011")
WaitForBackupCompletion(t, db.ID, 1, 5*time.Second)
backups, err := backups_core.GetBackupRepository().FindByDatabaseID(db.ID)
require.NoError(t, err)
var walBackup *backups_core.Backup
for _, b := range backups {
if b.PgWalBackupType != nil &&
*b.PgWalBackupType == backups_core.PgWalBackupTypeWalSegment {
walBackup = b
break
}
}
require.NotNil(t, walBackup)
assert.Equal(t, backups_core.BackupStatusCompleted, walBackup.Status)
assert.Greater(t, walBackup.BackupDurationMs, int64(0),
"WAL segment backup should have non-zero duration")
}
func Test_WalUpload_Basebackup_CompletedBackup_HasNonZeroDuration(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
backupID := uploadBasebackupPhase1(t, router, agentToken)
completeFullBackupUpload(t, router, agentToken, backupID,
"000000010000000100000001", "000000010000000100000010", nil)
backup, err := backups_core.GetBackupRepository().FindByID(backupID)
require.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
assert.Greater(t, backup.BackupDurationMs, int64(0),
"base backup should have non-zero duration")
}
func Test_WalUpload_WalSegment_ProgressUpdatedDuringStream(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
pipeReader, pipeWriter := io.Pipe()
req := newWalSegmentUploadRequest(pipeReader, agentToken, "000000010000000100000011")
recorder := httptest.NewRecorder()
done := make(chan struct{})
go func() {
router.ServeHTTP(recorder, req)
close(done)
}()
// Write some data so the countingReader registers bytes.
_, err := pipeWriter.Write([]byte("wal-segment-progress-data"))
require.NoError(t, err)
// Wait for the progress tracker to tick (1s interval + margin).
time.Sleep(1500 * time.Millisecond)
backups, err := backups_core.GetBackupRepository().FindByDatabaseID(db.ID)
require.NoError(t, err)
var walBackup *backups_core.Backup
for _, b := range backups {
if b.PgWalBackupType != nil &&
*b.PgWalBackupType == backups_core.PgWalBackupTypeWalSegment {
walBackup = b
break
}
}
require.NotNil(t, walBackup)
assert.Equal(t, backups_core.BackupStatusInProgress, walBackup.Status)
assert.Greater(t, walBackup.BackupDurationMs, int64(0),
"duration should be tracked in real-time during upload")
assert.Greater(t, walBackup.BackupSizeMb, float64(0),
"size should be tracked in real-time during upload")
_ = pipeWriter.Close()
<-done
}
func Test_WalUpload_Basebackup_ProgressUpdatedDuringStream(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
pipeReader, pipeWriter := io.Pipe()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/backups/postgres/wal/upload/full-start", pipeReader)
req.Header.Set("Authorization", agentToken)
req.Header.Set("Content-Type", "application/octet-stream")
recorder := httptest.NewRecorder()
done := make(chan struct{})
go func() {
router.ServeHTTP(recorder, req)
close(done)
}()
// Write some data so the countingReader registers bytes.
_, err := pipeWriter.Write([]byte("basebackup-progress-data"))
require.NoError(t, err)
// Wait for the progress tracker to tick (1s interval + margin).
time.Sleep(1500 * time.Millisecond)
backups, err := backups_core.GetBackupRepository().FindByDatabaseID(db.ID)
require.NoError(t, err)
var fullBackup *backups_core.Backup
for _, b := range backups {
if b.PgWalBackupType != nil &&
*b.PgWalBackupType == backups_core.PgWalBackupTypeFullBackup {
fullBackup = b
break
}
}
require.NotNil(t, fullBackup)
assert.Equal(t, backups_core.BackupStatusInProgress, fullBackup.Status)
assert.Greater(t, fullBackup.BackupDurationMs, int64(0),
"duration should be tracked in real-time during upload")
assert.Greater(t, fullBackup.BackupSizeMb, float64(0),
"size should be tracked in real-time during upload")
_ = pipeWriter.Close()
<-done
}
func Test_DownloadRestoreFile_UploadThenDownload_ContentMatches(t *testing.T) {
tests := []struct {
name string
@@ -841,15 +1189,18 @@ func Test_DownloadRestoreFile_UploadThenDownload_ContentMatches(t *testing.T) {
uploadContent := "test-basebackup-content-for-download"
body := bytes.NewReader([]byte(uploadContent))
req := newWalUploadRequest(
body, agentToken, backups_core.PgWalUploadTypeBasebackup, "",
"000000010000000100000001", "000000010000000100000010",
)
req, _ := http.NewRequest(http.MethodPost, "/api/v1/backups/postgres/wal/upload/full-start", body)
req.Header.Set("Authorization", agentToken)
req.Header.Set("Content-Type", "application/octet-stream")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusNoContent, w.Code)
require.Equal(t, http.StatusOK, w.Code)
WaitForBackupCompletion(t, db.ID, 0, 5*time.Second)
var uploadResp backups_dto.UploadBasebackupResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &uploadResp))
completeFullBackupUpload(t, router, agentToken, uploadResp.BackupID,
"000000010000000100000001", "000000010000000100000010", nil)
var planResp backups_dto.GetRestorePlanResponse
test_utils.MakeGetRequestAndUnmarshal(
@@ -883,7 +1234,7 @@ func Test_DownloadRestoreFile_WalSegment_UploadThenDownload_ContentMatches(t *te
walContent := "test-wal-segment-content-for-download"
body := bytes.NewReader([]byte(walContent))
req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000011", "", "")
req := newWalSegmentUploadRequest(body, agentToken, "000000010000000100000011")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusNoContent, w.Code)
@@ -1088,35 +1439,81 @@ func removeWalTestSetup(db *databases.Database, storage *storages.Storage) {
storages.RemoveTestStorage(storage.ID)
}
func newWalUploadRequest(
func newWalSegmentUploadRequest(
body io.Reader,
agentToken string,
uploadType backups_core.PgWalUploadType,
walSegmentName string,
walStart string,
walStop string,
segmentName string,
) *http.Request {
url := "/api/v1/backups/postgres/wal/upload"
if walStart != "" || walStop != "" {
url += "?fullBackupWalStartSegment=" + walStart + "&fullBackupWalStopSegment=" + walStop
}
req, err := http.NewRequest(http.MethodPost, url, body)
req, err := http.NewRequest(http.MethodPost, "/api/v1/backups/postgres/wal/upload/wal", body)
if err != nil {
panic(err)
}
req.Header.Set("Authorization", agentToken)
req.Header.Set("Content-Type", "application/octet-stream")
req.Header.Set("X-Upload-Type", string(uploadType))
if walSegmentName != "" {
req.Header.Set("X-Wal-Segment-Name", walSegmentName)
}
req.Header.Set("X-Wal-Segment-Name", segmentName)
return req
}
func uploadBasebackupPhase1(
t *testing.T,
router *gin.Engine,
agentToken string,
) uuid.UUID {
t.Helper()
body := bytes.NewReader([]byte("test-basebackup-content"))
req, _ := http.NewRequest(http.MethodPost, "/api/v1/backups/postgres/wal/upload/full-start", body)
req.Header.Set("Authorization", agentToken)
req.Header.Set("Content-Type", "application/octet-stream")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
var response backups_dto.UploadBasebackupResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
require.NotEqual(t, uuid.Nil, response.BackupID)
return response.BackupID
}
func completeFullBackupUpload(
t *testing.T,
router *gin.Engine,
agentToken string,
backupID uuid.UUID,
walStart string,
walStop string,
errMsg *string,
) {
t.Helper()
request := backups_dto.FinalizeBasebackupRequest{
BackupID: backupID,
StartSegment: walStart,
StopSegment: walStop,
Error: errMsg,
}
reqBody, _ := json.Marshal(request)
req, _ := http.NewRequest(
http.MethodPost,
"/api/v1/backups/postgres/wal/upload/full-complete",
bytes.NewReader(reqBody),
)
req.Header.Set("Authorization", agentToken)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
}
func uploadBasebackup(
t *testing.T,
router *gin.Engine,
@@ -1126,15 +1523,8 @@ func uploadBasebackup(
) {
t.Helper()
body := bytes.NewReader([]byte("test-basebackup-content"))
req := newWalUploadRequest(
body, agentToken, backups_core.PgWalUploadTypeBasebackup, "",
walStart, walStop,
)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusNoContent, w.Code)
backupID := uploadBasebackupPhase1(t, router, agentToken)
completeFullBackupUpload(t, router, agentToken, backupID, walStart, walStop, nil)
}
func uploadWalSegment(
@@ -1146,9 +1536,12 @@ func uploadWalSegment(
t.Helper()
body := bytes.NewReader([]byte("test-wal-segment-content"))
req := newWalUploadRequest(
body, agentToken, backups_core.PgWalUploadTypeWal, segmentName, "", "",
)
req, _ := http.NewRequest(http.MethodPost, "/api/v1/backups/postgres/wal/upload/wal", body)
req.Header.Set("Authorization", agentToken)
req.Header.Set("Content-Type", "application/octet-stream")
req.Header.Set("X-Wal-Segment-Name", segmentName)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)

View File

@@ -95,3 +95,33 @@ func CreateTestBackup(databaseID, storageID uuid.UUID) *backups_core.Backup {
return backup
}
type TestBackupOptions struct {
Status backups_core.BackupStatus
CreatedAt time.Time
PgWalBackupType *backups_core.PgWalBackupType
}
// CreateTestBackupWithOptions creates a test backup with custom status, time, and WAL type
func CreateTestBackupWithOptions(
databaseID, storageID uuid.UUID,
opts TestBackupOptions,
) *backups_core.Backup {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: databaseID,
StorageID: storageID,
Status: opts.Status,
BackupSizeMb: 10.5,
BackupDurationMs: 1000,
PgWalBackupType: opts.PgWalBackupType,
CreatedAt: opts.CreatedAt,
}
repo := &backups_core.BackupRepository{}
if err := repo.Save(backup); err != nil {
panic(err)
}
return backup
}

View File

@@ -0,0 +1,9 @@
package backups_core
import "time"
type BackupFilters struct {
Statuses []BackupStatus
BeforeDate *time.Time
PgWalBackupType *PgWalBackupType
}

View File

@@ -43,7 +43,8 @@ type Backup struct {
PgVersion *string `json:"pgVersion" gorm:"column:pg_version;type:text"`
PgWalSegmentName *string `json:"pgWalSegmentName" gorm:"column:pg_wal_segment_name;type:text"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
UploadCompletedAt *time.Time `json:"uploadCompletedAt" gorm:"column:upload_completed_at"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
}
func (b *Backup) GenerateFilename(dbName string) {

View File

@@ -349,6 +349,52 @@ func (r *BackupRepository) FindWalSegmentByName(
return &backup, nil
}
func (r *BackupRepository) FindLatestCompletedFullWalBackupBefore(
databaseID uuid.UUID,
before time.Time,
) (*Backup, error) {
var backup Backup
err := storage.
GetDb().
Where(
"database_id = ? AND pg_wal_backup_type = ? AND status = ? AND created_at <= ?",
databaseID,
PgWalBackupTypeFullBackup,
BackupStatusCompleted,
before,
).
Order("created_at DESC").
First(&backup).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &backup, nil
}
func (r *BackupRepository) FindStaleUploadedBasebackups(olderThan time.Time) ([]*Backup, error) {
var backups []*Backup
err := storage.
GetDb().
Where(
"status = ? AND upload_completed_at IS NOT NULL AND upload_completed_at < ?",
BackupStatusInProgress,
olderThan,
).
Find(&backups).Error
if err != nil {
return nil, err
}
return backups, nil
}
func (r *BackupRepository) FindLastWalSegmentAfter(
databaseID uuid.UUID,
afterSegmentName string,
@@ -376,3 +422,67 @@ func (r *BackupRepository) FindLastWalSegmentAfter(
return &backup, nil
}
func (r *BackupRepository) FindByDatabaseIDWithFiltersAndPagination(
databaseID uuid.UUID,
filters *BackupFilters,
limit, offset int,
) ([]*Backup, error) {
var backups []*Backup
query := storage.
GetDb().
Where("database_id = ?", databaseID)
if filters != nil {
query = filters.applyToQuery(query)
}
if err := query.
Order("created_at DESC").
Limit(limit).
Offset(offset).
Find(&backups).Error; err != nil {
return nil, err
}
return backups, nil
}
func (r *BackupRepository) CountByDatabaseIDWithFilters(
databaseID uuid.UUID,
filters *BackupFilters,
) (int64, error) {
var count int64
query := storage.
GetDb().
Model(&Backup{}).
Where("database_id = ?", databaseID)
if filters != nil {
query = filters.applyToQuery(query)
}
if err := query.Count(&count).Error; err != nil {
return 0, err
}
return count, nil
}
func (f *BackupFilters) applyToQuery(query *gorm.DB) *gorm.DB {
if len(f.Statuses) > 0 {
query = query.Where("status IN ?", f.Statuses)
}
if f.BeforeDate != nil {
query = query.Where("created_at < ?", *f.BeforeDate)
}
if f.PgWalBackupType != nil {
query = query.Where("pg_wal_backup_type = ?", *f.PgWalBackupType)
}
return query
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
)
@@ -13,38 +12,31 @@ type DownloadTokenBackgroundService struct {
downloadTokenService *DownloadTokenService
logger *slog.Logger
runOnce sync.Once
hasRun atomic.Bool
hasRun atomic.Bool
}
func (s *DownloadTokenBackgroundService) Run(ctx context.Context) {
wasAlreadyRun := s.hasRun.Load()
s.runOnce.Do(func() {
s.hasRun.Store(true)
s.logger.Info("Starting download token cleanup background service")
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
s.logger.Error("Failed to clean expired download tokens", "error", err)
}
}
}
})
if wasAlreadyRun {
if s.hasRun.Swap(true) {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
s.logger.Info("Starting download token cleanup background service")
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
s.logger.Error("Failed to clean expired download tokens", "error", err)
}
}
}
}

View File

@@ -1,9 +1,6 @@
package backups_download
import (
"sync"
"sync/atomic"
"databasus-backend/internal/config"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/logger"
@@ -37,8 +34,6 @@ func init() {
downloadTokenBackgroundService = &DownloadTokenBackgroundService{
downloadTokenService: downloadTokenService,
logger: logger.GetLogger(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}

View File

@@ -11,9 +11,12 @@ import (
)
type GetBackupsRequest struct {
DatabaseID string `form:"database_id" binding:"required"`
Limit int `form:"limit"`
Offset int `form:"offset"`
DatabaseID string `form:"database_id" binding:"required"`
Limit int `form:"limit"`
Offset int `form:"offset"`
Statuses []string `form:"status"`
BeforeDate *time.Time `form:"beforeDate"`
PgWalBackupType *string `form:"pgWalBackupType"`
}
type GetBackupsResponse struct {
@@ -44,10 +47,10 @@ type ReportErrorRequest struct {
Error string `json:"error" binding:"required"`
}
type UploadGapResponse struct {
Error string `json:"error"`
ExpectedSegmentName string `json:"expectedSegmentName"`
ReceivedSegmentName string `json:"receivedSegmentName"`
type IsWalChainValidResponse struct {
IsValid bool `json:"isValid"`
Error string `json:"error,omitempty"`
LastContiguousSegment string `json:"lastContiguousSegment,omitempty"`
}
type RestorePlanFullBackup struct {
@@ -77,3 +80,14 @@ type GetRestorePlanResponse struct {
TotalSizeBytes int64 `json:"totalSizeBytes"`
LatestAvailableSegment string `json:"latestAvailableSegment"`
}
type UploadBasebackupResponse struct {
BackupID uuid.UUID `json:"backupId"`
}
type FinalizeBasebackupRequest struct {
BackupID uuid.UUID `json:"backupId" binding:"required"`
StartSegment string `json:"startSegment"`
StopSegment string `json:"stopSegment"`
Error *string `json:"error"`
}

View File

@@ -2,7 +2,6 @@ package backups_services
import (
"sync"
"sync/atomic"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups/backuping"
@@ -59,26 +58,11 @@ func GetWalService() *PostgreWalBackupService {
return walService
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
var SetupDependencies = sync.OnceFunc(func() {
backups_config.
GetBackupConfigService().
SetDatabaseStorageChangeListener(backupService)
func SetupDependencies() {
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
backups_config.
GetBackupConfigService().
SetDatabaseStorageChangeListener(backupService)
databases.GetDatabaseService().AddDbRemoveListener(backupService)
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}
databases.GetDatabaseService().AddDbRemoveListener(backupService)
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
})

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"io"
"log/slog"
"sync/atomic"
"time"
"github.com/google/uuid"
@@ -30,88 +31,173 @@ type PostgreWalBackupService struct {
backupService *BackupService
}
// UploadWal accepts a streaming WAL segment or basebackup upload from the agent.
// For WAL segments it validates the WAL chain before accepting. Returns an UploadGapResponse
// (409) when the chain is broken so the agent knows to trigger a full basebackup.
func (s *PostgreWalBackupService) UploadWal(
// UploadWalSegment accepts a streaming WAL segment upload from the agent.
// WAL segments are accepted unconditionally.
func (s *PostgreWalBackupService) UploadWalSegment(
ctx context.Context,
database *databases.Database,
uploadType backups_core.PgWalUploadType,
walSegmentName string,
fullBackupWalStartSegment string,
fullBackupWalStopSegment string,
walSegmentSizeBytes int64,
body io.Reader,
) (*backups_dto.UploadGapResponse, error) {
if err := s.validateWalBackupType(database); err != nil {
return nil, err
}
) error {
uploadStart := time.Now().UTC()
if uploadType == backups_core.PgWalUploadTypeBasebackup {
if fullBackupWalStartSegment == "" || fullBackupWalStopSegment == "" {
return nil, fmt.Errorf(
"fullBackupWalStartSegment and fullBackupWalStopSegment are required for basebackup uploads",
)
}
if err := s.validateWalBackupType(database); err != nil {
return err
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
if err != nil {
return nil, fmt.Errorf("failed to get backup config: %w", err)
return fmt.Errorf("failed to get backup config: %w", err)
}
if backupConfig.Storage == nil {
return nil, fmt.Errorf("no storage configured for database %s", database.ID)
return fmt.Errorf("no storage configured for database %s", database.ID)
}
if uploadType == backups_core.PgWalUploadTypeWal {
// Idempotency: check before chain validation so a successful re-upload is
// not misidentified as a gap.
existing, err := s.backupRepository.FindWalSegmentByName(database.ID, walSegmentName)
if err != nil {
return nil, fmt.Errorf("failed to check for duplicate WAL segment: %w", err)
}
if existing != nil {
return nil, nil
}
gapResp, err := s.validateWalChain(database.ID, walSegmentName, walSegmentSizeBytes)
if err != nil {
return nil, err
}
if gapResp != nil {
return gapResp, nil
}
existing, err := s.backupRepository.FindWalSegmentByName(database.ID, walSegmentName)
if err != nil {
return fmt.Errorf("failed to check for duplicate WAL segment: %w", err)
}
backup := s.createBackupRecord(
if existing != nil {
return nil
}
backup := s.createWalSegmentRecord(
database.ID,
backupConfig.Storage.ID,
uploadType,
database.Name,
walSegmentName,
fullBackupWalStartSegment,
fullBackupWalStopSegment,
backupConfig.Encryption,
)
if err := s.backupRepository.Save(backup); err != nil {
return nil, fmt.Errorf("failed to create backup record: %w", err)
return fmt.Errorf("failed to create backup record: %w", err)
}
sizeBytes, streamErr := s.streamToStorage(ctx, backup, backupConfig, body)
inputCounter := &countingReader{r: body}
progressDone := make(chan struct{})
go s.startProgressTracker(backup, inputCounter, uploadStart, progressDone)
sizeBytes, streamErr := s.streamToStorage(ctx, backup, backupConfig, inputCounter)
close(progressDone)
if streamErr != nil {
errMsg := streamErr.Error()
backup.BackupDurationMs = time.Since(uploadStart).Milliseconds()
s.markFailed(backup, errMsg)
return nil, fmt.Errorf("upload failed: %w", streamErr)
return fmt.Errorf("upload failed: %w", streamErr)
}
backup.BackupDurationMs = time.Since(uploadStart).Milliseconds()
s.markCompleted(backup, sizeBytes)
return nil, nil
return nil
}
// UploadBasebackup accepts a streaming basebackup upload from the agent (Phase 1).
// The backup stays IN_PROGRESS with UploadCompletedAt set after streaming finishes.
// The agent must call FinalizeBasebackup (Phase 2) with WAL segment names to complete.
func (s *PostgreWalBackupService) UploadBasebackup(
ctx context.Context,
database *databases.Database,
body io.Reader,
) (uuid.UUID, error) {
uploadStart := time.Now().UTC()
if err := s.validateWalBackupType(database); err != nil {
return uuid.Nil, err
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
if err != nil {
return uuid.Nil, fmt.Errorf("failed to get backup config: %w", err)
}
if backupConfig.Storage == nil {
return uuid.Nil, fmt.Errorf("no storage configured for database %s", database.ID)
}
backup := s.createBasebackupRecord(
database.ID,
backupConfig.Storage.ID,
database.Name,
backupConfig.Encryption,
)
if err := s.backupRepository.Save(backup); err != nil {
return uuid.Nil, fmt.Errorf("failed to create backup record: %w", err)
}
inputCounter := &countingReader{r: body}
progressDone := make(chan struct{})
go s.startProgressTracker(backup, inputCounter, uploadStart, progressDone)
sizeBytes, streamErr := s.streamToStorage(ctx, backup, backupConfig, inputCounter)
close(progressDone)
if streamErr != nil {
errMsg := streamErr.Error()
backup.BackupDurationMs = time.Since(uploadStart).Milliseconds()
s.markFailed(backup, errMsg)
return uuid.Nil, fmt.Errorf("upload failed: %w", streamErr)
}
now := time.Now().UTC()
backup.UploadCompletedAt = &now
backup.BackupSizeMb = float64(sizeBytes) / (1024 * 1024)
backup.BackupDurationMs = time.Since(uploadStart).Milliseconds()
if err := s.backupRepository.Save(backup); err != nil {
return uuid.Nil, fmt.Errorf("failed to update backup after upload: %w", err)
}
return backup.ID, nil
}
// FinalizeBasebackup completes a previously uploaded basebackup (Phase 2).
// Sets WAL segment names and marks the backup as COMPLETED, or marks it FAILED if errorMsg is provided.
func (s *PostgreWalBackupService) FinalizeBasebackup(
database *databases.Database,
backupID uuid.UUID,
startSegment string,
stopSegment string,
errorMsg *string,
) error {
if err := s.validateWalBackupType(database); err != nil {
return err
}
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
return fmt.Errorf("backup not found: %w", err)
}
if backup.DatabaseID != database.ID {
return fmt.Errorf("backup does not belong to this database")
}
if backup.Status != backups_core.BackupStatusInProgress || backup.UploadCompletedAt == nil {
return fmt.Errorf("backup is not awaiting finalization")
}
if errorMsg != nil {
s.markFailed(backup, *errorMsg)
return nil
}
backup.PgFullBackupWalStartSegmentName = &startSegment
backup.PgFullBackupWalStopSegmentName = &stopSegment
backup.Status = backups_core.BackupStatusCompleted
if err := s.backupRepository.Save(backup); err != nil {
return fmt.Errorf("failed to finalize backup: %w", err)
}
return nil
}
func (s *PostgreWalBackupService) GetRestorePlan(
@@ -299,97 +385,97 @@ func (s *PostgreWalBackupService) ReportError(
return nil
}
func (s *PostgreWalBackupService) validateWalChain(
databaseID uuid.UUID,
incomingSegment string,
walSegmentSizeBytes int64,
) (*backups_dto.UploadGapResponse, error) {
fullBackup, err := s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(databaseID)
// IsWalChainValid checks whether the WAL chain is continuous since the last completed full backup.
func (s *PostgreWalBackupService) IsWalChainValid(
database *databases.Database,
) (*backups_dto.IsWalChainValidResponse, error) {
if err := s.validateWalBackupType(database); err != nil {
return nil, err
}
fullBackup, err := s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(database.ID)
if err != nil {
return nil, fmt.Errorf("failed to query full backup: %w", err)
}
// No full backup exists yet: cannot accept WAL segments without a chain anchor.
if fullBackup == nil || fullBackup.PgFullBackupWalStopSegmentName == nil {
return &backups_dto.UploadGapResponse{
Error: "no_full_backup",
ExpectedSegmentName: "",
ReceivedSegmentName: incomingSegment,
return &backups_dto.IsWalChainValidResponse{
IsValid: false,
Error: "no_full_backup",
}, nil
}
stopSegment := *fullBackup.PgFullBackupWalStopSegmentName
startSegment := ""
if fullBackup.PgFullBackupWalStartSegmentName != nil {
startSegment = *fullBackup.PgFullBackupWalStartSegmentName
}
lastWal, err := s.backupRepository.FindLastWalSegmentAfter(databaseID, stopSegment)
walSegments, err := s.backupRepository.FindCompletedWalSegmentsAfter(database.ID, startSegment)
if err != nil {
return nil, fmt.Errorf("failed to query last WAL segment: %w", err)
return nil, fmt.Errorf("failed to query WAL segments: %w", err)
}
walCalculator := util_wal.NewWalCalculator(walSegmentSizeBytes)
var chainTail string
if lastWal != nil && lastWal.PgWalSegmentName != nil {
chainTail = *lastWal.PgWalSegmentName
} else {
chainTail = stopSegment
}
expectedNext, err := walCalculator.NextSegment(chainTail)
if err != nil {
return nil, fmt.Errorf("WAL arithmetic failed for %q: %w", chainTail, err)
}
if incomingSegment != expectedNext {
return &backups_dto.UploadGapResponse{
Error: "gap_detected",
ExpectedSegmentName: expectedNext,
ReceivedSegmentName: incomingSegment,
chainErr := s.validateRestoreWalChain(fullBackup, walSegments)
if chainErr != nil {
return &backups_dto.IsWalChainValidResponse{
IsValid: false,
Error: chainErr.Error,
LastContiguousSegment: chainErr.LastContiguousSegment,
}, nil
}
return nil, nil
return &backups_dto.IsWalChainValidResponse{
IsValid: true,
}, nil
}
func (s *PostgreWalBackupService) createBackupRecord(
func (s *PostgreWalBackupService) createBasebackupRecord(
databaseID uuid.UUID,
storageID uuid.UUID,
uploadType backups_core.PgWalUploadType,
dbName string,
walSegmentName string,
fullBackupWalStartSegment string,
fullBackupWalStopSegment string,
encryption backups_config.BackupEncryption,
) *backups_core.Backup {
now := time.Now().UTC()
walBackupType := backups_core.PgWalBackupTypeFullBackup
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: databaseID,
StorageID: storageID,
Status: backups_core.BackupStatusInProgress,
Encryption: encryption,
CreatedAt: now,
ID: uuid.New(),
DatabaseID: databaseID,
StorageID: storageID,
Status: backups_core.BackupStatusInProgress,
PgWalBackupType: &walBackupType,
Encryption: encryption,
CreatedAt: now,
}
backup.GenerateFilename(dbName)
if uploadType == backups_core.PgWalUploadTypeBasebackup {
walBackupType := backups_core.PgWalBackupTypeFullBackup
backup.PgWalBackupType = &walBackupType
return backup
}
if fullBackupWalStartSegment != "" {
backup.PgFullBackupWalStartSegmentName = &fullBackupWalStartSegment
}
func (s *PostgreWalBackupService) createWalSegmentRecord(
databaseID uuid.UUID,
storageID uuid.UUID,
dbName string,
walSegmentName string,
encryption backups_config.BackupEncryption,
) *backups_core.Backup {
now := time.Now().UTC()
walBackupType := backups_core.PgWalBackupTypeWalSegment
if fullBackupWalStopSegment != "" {
backup.PgFullBackupWalStopSegmentName = &fullBackupWalStopSegment
}
} else {
walBackupType := backups_core.PgWalBackupTypeWalSegment
backup.PgWalBackupType = &walBackupType
backup.PgWalSegmentName = &walSegmentName
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: databaseID,
StorageID: storageID,
Status: backups_core.BackupStatusInProgress,
PgWalBackupType: &walBackupType,
PgWalSegmentName: &walSegmentName,
Encryption: encryption,
CreatedAt: now,
}
backup.GenerateFilename(dbName)
return backup
}
@@ -418,7 +504,7 @@ func (s *PostgreWalBackupService) streamDirect(
return 0, err
}
return cr.n, nil
return cr.n.Load(), nil
}
func (s *PostgreWalBackupService) streamEncrypted(
@@ -479,7 +565,7 @@ func (s *PostgreWalBackupService) streamEncrypted(
backup.EncryptionSalt = &encryptionSetup.SaltBase64
backup.EncryptionIV = &encryptionSetup.NonceBase64
return cr.n, nil
return cr.n.Load(), nil
}
func (s *PostgreWalBackupService) markCompleted(backup *backups_core.Backup, sizeBytes int64) {
@@ -497,6 +583,31 @@ func (s *PostgreWalBackupService) markCompleted(backup *backups_core.Backup, siz
}
}
func (s *PostgreWalBackupService) startProgressTracker(
backup *backups_core.Backup,
inputCounter *countingReader,
uploadStart time.Time,
done <-chan struct{},
) {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-done:
return
case <-ticker.C:
backup.BackupDurationMs = time.Since(uploadStart).Milliseconds()
backup.BackupSizeMb = float64(inputCounter.n.Load()) / (1024 * 1024)
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("failed to update backup progress",
"backupId", backup.ID, "error", err)
}
}
}
}
func (s *PostgreWalBackupService) markFailed(backup *backups_core.Backup, errMsg string) {
backup.Status = backups_core.BackupStatusFailed
backup.FailMessage = &errMsg
@@ -510,11 +621,32 @@ func (s *PostgreWalBackupService) resolveFullBackup(
databaseID uuid.UUID,
backupID *uuid.UUID,
) (*backups_core.Backup, error) {
if backupID != nil {
return s.backupRepository.FindCompletedFullWalBackupByID(databaseID, *backupID)
if backupID == nil {
return s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(databaseID)
}
return s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(databaseID)
fullBackup, err := s.backupRepository.FindCompletedFullWalBackupByID(databaseID, *backupID)
if err != nil {
return nil, err
}
if fullBackup != nil {
return fullBackup, nil
}
backup, err := s.backupRepository.FindByID(*backupID)
if err != nil {
return nil, nil
}
if backup.DatabaseID != databaseID ||
backup.Status != backups_core.BackupStatusCompleted ||
backup.PgWalBackupType == nil ||
*backup.PgWalBackupType != backups_core.PgWalBackupTypeWalSegment {
return nil, nil
}
return s.backupRepository.FindLatestCompletedFullWalBackupBefore(databaseID, backup.CreatedAt)
}
func (s *PostgreWalBackupService) validateRestoreWalChain(
@@ -602,12 +734,12 @@ func (s *PostgreWalBackupService) validateWalBackupType(database *databases.Data
type countingReader struct {
r io.Reader
n int64
n atomic.Int64
}
func (cr *countingReader) Read(p []byte) (n int, err error) {
n, err = cr.r.Read(p)
cr.n += int64(n)
func (cr *countingReader) Read(p []byte) (int, error) {
n, err := cr.r.Read(p)
cr.n.Add(int64(n))
return n, err
}

View File

@@ -109,6 +109,7 @@ func (s *BackupService) GetBackups(
user *users_models.User,
databaseID uuid.UUID,
limit, offset int,
filters *backups_core.BackupFilters,
) (*backups_dto.GetBackupsResponse, error) {
database, err := s.databaseService.GetDatabaseByID(databaseID)
if err != nil {
@@ -134,12 +135,14 @@ func (s *BackupService) GetBackups(
offset = 0
}
backups, err := s.backupRepository.FindByDatabaseIDWithPagination(databaseID, limit, offset)
backups, err := s.backupRepository.FindByDatabaseIDWithFiltersAndPagination(
databaseID, filters, limit, offset,
)
if err != nil {
return nil, err
}
total, err := s.backupRepository.CountByDatabaseID(databaseID)
total, err := s.backupRepository.CountByDatabaseIDWithFilters(databaseID, filters)
if err != nil {
return nil, err
}

View File

@@ -109,6 +109,7 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
"--routines",
"--quick",
"--skip-extended-insert",
"--skip-add-locks",
"--verbose",
}
@@ -129,6 +130,8 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
if mdb.IsHttps {
args = append(args, "--ssl")
args = append(args, "--skip-ssl-verify-server-cert")
} else {
args = append(args, "--skip-ssl")
}
if mdb.Database != nil && *mdb.Database != "" {
@@ -278,15 +281,9 @@ func (uc *CreateMariadbBackupUsecase) createTempMyCnfFile(
mdbConfig *mariadbtypes.MariadbDatabase,
password string,
) (string, error) {
tempFolder := config.GetEnv().TempFolder
if err := os.MkdirAll(tempFolder, 0o700); err != nil {
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
}
if err := os.Chmod(tempFolder, 0o700); err != nil {
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
}
tempDir, err := os.MkdirTemp(tempFolder, "mycnf_"+uuid.New().String())
// Credential files use OS temp dir (/tmp) because some filesystems
// (e.g. ZFS on TrueNAS) ignore chmod, causing "group or world access" errors.
tempDir, err := os.MkdirTemp(os.TempDir(), "mycnf_"+uuid.New().String())
if err != nil {
return "", fmt.Errorf("failed to create temp directory: %w", err)
}

View File

@@ -108,6 +108,7 @@ func (uc *CreateMysqlBackupUsecase) buildMysqldumpArgs(my *mysqltypes.MysqlDatab
"--set-gtid-purged=OFF",
"--quick",
"--skip-extended-insert",
"--skip-add-locks",
"--verbose",
}
@@ -126,6 +127,8 @@ func (uc *CreateMysqlBackupUsecase) buildMysqldumpArgs(my *mysqltypes.MysqlDatab
if my.IsHttps {
args = append(args, "--ssl-mode=REQUIRED")
} else {
args = append(args, "--ssl-mode=DISABLED")
}
if my.Database != nil && *my.Database != "" {
@@ -297,15 +300,9 @@ func (uc *CreateMysqlBackupUsecase) createTempMyCnfFile(
myConfig *mysqltypes.MysqlDatabase,
password string,
) (string, error) {
tempFolder := config.GetEnv().TempFolder
if err := os.MkdirAll(tempFolder, 0o700); err != nil {
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
}
if err := os.Chmod(tempFolder, 0o700); err != nil {
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
}
tempDir, err := os.MkdirTemp(tempFolder, "mycnf_"+uuid.New().String())
// Credential files use OS temp dir (/tmp) because some filesystems
// (e.g. ZFS on TrueNAS) ignore chmod, causing "group or world access" errors.
tempDir, err := os.MkdirTemp(os.TempDir(), "mycnf_"+uuid.New().String())
if err != nil {
return "", fmt.Errorf("failed to create temp directory: %w", err)
}
@@ -326,6 +323,8 @@ port=%d
if myConfig.IsHttps {
content += "ssl-mode=REQUIRED\n"
} else {
content += "ssl-mode=DISABLED\n"
}
err = os.WriteFile(myCnfFile, []byte(content), 0o600)

View File

@@ -747,15 +747,9 @@ func (uc *CreatePostgresqlBackupUsecase) createTempPgpassFile(
escapedPassword,
)
tempFolder := config.GetEnv().TempFolder
if err := os.MkdirAll(tempFolder, 0o700); err != nil {
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
}
if err := os.Chmod(tempFolder, 0o700); err != nil {
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
}
tempDir, err := os.MkdirTemp(tempFolder, "pgpass_"+uuid.New().String())
// Credential files use OS temp dir (/tmp) because some filesystems
// (e.g. ZFS on TrueNAS) ignore chmod, causing "group or world access" errors.
tempDir, err := os.MkdirTemp(os.TempDir(), "pgpass_"+uuid.New().String())
if err != nil {
return "", fmt.Errorf("failed to create temporary directory: %w", err)
}

View File

@@ -16,7 +16,6 @@ type BackupConfigController struct {
func (c *BackupConfigController) RegisterRoutes(router *gin.RouterGroup) {
router.POST("/backup-configs/save", c.SaveBackupConfig)
router.GET("/backup-configs/database/:id/plan", c.GetDatabasePlan)
router.GET("/backup-configs/database/:id", c.GetBackupConfigByDbID)
router.GET("/backup-configs/storage/:id/is-using", c.IsStorageUsing)
router.GET("/backup-configs/storage/:id/databases-count", c.CountDatabasesForStorage)
@@ -93,39 +92,6 @@ func (c *BackupConfigController) GetBackupConfigByDbID(ctx *gin.Context) {
ctx.JSON(http.StatusOK, backupConfig)
}
// GetDatabasePlan
// @Summary Get database plan by database ID
// @Description Get the plan limits for a specific database (max backup size, max total size, max storage period)
// @Tags backup-configs
// @Produce json
// @Param id path string true "Database ID"
// @Success 200 {object} plans.DatabasePlan
// @Failure 400 {object} map[string]string "Invalid database ID"
// @Failure 401 {object} map[string]string "User not authenticated"
// @Failure 404 {object} map[string]string "Database not found or access denied"
// @Router /backup-configs/database/{id}/plan [get]
func (c *BackupConfigController) GetDatabasePlan(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
return
}
plan, err := c.backupConfigService.GetDatabasePlan(user, id)
if err != nil {
ctx.JSON(http.StatusNotFound, gin.H{"error": "database plan not found"})
return
}
ctx.JSON(http.StatusOK, plan)
}
// IsStorageUsing
// @Summary Check if storage is being used
// @Description Check if a storage is currently being used by any backup configuration

View File

@@ -17,14 +17,12 @@ import (
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/features/storages"
local_storage "databasus-backend/internal/features/storages/models/local"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/storage"
"databasus-backend/internal/util/period"
test_utils "databasus-backend/internal/util/testing"
"databasus-backend/internal/util/tools"
@@ -326,218 +324,13 @@ func Test_GetBackupConfigByDbID_ReturnsDefaultConfigForNewDatabase(t *testing.T)
&response,
)
var plan plans.DatabasePlan
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
"Bearer "+owner.Token,
http.StatusOK,
&plan,
)
assert.Equal(t, database.ID, response.DatabaseID)
assert.False(t, response.IsBackupsEnabled)
assert.Equal(t, plan.MaxStoragePeriod, response.RetentionTimePeriod)
assert.Equal(t, plan.MaxBackupSizeMB, response.MaxBackupSizeMB)
assert.Equal(t, plan.MaxBackupsTotalSizeMB, response.MaxBackupsTotalSizeMB)
assert.True(t, response.IsRetryIfFailed)
assert.Equal(t, 3, response.MaxFailedTriesCount)
assert.NotNil(t, response.BackupInterval)
}
func Test_GetDatabasePlan_ForNewDatabase_PlanAlwaysReturned(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
var response plans.DatabasePlan
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
"Bearer "+owner.Token,
http.StatusOK,
&response,
)
assert.Equal(t, database.ID, response.DatabaseID)
assert.NotNil(t, response.MaxBackupSizeMB)
assert.NotNil(t, response.MaxBackupsTotalSizeMB)
assert.NotEmpty(t, response.MaxStoragePeriod)
}
func Test_SaveBackupConfig_WhenPlanLimitsAreAdjusted_ValidationEnforced(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Get plan via API (triggers auto-creation)
var plan plans.DatabasePlan
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
"Bearer "+owner.Token,
http.StatusOK,
&plan,
)
assert.Equal(t, database.ID, plan.DatabaseID)
// Adjust plan limits directly in database to fixed restrictive values
err := storage.GetDb().Model(&plans.DatabasePlan{}).
Where("database_id = ?", database.ID).
Updates(map[string]any{
"max_backup_size_mb": 100,
"max_backups_total_size_mb": 1000,
"max_storage_period": period.PeriodMonth,
}).Error
assert.NoError(t, err)
// Test 1: Try to save backup config with exceeded backup size limit
timeOfDay := "04:00"
backupConfigExceededSize := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 200, // Exceeds limit of 100
MaxBackupsTotalSizeMB: 800,
}
respExceededSize := test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigExceededSize,
http.StatusBadRequest,
)
assert.Contains(t, string(respExceededSize.Body), "max backup size exceeds plan limit")
// Test 2: Try to save backup config with exceeded total size limit
backupConfigExceededTotal := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 50,
MaxBackupsTotalSizeMB: 2000, // Exceeds limit of 1000
}
respExceededTotal := test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigExceededTotal,
http.StatusBadRequest,
)
assert.Contains(t, string(respExceededTotal.Body), "max total backups size exceeds plan limit")
// Test 3: Try to save backup config with exceeded storage period limit
backupConfigExceededPeriod := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodYear, // Exceeds limit of Month
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 80,
MaxBackupsTotalSizeMB: 800,
}
respExceededPeriod := test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigExceededPeriod,
http.StatusBadRequest,
)
assert.Contains(t, string(respExceededPeriod.Body), "storage period exceeds plan limit")
// Test 4: Save backup config within all limits - should succeed
backupConfigValid := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek, // Within Month limit
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 80, // Within 100 limit
MaxBackupsTotalSizeMB: 800, // Within 1000 limit
}
var responseValid BackupConfig
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigValid,
http.StatusOK,
&responseValid,
)
assert.Equal(t, database.ID, responseValid.DatabaseID)
assert.Equal(t, int64(80), responseValid.MaxBackupSizeMB)
assert.Equal(t, int64(800), responseValid.MaxBackupsTotalSizeMB)
assert.Equal(t, period.PeriodWeek, responseValid.RetentionTimePeriod)
}
func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string

View File

@@ -2,14 +2,11 @@ package backups_config
import (
"sync"
"sync/atomic"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/features/storages"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/logger"
)
var (
@@ -20,7 +17,6 @@ var (
storages.GetStorageService(),
notifiers.GetNotifierService(),
workspaces_services.GetWorkspaceService(),
plans.GetDatabasePlanService(),
nil,
}
)
@@ -37,21 +33,6 @@ func GetBackupConfigService() *BackupConfigService {
return backupConfigService
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}
var SetupDependencies = sync.OnceFunc(func() {
storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService)
})

Some files were not shown because too many files have changed in this diff Show More