mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 08:41:58 +02:00
Compare commits
57 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8c8c712d97 | ||
|
|
4e1cee2aa2 | ||
|
|
18b8178608 | ||
|
|
02d9cda86f | ||
|
|
cefedb6ddd | ||
|
|
27d891fb34 | ||
|
|
d1c41ed53a | ||
|
|
f287967b5d | ||
|
|
44ddcb836e | ||
|
|
7913c1b474 | ||
|
|
189573fa1b | ||
|
|
63e23b2489 | ||
|
|
1926096377 | ||
|
|
0a131511a8 | ||
|
|
aa01ce0b76 | ||
|
|
1ac0eb4d5b | ||
|
|
c7d091fe51 | ||
|
|
b1dfd1c425 | ||
|
|
4bee78646a | ||
|
|
3a5a53c92d | ||
|
|
f0ab470a84 | ||
|
|
f728fda759 | ||
|
|
80b5df6283 | ||
|
|
67556a0db1 | ||
|
|
c4cf7f8446 | ||
|
|
61a0bcabb1 | ||
|
|
f1e289c421 | ||
|
|
c0952e057f | ||
|
|
b4d4e0a1d7 | ||
|
|
c648e9c29f | ||
|
|
3fce6d2a99 | ||
|
|
198b94ba9d | ||
|
|
80cd0bf5d3 | ||
|
|
231e3cc709 | ||
|
|
8cf0fdacb1 | ||
|
|
2d28af19dc | ||
|
|
67dc257fda | ||
|
|
881167f812 | ||
|
|
cf807cfc54 | ||
|
|
df91651709 | ||
|
|
b0592dae9e | ||
|
|
c997202484 | ||
|
|
a17ea2f3e2 | ||
|
|
856aa1c256 | ||
|
|
f60f677351 | ||
|
|
4c980746ab | ||
|
|
89197bbbc6 | ||
|
|
e2ac5bfbd7 | ||
|
|
cf6e8f212a | ||
|
|
6ee7e02f5d | ||
|
|
14bcd3d70b | ||
|
|
5faa11f82a | ||
|
|
2c4e3e567b | ||
|
|
82d615545b | ||
|
|
e913f4c32e | ||
|
|
57a75918e4 | ||
|
|
8a601c7f68 |
109
.github/workflows/ci-release.yml
vendored
109
.github/workflows/ci-release.yml
vendored
@@ -9,6 +9,7 @@ on:
|
||||
|
||||
jobs:
|
||||
lint-backend:
|
||||
if: github.ref != 'refs/heads/develop'
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: golang:1.26.1
|
||||
@@ -55,6 +56,7 @@ jobs:
|
||||
git diff --exit-code go.mod go.sum || (echo "go mod tidy made changes, please run 'go mod tidy' and commit the changes" && exit 1)
|
||||
|
||||
lint-frontend:
|
||||
if: github.ref != 'refs/heads/develop'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check out code
|
||||
@@ -87,6 +89,7 @@ jobs:
|
||||
npm run build
|
||||
|
||||
lint-agent:
|
||||
if: github.ref != 'refs/heads/develop'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check out code
|
||||
@@ -120,6 +123,7 @@ jobs:
|
||||
git diff --exit-code go.mod go.sum || (echo "go mod tidy made changes, please run 'go mod tidy' and commit the changes" && exit 1)
|
||||
|
||||
test-frontend:
|
||||
if: github.ref != 'refs/heads/develop'
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint-frontend]
|
||||
steps:
|
||||
@@ -142,6 +146,7 @@ jobs:
|
||||
npm run test
|
||||
|
||||
test-agent:
|
||||
if: github.ref != 'refs/heads/develop'
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint-agent]
|
||||
steps:
|
||||
@@ -165,6 +170,7 @@ jobs:
|
||||
go test -count=1 -failfast ./internal/...
|
||||
|
||||
e2e-agent:
|
||||
if: github.ref != 'refs/heads/develop'
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint-agent]
|
||||
steps:
|
||||
@@ -183,9 +189,34 @@ jobs:
|
||||
docker compose down -v --rmi local || true
|
||||
rm -rf artifacts || true
|
||||
|
||||
e2e-agent-backup-restore:
|
||||
if: github.ref != 'refs/heads/develop'
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint-agent]
|
||||
strategy:
|
||||
matrix:
|
||||
pg_version: [15, 16, 17, 18]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Run backup-restore e2e (PG ${{ matrix.pg_version }})
|
||||
run: |
|
||||
cd agent
|
||||
make e2e-backup-restore PG_VERSION=${{ matrix.pg_version }}
|
||||
|
||||
- name: Cleanup
|
||||
if: always()
|
||||
run: |
|
||||
cd agent/e2e
|
||||
docker compose -f docker-compose.backup-restore.yml down -v --rmi local || true
|
||||
rm -rf artifacts || true
|
||||
|
||||
# Self-hosted: performant high-frequency CPU is used to start many containers and run tests fast. Tests
|
||||
# step is bottle-neck, because we need a lot of containers and cannot parallelize tests due to shared resources
|
||||
test-backend:
|
||||
if: github.ref != 'refs/heads/develop'
|
||||
runs-on: self-hosted
|
||||
needs: [lint-backend]
|
||||
container:
|
||||
@@ -273,9 +304,6 @@ jobs:
|
||||
TEST_MARIADB_114_PORT=33114
|
||||
TEST_MARIADB_118_PORT=33118
|
||||
TEST_MARIADB_120_PORT=33120
|
||||
# testing Telegram
|
||||
TEST_TELEGRAM_BOT_TOKEN=${{ secrets.TEST_TELEGRAM_BOT_TOKEN }}
|
||||
TEST_TELEGRAM_CHAT_ID=${{ secrets.TEST_TELEGRAM_CHAT_ID }}
|
||||
# supabase
|
||||
TEST_SUPABASE_HOST=${{ secrets.TEST_SUPABASE_HOST }}
|
||||
TEST_SUPABASE_PORT=${{ secrets.TEST_SUPABASE_PORT }}
|
||||
@@ -514,11 +542,47 @@ jobs:
|
||||
|
||||
echo "Cleanup complete"
|
||||
|
||||
build-and-push-dev:
|
||||
runs-on: self-hosted
|
||||
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/develop' }}
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/* || true
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/.* || true
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up QEMU (enables multi-arch emulation)
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Build and push dev image
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
platforms: linux/amd64,linux/arm64
|
||||
build-args: |
|
||||
APP_VERSION=dev-${{ github.sha }}
|
||||
tags: |
|
||||
databasus/databasus-dev:latest
|
||||
databasus/databasus-dev:${{ github.sha }}
|
||||
|
||||
determine-version:
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: node:20
|
||||
needs: [test-backend, test-frontend, test-agent, e2e-agent]
|
||||
needs: [test-backend, test-frontend, test-agent, e2e-agent, e2e-agent-backup-restore]
|
||||
if: ${{ github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, '[skip-release]') }}
|
||||
outputs:
|
||||
should_release: ${{ steps.version_bump.outputs.should_release }}
|
||||
@@ -609,43 +673,6 @@ jobs:
|
||||
echo "No version bump needed"
|
||||
fi
|
||||
|
||||
build-only:
|
||||
runs-on: self-hosted
|
||||
needs: [test-backend, test-frontend, test-agent, e2e-agent]
|
||||
if: ${{ github.ref == 'refs/heads/main' && contains(github.event.head_commit.message, '[skip-release]') }}
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/* || true
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/.* || true
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up QEMU (enables multi-arch emulation)
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Build and push SHA-only tags
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
platforms: linux/amd64,linux/arm64
|
||||
build-args: |
|
||||
APP_VERSION=dev-${{ github.sha }}
|
||||
tags: |
|
||||
databasus/databasus:latest
|
||||
databasus/databasus:${{ github.sha }}
|
||||
|
||||
build-and-push:
|
||||
runs-on: self-hosted
|
||||
needs: [determine-version]
|
||||
|
||||
62
Dockerfile
62
Dockerfile
@@ -239,7 +239,8 @@ RUN apt-get update && \
|
||||
fi
|
||||
|
||||
# Create postgres user and set up directories
|
||||
RUN useradd -m -s /bin/bash postgres || true && \
|
||||
RUN groupadd -g 999 postgres || true && \
|
||||
useradd -m -s /bin/bash -u 999 -g 999 postgres || true && \
|
||||
mkdir -p /databasus-data/pgdata && \
|
||||
chown -R postgres:postgres /databasus-data/pgdata
|
||||
|
||||
@@ -257,6 +258,9 @@ COPY backend/migrations ./migrations
|
||||
# Copy UI files
|
||||
COPY --from=backend-build /app/ui/build ./ui/build
|
||||
|
||||
# Copy cloud static HTML template (injected into index.html at startup when IS_CLOUD=true)
|
||||
COPY frontend/cloud-root-content.html /app/cloud-root-content.html
|
||||
|
||||
# Copy agent binaries (both architectures) — served by the backend
|
||||
# at GET /api/v1/system/agent?arch=amd64|arm64
|
||||
COPY --from=agent-build /agent-binaries ./agent-binaries
|
||||
@@ -291,6 +295,23 @@ if [ -d "/postgresus-data" ] && [ "\$(ls -A /postgresus-data 2>/dev/null)" ]; th
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# ========= Adjust postgres user UID/GID =========
|
||||
PUID=\${PUID:-999}
|
||||
PGID=\${PGID:-999}
|
||||
|
||||
CURRENT_UID=\$(id -u postgres)
|
||||
CURRENT_GID=\$(id -g postgres)
|
||||
|
||||
if [ "\$CURRENT_GID" != "\$PGID" ]; then
|
||||
echo "Adjusting postgres group GID from \$CURRENT_GID to \$PGID..."
|
||||
groupmod -o -g "\$PGID" postgres
|
||||
fi
|
||||
|
||||
if [ "\$CURRENT_UID" != "\$PUID" ]; then
|
||||
echo "Adjusting postgres user UID from \$CURRENT_UID to \$PUID..."
|
||||
usermod -o -u "\$PUID" postgres
|
||||
fi
|
||||
|
||||
# PostgreSQL 17 binary paths
|
||||
PG_BIN="/usr/lib/postgresql/17/bin"
|
||||
|
||||
@@ -313,7 +334,9 @@ window.__RUNTIME_CONFIG__ = {
|
||||
GOOGLE_CLIENT_ID: '\${GOOGLE_CLIENT_ID:-}',
|
||||
IS_EMAIL_CONFIGURED: '\$IS_EMAIL_CONFIGURED',
|
||||
CLOUDFLARE_TURNSTILE_SITE_KEY: '\${CLOUDFLARE_TURNSTILE_SITE_KEY:-}',
|
||||
CONTAINER_ARCH: '\${CONTAINER_ARCH:-unknown}'
|
||||
CONTAINER_ARCH: '\${CONTAINER_ARCH:-unknown}',
|
||||
CLOUD_PRICE_PER_GB: '\${CLOUD_PRICE_PER_GB:-}',
|
||||
CLOUD_PADDLE_CLIENT_TOKEN: '\${CLOUD_PADDLE_CLIENT_TOKEN:-}'
|
||||
};
|
||||
JSEOF
|
||||
|
||||
@@ -326,6 +349,32 @@ if [ -n "\${ANALYTICS_SCRIPT:-}" ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
# Inject Paddle script if client token is provided (only if not already injected)
|
||||
if [ -n "\${CLOUD_PADDLE_CLIENT_TOKEN:-}" ]; then
|
||||
if ! grep -q "cdn.paddle.com" /app/ui/build/index.html 2>/dev/null; then
|
||||
echo "Injecting Paddle script..."
|
||||
sed -i "s#</head># <script src=\"https://cdn.paddle.com/paddle/v2/paddle.js\"></script>\\
|
||||
</head>#" /app/ui/build/index.html
|
||||
fi
|
||||
fi
|
||||
|
||||
# Inject static HTML into root div for cloud mode (payment system requires visible legal links)
|
||||
if [ "\${IS_CLOUD:-false}" = "true" ]; then
|
||||
if ! grep -q "cloud-static-content" /app/ui/build/index.html 2>/dev/null; then
|
||||
echo "Injecting cloud static HTML content..."
|
||||
perl -i -pe '
|
||||
BEGIN {
|
||||
open my \$fh, "<", "/app/cloud-root-content.html" or die;
|
||||
local \$/;
|
||||
\$c = <\$fh>;
|
||||
close \$fh;
|
||||
\$c =~ s/\\n/ /g;
|
||||
}
|
||||
s/<div id="root"><\\/div>/<div id="root"><!-- cloud-static-content --><noscript>\$c<\\/noscript><\\/div>/
|
||||
' /app/ui/build/index.html
|
||||
fi
|
||||
fi
|
||||
|
||||
# Ensure proper ownership of data directory
|
||||
echo "Setting up data directory permissions..."
|
||||
mkdir -p /databasus-data/pgdata
|
||||
@@ -375,7 +424,12 @@ fi
|
||||
# Function to start PostgreSQL and wait for it to be ready
|
||||
start_postgres() {
|
||||
echo "Starting PostgreSQL..."
|
||||
gosu postgres \$PG_BIN/postgres -D /databasus-data/pgdata -p 5437 &
|
||||
# -k /tmp: create Unix socket and lock file in /tmp instead of /var/run/postgresql/.
|
||||
# On NAS systems (e.g. TrueNAS Scale), the ZFS-backed Docker overlay filesystem
|
||||
# ignores chown/chmod on directories from image layers, so PostgreSQL gets
|
||||
# "Permission denied" when creating .s.PGSQL.5437.lock in /var/run/postgresql/.
|
||||
# All internal connections use TCP (-h localhost), so the socket location does not matter.
|
||||
gosu postgres \$PG_BIN/postgres -D /databasus-data/pgdata -p 5437 -k /tmp &
|
||||
POSTGRES_PID=\$!
|
||||
|
||||
echo "Waiting for PostgreSQL to be ready..."
|
||||
@@ -439,7 +493,7 @@ fi
|
||||
echo "Setting up database and user..."
|
||||
gosu postgres \$PG_BIN/psql -p 5437 -h localhost -d postgres << 'SQL'
|
||||
|
||||
# We use stub password, because internal DB is not exposed outside container
|
||||
-- We use stub password, because internal DB is not exposed outside container
|
||||
ALTER USER postgres WITH PASSWORD 'Q1234567';
|
||||
SELECT 'CREATE DATABASE databasus OWNER postgres'
|
||||
WHERE NOT EXISTS (SELECT FROM pg_database WHERE datname = 'databasus')
|
||||
|
||||
22
NOTICE.md
Normal file
22
NOTICE.md
Normal file
@@ -0,0 +1,22 @@
|
||||
Copyright © 2025–2026 Rostislav Dugin and contributors.
|
||||
|
||||
“Databasus” is a trademark of Rostislav Dugin.
|
||||
|
||||
The source code in this repository is licensed under the Apache License, Version 2.0.
|
||||
That license applies to the code only and does not grant any right to use the
|
||||
Databasus name, logo, or branding, except for reasonable and customary referential
|
||||
use in describing the origin of the software and reproducing the content of this NOTICE.
|
||||
|
||||
Permitted referential use includes truthful use of the name “Databasus” to identify
|
||||
the original Databasus project in software catalogs, deployment templates, hosting
|
||||
panels, package indexes, compatibility pages, integrations, tutorials, reviews, and
|
||||
similar informational materials, including phrases such as “Databasus”,
|
||||
“Deploy Databasus”, “Databasus on Coolify”, and “Compatible with Databasus”.
|
||||
|
||||
You may not use “Databasus” as the name or primary branding of a competing product,
|
||||
service, fork, distribution, or hosted offering, or in any manner likely to cause
|
||||
confusion as to source, affiliation, sponsorship, or endorsement.
|
||||
|
||||
Nothing in this repository transfers, waives, limits, or estops any rights in the
|
||||
Databasus mark. All trademark rights are reserved except for the limited referential
|
||||
use stated above.
|
||||
34
README.md
34
README.md
@@ -1,8 +1,8 @@
|
||||
<div align="center">
|
||||
<img src="assets/logo.svg" alt="Databasus Logo" width="250"/>
|
||||
|
||||
<h3>Backup tool for PostgreSQL, MySQL and MongoDB</h3>
|
||||
<p>Databasus is a free, open source and self-hosted tool to backup databases (with focus on PostgreSQL). Make backups with different storages (S3, Google Drive, FTP, etc.) and notifications about progress (Slack, Discord, Telegram, etc.)</p>
|
||||
<h3>PostgreSQL backup tool (with MySQL\MariaDB and MongoDB support)</h3>
|
||||
<p>Databasus is a free, open source and self-hosted tool to backup databases (with primary focus on PostgreSQL). Make backups with different storages (S3, Google Drive, FTP, etc.) and notifications about progress (Slack, Discord, Telegram, etc.)</p>
|
||||
|
||||
<!-- Badges -->
|
||||
[](https://www.postgresql.org/)
|
||||
@@ -91,14 +91,16 @@ It is also important for Databasus that you are able to decrypt and restore back
|
||||
- **Dark & light themes**: Choose the look that suits your workflow
|
||||
- **Mobile adaptive**: Check your backups from anywhere on any device
|
||||
|
||||
### ☁️ **Works with self-hosted & cloud databases**
|
||||
### 🔌 **Connection types**
|
||||
|
||||
Databasus works seamlessly with both self-hosted PostgreSQL and cloud-managed databases:
|
||||
- **Remote** — Databasus connects directly to the database over the network (recommended in read-only mode). No agent or additional software required. Works with cloud-managed and self-hosted databases
|
||||
- **Agent** — A lightweight Databasus agent (written in Go) runs alongside the database. The agent streams backups directly to Databasus, so the database never needs to be exposed publicly. Supports host-installed databases and Docker containers
|
||||
|
||||
- **Cloud support**: AWS RDS, Google Cloud SQL, Azure Database for PostgreSQL
|
||||
- **Self-hosted**: Any PostgreSQL instance you manage yourself
|
||||
- **Why no PITR support?**: Cloud providers already offer native PITR, and external PITR backups cannot be restored to managed cloud databases — making them impractical for cloud-hosted PostgreSQL
|
||||
- **Practical granularity**: Hourly and daily backups are sufficient for 99% of projects without the operational complexity of WAL archiving
|
||||
### 📦 **Backup types**
|
||||
|
||||
- **Logical** — Native dump of the database in its engine-specific binary format. Compressed and streamed directly to storage with no intermediate files
|
||||
- **Physical** — File-level copy of the entire database cluster. Faster backup and restore for large datasets compared to logical dumps
|
||||
- **Incremental** — Physical base backup combined with continuous WAL segment archiving. Enables Point-in-time recovery (PITR) — restore to any second between backups. Designed for disaster recovery and near-zero data loss requirements
|
||||
|
||||
### 🐳 **Self-hosted & secure**
|
||||
|
||||
@@ -243,7 +245,7 @@ Replace `admin` with the actual email address of the user whose password you wan
|
||||
|
||||
### 💾 Backuping Databasus itself
|
||||
|
||||
After installation, it is also recommended to <a href="https://databasus.com/faq/#backup-databasus">backup your Databasus itself</a> or, at least, to copy secret key used for encryption (30 seconds is needed). So you are able to restore from your encrypted backups if you lose access to the server with Databasus or it is corrupted.
|
||||
After installation, it is also recommended to <a href="https://databasus.com/faq#backup-databasus">backup your Databasus itself</a> or, at least, to copy secret key used for encryption (30 seconds is needed). So you are able to restore from your encrypted backups if you lose access to the server with Databasus or it is corrupted.
|
||||
|
||||
---
|
||||
|
||||
@@ -257,7 +259,9 @@ Contributions are welcome! Read the <a href="https://databasus.com/contribute">c
|
||||
|
||||
Also you can join our large community of developers, DBAs and DevOps engineers on Telegram [@databasus_community](https://t.me/databasus_community).
|
||||
|
||||
## AI disclaimer
|
||||
## FAQ
|
||||
|
||||
### AI disclaimer
|
||||
|
||||
There have been questions about AI usage in project development in issues and discussions. As the project focuses on security, reliability and production usage, it's important to explain how AI is used in the development process.
|
||||
|
||||
@@ -293,3 +297,13 @@ Moreover, it's important to note that we do not differentiate between bad human
|
||||
Even if code is written manually by a human, it's not guaranteed to be merged. Vibe code is not allowed at all and all such PRs are rejected by default (see [contributing guide](https://databasus.com/contribute)).
|
||||
|
||||
We also draw attention to fast issue resolution and security [vulnerability reporting](https://github.com/databasus/databasus?tab=security-ov-file#readme).
|
||||
|
||||
### You have a cloud version — are you truly open source?
|
||||
|
||||
Yes. Every feature available in Databasus Cloud is equally available in the self-hosted version with no restrictions, no feature gates and no usage limits. The entire codebase is Apache 2.0 licensed and always will be.
|
||||
|
||||
Databasus is not "open core." We do not withhold features behind a paid tier and then call the limited remainder "open source," as projects like GitLab or Sentry do. We believe open source means the complete product is open, not just a marketing label on a stripped-down edition.
|
||||
|
||||
Databasus Cloud runs the exact same code as the self-hosted version. The only difference is that we take care of infrastructure, availability, backups, reservations, monitoring and updates for you — so you don't have to. If you are using cloud, you can always move your databases from cloud to self-hosted if you wish.
|
||||
|
||||
Revenue from Cloud funds full-time development of the project. Most large open-source projects rely on corporate backing or sponsorship to survive. We chose a different path: Databasus sustains itself so it can grow and improve independently, without being tied to any enterprise or sponsor.
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
ENV_MODE=development
|
||||
AGENT_DB_ID=your-database-id
|
||||
AGENT_TOKEN=your-agent-token
|
||||
|
||||
3
agent/.gitignore
vendored
3
agent/.gitignore
vendored
@@ -23,4 +23,5 @@ valkey-data/
|
||||
victoria-logs-data/
|
||||
databasus.json
|
||||
.test-tmp/
|
||||
databasus.log
|
||||
databasus.log
|
||||
wal-queue/
|
||||
@@ -1,8 +1,21 @@
|
||||
.PHONY: run build test lint e2e e2e-clean
|
||||
.PHONY: run build test lint e2e e2e-clean e2e-backup-restore e2e-backup-restore-clean
|
||||
|
||||
-include .env
|
||||
export
|
||||
|
||||
# Usage: make run ARGS="start --pg-host localhost"
|
||||
run:
|
||||
go run cmd/main.go $(ARGS)
|
||||
go run cmd/main.go start \
|
||||
--databasus-host http://localhost:4005 \
|
||||
--db-id $(AGENT_DB_ID) \
|
||||
--token $(AGENT_TOKEN) \
|
||||
--pg-host 127.0.0.1 \
|
||||
--pg-port 7433 \
|
||||
--pg-user devuser \
|
||||
--pg-password devpassword \
|
||||
--pg-type docker \
|
||||
--pg-docker-container-name dev-postgres \
|
||||
--pg-wal-dir ./wal-queue \
|
||||
--skip-update
|
||||
|
||||
build:
|
||||
CGO_ENABLED=0 go build -ldflags "-X main.Version=$(VERSION)" -o databasus-agent ./cmd/main.go
|
||||
@@ -14,6 +27,7 @@ lint:
|
||||
golangci-lint fmt ./cmd/... ./internal/... ./e2e/... && golangci-lint run ./cmd/... ./internal/... ./e2e/...
|
||||
|
||||
e2e:
|
||||
cd e2e && docker compose build --no-cache e2e-mock-server
|
||||
cd e2e && docker compose build
|
||||
cd e2e && docker compose run --rm e2e-agent-builder
|
||||
cd e2e && docker compose up -d e2e-postgres e2e-mock-server
|
||||
@@ -23,4 +37,5 @@ e2e:
|
||||
|
||||
e2e-clean:
|
||||
cd e2e && docker compose down -v --rmi local
|
||||
cd e2e && docker compose -f docker-compose.backup-restore.yml down -v --rmi local 2>/dev/null || true
|
||||
rm -rf e2e/artifacts
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
|
||||
"databasus-agent/internal/config"
|
||||
"databasus-agent/internal/features/api"
|
||||
"databasus-agent/internal/features/restore"
|
||||
"databasus-agent/internal/features/start"
|
||||
"databasus-agent/internal/features/upgrade"
|
||||
"databasus-agent/internal/logger"
|
||||
@@ -115,10 +117,9 @@ func runStatus() {
|
||||
func runRestore(args []string) {
|
||||
fs := flag.NewFlagSet("restore", flag.ExitOnError)
|
||||
|
||||
targetDir := fs.String("target-dir", "", "Target pgdata directory")
|
||||
pgDataDir := fs.String("target-dir", "", "Target pgdata directory (required)")
|
||||
backupID := fs.String("backup-id", "", "Full backup UUID (optional)")
|
||||
targetTime := fs.String("target-time", "", "PITR target time in RFC3339 (optional)")
|
||||
isYes := fs.Bool("yes", false, "Skip confirmation prompt")
|
||||
isSkipUpdate := fs.Bool("skip-update", false, "Skip auto-update check")
|
||||
|
||||
cfg := &config.Config{}
|
||||
@@ -133,12 +134,29 @@ func runRestore(args []string) {
|
||||
isDev := checkIsDevelopment()
|
||||
runUpdateCheck(cfg.DatabasusHost, *isSkipUpdate, isDev, log)
|
||||
|
||||
log.Info("restore: stub — not yet implemented",
|
||||
"targetDir", *targetDir,
|
||||
"backupId", *backupID,
|
||||
"targetTime", *targetTime,
|
||||
"yes", *isYes,
|
||||
)
|
||||
if *pgDataDir == "" {
|
||||
fmt.Fprintln(os.Stderr, "Error: --target-dir is required")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if cfg.DatabasusHost == "" || cfg.Token == "" {
|
||||
fmt.Fprintln(os.Stderr, "Error: databasus-host and token must be configured")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if cfg.PgType != "host" && cfg.PgType != "docker" {
|
||||
fmt.Fprintf(os.Stderr, "Error: --pg-type must be 'host' or 'docker', got '%s'\n", cfg.PgType)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
apiClient := api.NewClient(cfg.DatabasusHost, cfg.Token, log)
|
||||
restorer := restore.NewRestorer(apiClient, log, *pgDataDir, *backupID, *targetTime, cfg.PgType)
|
||||
|
||||
ctx := context.Background()
|
||||
if err := restorer.Run(ctx); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func printUsage() {
|
||||
|
||||
58
agent/docker-compose.yml.example
Normal file
58
agent/docker-compose.yml.example
Normal file
@@ -0,0 +1,58 @@
|
||||
services:
|
||||
dev-postgres:
|
||||
image: postgres:17
|
||||
container_name: dev-postgres
|
||||
environment:
|
||||
POSTGRES_DB: devdb
|
||||
POSTGRES_USER: devuser
|
||||
POSTGRES_PASSWORD: devpassword
|
||||
ports:
|
||||
- "7433:5432"
|
||||
command:
|
||||
- bash
|
||||
- -c
|
||||
- |
|
||||
mkdir -p /wal-queue && chown postgres:postgres /wal-queue
|
||||
exec docker-entrypoint.sh postgres \
|
||||
-c wal_level=replica \
|
||||
-c max_wal_senders=3 \
|
||||
-c archive_mode=on \
|
||||
-c "archive_command=cp %p /wal-queue/%f"
|
||||
volumes:
|
||||
- ./wal-queue:/wal-queue
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U devuser -d devdb"]
|
||||
interval: 2s
|
||||
timeout: 5s
|
||||
retries: 30
|
||||
|
||||
db-writer:
|
||||
image: postgres:17
|
||||
container_name: dev-db-writer
|
||||
depends_on:
|
||||
dev-postgres:
|
||||
condition: service_healthy
|
||||
environment:
|
||||
PGHOST: dev-postgres
|
||||
PGPORT: "5432"
|
||||
PGUSER: devuser
|
||||
PGPASSWORD: devpassword
|
||||
PGDATABASE: devdb
|
||||
command:
|
||||
- bash
|
||||
- -c
|
||||
- |
|
||||
echo "Waiting for postgres..."
|
||||
until pg_isready -h dev-postgres -U devuser -d devdb; do sleep 1; done
|
||||
|
||||
psql -c "DROP TABLE IF EXISTS wal_generator;"
|
||||
psql -c "CREATE TABLE wal_generator (id SERIAL PRIMARY KEY, data TEXT NOT NULL);"
|
||||
echo "Starting WAL generation loop..."
|
||||
while true; do
|
||||
echo "Inserting ~50MB of data..."
|
||||
psql -c "INSERT INTO wal_generator (data) SELECT repeat(md5(random()::text), 640) FROM generate_series(1, 2500);"
|
||||
echo "Deleting data..."
|
||||
psql -c "DELETE FROM wal_generator;"
|
||||
echo "Cycle complete, sleeping 5s..."
|
||||
sleep 5
|
||||
done
|
||||
1
agent/e2e/.gitignore
vendored
1
agent/e2e/.gitignore
vendored
@@ -1 +1,2 @@
|
||||
artifacts/
|
||||
pgdata/
|
||||
|
||||
@@ -1,8 +1,22 @@
|
||||
# Runs pg_basebackup-via-docker-exec test (test 5) which tests
|
||||
# that the agent can connect to Postgres inside Docker container
|
||||
FROM docker:27-cli
|
||||
# Runs backup-restore via docker exec test (test 6). Needs both Docker
|
||||
# CLI (for pg_basebackup via docker exec) and PostgreSQL server (for
|
||||
# restore verification).
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
RUN apk add --no-cache bash curl
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
ca-certificates curl gnupg2 locales postgresql-common && \
|
||||
sed -i '/en_US.UTF-8/s/^# //g' /etc/locale.gen && \
|
||||
locale-gen && \
|
||||
/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
postgresql-17 && \
|
||||
install -m 0755 -d /etc/apt/keyrings && \
|
||||
curl -fsSL https://download.docker.com/linux/debian/gpg -o /etc/apt/keyrings/docker.asc && \
|
||||
echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/debian bookworm stable" > /etc/apt/sources.list.d/docker.list && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends docker-ce-cli && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /tmp
|
||||
ENTRYPOINT []
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Runs upgrade and host-mode pg_basebackup tests (tests 1-4). Needs
|
||||
# Postgres client tools to be installed inside the system
|
||||
# Runs upgrade and host-mode backup-restore tests (tests 1-5). Needs
|
||||
# full PostgreSQL server for backup-restore lifecycle tests.
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
RUN apt-get update && \
|
||||
@@ -7,7 +7,7 @@ RUN apt-get update && \
|
||||
ca-certificates curl gnupg2 postgresql-common && \
|
||||
/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
postgresql-client-17 && \
|
||||
postgresql-17 && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /tmp
|
||||
|
||||
16
agent/e2e/Dockerfile.backup-restore-runner
Normal file
16
agent/e2e/Dockerfile.backup-restore-runner
Normal file
@@ -0,0 +1,16 @@
|
||||
# Runs backup-restore lifecycle tests with a specific PostgreSQL version.
|
||||
# Used for PG version matrix testing (15, 16, 17, 18).
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
ARG PG_VERSION=17
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
ca-certificates curl gnupg2 postgresql-common && \
|
||||
/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
postgresql-${PG_VERSION} && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /tmp
|
||||
ENTRYPOINT []
|
||||
33
agent/e2e/docker-compose.backup-restore.yml
Normal file
33
agent/e2e/docker-compose.backup-restore.yml
Normal file
@@ -0,0 +1,33 @@
|
||||
services:
|
||||
e2e-br-mock-server:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.mock-server
|
||||
volumes:
|
||||
- backup-storage:/backup-storage
|
||||
container_name: e2e-br-mock-server
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "-q", "--spider", "http://localhost:4050/health"]
|
||||
interval: 2s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
|
||||
e2e-br-runner:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.backup-restore-runner
|
||||
args:
|
||||
PG_VERSION: ${PG_VERSION:-17}
|
||||
volumes:
|
||||
- ./artifacts:/opt/agent/artifacts:ro
|
||||
- ./scripts:/opt/agent/scripts:ro
|
||||
depends_on:
|
||||
e2e-br-mock-server:
|
||||
condition: service_healthy
|
||||
container_name: e2e-br-runner
|
||||
command: ["bash", "/opt/agent/scripts/test-pg-host-path.sh"]
|
||||
environment:
|
||||
MOCK_SERVER_OVERRIDE: "http://e2e-br-mock-server:4050"
|
||||
|
||||
volumes:
|
||||
backup-storage:
|
||||
@@ -14,7 +14,19 @@ services:
|
||||
POSTGRES_USER: testuser
|
||||
POSTGRES_PASSWORD: testpassword
|
||||
container_name: e2e-agent-postgres
|
||||
command: postgres -c wal_level=replica -c max_wal_senders=3
|
||||
command:
|
||||
- bash
|
||||
- -c
|
||||
- |
|
||||
mkdir -p /wal-queue && chown postgres:postgres /wal-queue
|
||||
exec docker-entrypoint.sh postgres \
|
||||
-c wal_level=replica \
|
||||
-c max_wal_senders=3 \
|
||||
-c archive_mode=on \
|
||||
-c "archive_command=cp %p /wal-queue/%f"
|
||||
volumes:
|
||||
- ./pgdata:/var/lib/postgresql/data
|
||||
- wal-queue:/wal-queue
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U testuser -d testdb"]
|
||||
interval: 2s
|
||||
@@ -27,6 +39,7 @@ services:
|
||||
dockerfile: Dockerfile.mock-server
|
||||
volumes:
|
||||
- ./artifacts:/artifacts:ro
|
||||
- backup-storage:/backup-storage
|
||||
container_name: e2e-mock-server
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "-q", "--spider", "http://localhost:4050/health"]
|
||||
@@ -57,8 +70,15 @@ services:
|
||||
- ./artifacts:/opt/agent/artifacts:ro
|
||||
- ./scripts:/opt/agent/scripts:ro
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
- wal-queue:/wal-queue
|
||||
depends_on:
|
||||
e2e-postgres:
|
||||
condition: service_healthy
|
||||
e2e-mock-server:
|
||||
condition: service_healthy
|
||||
container_name: e2e-agent-docker
|
||||
command: ["bash", "/opt/agent/scripts/run-all.sh", "docker"]
|
||||
|
||||
volumes:
|
||||
wal-queue:
|
||||
backup-storage:
|
||||
|
||||
@@ -1,17 +1,39 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const backupStorageDir = "/backup-storage"
|
||||
|
||||
type walSegment struct {
|
||||
BackupID string
|
||||
SegmentName string
|
||||
FilePath string
|
||||
SizeBytes int64
|
||||
}
|
||||
|
||||
type server struct {
|
||||
mu sync.RWMutex
|
||||
version string
|
||||
binaryPath string
|
||||
|
||||
backupID string
|
||||
backupFilePath string
|
||||
startSegment string
|
||||
stopSegment string
|
||||
isFinalized bool
|
||||
walSegments []walSegment
|
||||
backupCreatedAt time.Time
|
||||
}
|
||||
|
||||
func main() {
|
||||
@@ -19,12 +41,31 @@ func main() {
|
||||
binaryPath := "/artifacts/agent-v2"
|
||||
port := "4050"
|
||||
|
||||
_ = os.MkdirAll(backupStorageDir, 0o755)
|
||||
|
||||
s := &server{version: version, binaryPath: binaryPath}
|
||||
|
||||
// System endpoints
|
||||
http.HandleFunc("/api/v1/system/version", s.handleVersion)
|
||||
http.HandleFunc("/api/v1/system/agent", s.handleAgentDownload)
|
||||
|
||||
// Backup endpoints
|
||||
http.HandleFunc("/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup", s.handleChainValidity)
|
||||
http.HandleFunc("/api/v1/backups/postgres/wal/next-full-backup-time", s.handleNextBackupTime)
|
||||
http.HandleFunc("/api/v1/backups/postgres/wal/upload/full-start", s.handleFullStart)
|
||||
http.HandleFunc("/api/v1/backups/postgres/wal/upload/full-complete", s.handleFullComplete)
|
||||
http.HandleFunc("/api/v1/backups/postgres/wal/upload/wal", s.handleWalUpload)
|
||||
http.HandleFunc("/api/v1/backups/postgres/wal/error", s.handleError)
|
||||
|
||||
// Restore endpoints
|
||||
http.HandleFunc("/api/v1/backups/postgres/wal/restore/plan", s.handleRestorePlan)
|
||||
http.HandleFunc("/api/v1/backups/postgres/wal/restore/download", s.handleRestoreDownload)
|
||||
|
||||
// Mock control endpoints
|
||||
http.HandleFunc("/mock/set-version", s.handleSetVersion)
|
||||
http.HandleFunc("/mock/set-binary-path", s.handleSetBinaryPath)
|
||||
http.HandleFunc("/mock/backup-status", s.handleBackupStatus)
|
||||
http.HandleFunc("/mock/reset", s.handleReset)
|
||||
http.HandleFunc("/health", s.handleHealth)
|
||||
|
||||
addr := ":" + port
|
||||
@@ -35,7 +76,9 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *server) handleVersion(w http.ResponseWriter, r *http.Request) {
|
||||
// --- System handlers ---
|
||||
|
||||
func (s *server) handleVersion(w http.ResponseWriter, _ *http.Request) {
|
||||
s.mu.RLock()
|
||||
v := s.version
|
||||
s.mu.RUnlock()
|
||||
@@ -56,6 +99,263 @@ func (s *server) handleAgentDownload(w http.ResponseWriter, r *http.Request) {
|
||||
http.ServeFile(w, r, path)
|
||||
}
|
||||
|
||||
// --- Backup handlers ---
|
||||
|
||||
func (s *server) handleChainValidity(w http.ResponseWriter, _ *http.Request) {
|
||||
s.mu.RLock()
|
||||
isFinalized := s.isFinalized
|
||||
s.mu.RUnlock()
|
||||
|
||||
log.Printf("GET chain-validity -> isFinalized=%v", isFinalized)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if isFinalized {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"isValid": true,
|
||||
})
|
||||
} else {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"isValid": false,
|
||||
"error": "no full backup found",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *server) handleNextBackupTime(w http.ResponseWriter, _ *http.Request) {
|
||||
log.Printf("GET next-full-backup-time")
|
||||
|
||||
nextTime := time.Now().UTC().Add(1 * time.Hour)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"nextFullBackupTime": nextTime.Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *server) handleFullStart(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "POST only", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
backupID := generateID()
|
||||
filePath := filepath.Join(backupStorageDir, backupID+".zst")
|
||||
|
||||
file, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
log.Printf("ERROR creating backup file: %v", err)
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
bytesWritten, err := io.Copy(file, r.Body)
|
||||
_ = file.Close()
|
||||
|
||||
if err != nil {
|
||||
log.Printf("ERROR writing backup data: %v", err)
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.backupID = backupID
|
||||
s.backupFilePath = filePath
|
||||
s.backupCreatedAt = time.Now().UTC()
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Printf("POST full-start -> backupID=%s, size=%d bytes", backupID, bytesWritten)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"backupId": backupID})
|
||||
}
|
||||
|
||||
func (s *server) handleFullComplete(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "POST only", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var body struct {
|
||||
BackupID string `json:"backupId"`
|
||||
StartSegment string `json:"startSegment"`
|
||||
StopSegment string `json:"stopSegment"`
|
||||
Error *string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if body.Error != nil {
|
||||
log.Printf("POST full-complete -> backupID=%s ERROR: %s", body.BackupID, *body.Error)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.startSegment = body.StartSegment
|
||||
s.stopSegment = body.StopSegment
|
||||
s.isFinalized = true
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Printf(
|
||||
"POST full-complete -> backupID=%s, start=%s, stop=%s",
|
||||
body.BackupID,
|
||||
body.StartSegment,
|
||||
body.StopSegment,
|
||||
)
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (s *server) handleWalUpload(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "POST only", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
segmentName := r.Header.Get("X-Wal-Segment-Name")
|
||||
if segmentName == "" {
|
||||
http.Error(w, "missing X-Wal-Segment-Name header", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
walBackupID := generateID()
|
||||
filePath := filepath.Join(backupStorageDir, walBackupID+".zst")
|
||||
|
||||
file, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
log.Printf("ERROR creating WAL file: %v", err)
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
bytesWritten, err := io.Copy(file, r.Body)
|
||||
_ = file.Close()
|
||||
|
||||
if err != nil {
|
||||
log.Printf("ERROR writing WAL data: %v", err)
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.walSegments = append(s.walSegments, walSegment{
|
||||
BackupID: walBackupID,
|
||||
SegmentName: segmentName,
|
||||
FilePath: filePath,
|
||||
SizeBytes: bytesWritten,
|
||||
})
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Printf("POST wal-upload -> segment=%s, walBackupID=%s, size=%d", segmentName, walBackupID, bytesWritten)
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (s *server) handleError(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "POST only", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var body struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
log.Printf("POST error -> failed to decode: %v", err)
|
||||
} else {
|
||||
log.Printf("POST error -> %s", body.Error)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// --- Restore handlers ---
|
||||
|
||||
func (s *server) handleRestorePlan(w http.ResponseWriter, _ *http.Request) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if !s.isFinalized {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "no_backups",
|
||||
"message": "No full backups available",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
backupFileInfo, err := os.Stat(s.backupFilePath)
|
||||
if err != nil {
|
||||
log.Printf("ERROR stat backup file: %v", err)
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
backupSizeBytes := backupFileInfo.Size()
|
||||
totalSizeBytes := backupSizeBytes
|
||||
|
||||
walSegmentsJSON := make([]map[string]any, 0, len(s.walSegments))
|
||||
|
||||
latestSegment := ""
|
||||
|
||||
for _, segment := range s.walSegments {
|
||||
totalSizeBytes += segment.SizeBytes
|
||||
latestSegment = segment.SegmentName
|
||||
|
||||
walSegmentsJSON = append(walSegmentsJSON, map[string]any{
|
||||
"backupId": segment.BackupID,
|
||||
"segmentName": segment.SegmentName,
|
||||
"sizeBytes": segment.SizeBytes,
|
||||
})
|
||||
}
|
||||
|
||||
response := map[string]any{
|
||||
"fullBackup": map[string]any{
|
||||
"id": s.backupID,
|
||||
"fullBackupWalStartSegment": s.startSegment,
|
||||
"fullBackupWalStopSegment": s.stopSegment,
|
||||
"pgVersion": "17",
|
||||
"createdAt": s.backupCreatedAt.Format(time.RFC3339),
|
||||
"sizeBytes": backupSizeBytes,
|
||||
},
|
||||
"walSegments": walSegmentsJSON,
|
||||
"totalSizeBytes": totalSizeBytes,
|
||||
"latestAvailableSegment": latestSegment,
|
||||
}
|
||||
|
||||
log.Printf("GET restore-plan -> backupID=%s, walSegments=%d", s.backupID, len(s.walSegments))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
func (s *server) handleRestoreDownload(w http.ResponseWriter, r *http.Request) {
|
||||
requestedBackupID := r.URL.Query().Get("backupId")
|
||||
if requestedBackupID == "" {
|
||||
http.Error(w, "missing backupId query param", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
filePath := s.findBackupFile(requestedBackupID)
|
||||
if filePath == "" {
|
||||
log.Printf("GET restore-download -> backupId=%s NOT FOUND", requestedBackupID)
|
||||
http.Error(w, "backup not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("GET restore-download -> backupId=%s, file=%s", requestedBackupID, filePath)
|
||||
|
||||
http.ServeFile(w, r, filePath)
|
||||
}
|
||||
|
||||
// --- Mock control handlers ---
|
||||
|
||||
func (s *server) handleSetVersion(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "POST only", http.StatusMethodNotAllowed)
|
||||
@@ -65,6 +365,7 @@ func (s *server) handleSetVersion(w http.ResponseWriter, r *http.Request) {
|
||||
var body struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
@@ -88,6 +389,7 @@ func (s *server) handleSetBinaryPath(w http.ResponseWriter, r *http.Request) {
|
||||
var body struct {
|
||||
BinaryPath string `json:"binaryPath"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
@@ -102,7 +404,74 @@ func (s *server) handleSetBinaryPath(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = fmt.Fprintf(w, "binary path set to %s", body.BinaryPath)
|
||||
}
|
||||
|
||||
func (s *server) handleBackupStatus(w http.ResponseWriter, _ *http.Request) {
|
||||
s.mu.RLock()
|
||||
isFinalized := s.isFinalized
|
||||
walSegmentCount := len(s.walSegments)
|
||||
s.mu.RUnlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"isFinalized": isFinalized,
|
||||
"walSegmentCount": walSegmentCount,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *server) handleReset(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "POST only", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.backupID = ""
|
||||
s.backupFilePath = ""
|
||||
s.startSegment = ""
|
||||
s.stopSegment = ""
|
||||
s.isFinalized = false
|
||||
s.walSegments = nil
|
||||
s.backupCreatedAt = time.Time{}
|
||||
s.mu.Unlock()
|
||||
|
||||
// Clean stored files
|
||||
entries, _ := os.ReadDir(backupStorageDir)
|
||||
for _, entry := range entries {
|
||||
_ = os.Remove(filepath.Join(backupStorageDir, entry.Name()))
|
||||
}
|
||||
|
||||
log.Printf("POST /mock/reset -> state cleared")
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}
|
||||
|
||||
func (s *server) handleHealth(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}
|
||||
|
||||
// --- Private helpers ---
|
||||
|
||||
func generateID() string {
|
||||
b := make([]byte, 16)
|
||||
_, _ = rand.Read(b)
|
||||
|
||||
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:16])
|
||||
}
|
||||
|
||||
func (s *server) findBackupFile(backupID string) string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.backupID == backupID {
|
||||
return s.backupFilePath
|
||||
}
|
||||
|
||||
for _, segment := range s.walSegments {
|
||||
if segment.BackupID == backupID {
|
||||
return segment.FilePath
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
357
agent/e2e/scripts/backup-restore-helpers.sh
Normal file
357
agent/e2e/scripts/backup-restore-helpers.sh
Normal file
@@ -0,0 +1,357 @@
|
||||
#!/bin/bash
|
||||
# Shared helper functions for backup-restore E2E tests.
|
||||
# Source this file from test scripts: source "$(dirname "$0")/backup-restore-helpers.sh"
|
||||
|
||||
AGENT="/tmp/test-agent"
|
||||
AGENT_PID=""
|
||||
|
||||
cleanup_agent() {
|
||||
if [ -n "$AGENT_PID" ]; then
|
||||
kill "$AGENT_PID" 2>/dev/null || true
|
||||
wait "$AGENT_PID" 2>/dev/null || true
|
||||
AGENT_PID=""
|
||||
fi
|
||||
|
||||
pkill -f "test-agent" 2>/dev/null || true
|
||||
for i in $(seq 1 20); do
|
||||
pgrep -f "test-agent" > /dev/null 2>&1 || break
|
||||
sleep 0.5
|
||||
done
|
||||
pkill -9 -f "test-agent" 2>/dev/null || true
|
||||
sleep 0.5
|
||||
|
||||
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
|
||||
}
|
||||
|
||||
setup_agent() {
|
||||
local artifacts="${1:-/opt/agent/artifacts}"
|
||||
|
||||
cleanup_agent
|
||||
cp "$artifacts/agent-v1" "$AGENT"
|
||||
chmod +x "$AGENT"
|
||||
}
|
||||
|
||||
init_pg_local() {
|
||||
local pgdata="$1"
|
||||
local port="$2"
|
||||
local wal_queue="$3"
|
||||
local pg_bin_dir="$4"
|
||||
|
||||
# Stop any leftover PG from previous test runs
|
||||
su postgres -c "$pg_bin_dir/pg_ctl -D $pgdata stop -m immediate" 2>/dev/null || true
|
||||
su postgres -c "$pg_bin_dir/pg_ctl -D /tmp/restore-pgdata stop -m immediate" 2>/dev/null || true
|
||||
|
||||
mkdir -p "$wal_queue"
|
||||
chown postgres:postgres "$wal_queue"
|
||||
rm -rf "$pgdata"
|
||||
|
||||
su postgres -c "$pg_bin_dir/initdb -D $pgdata" > /dev/null
|
||||
|
||||
cat >> "$pgdata/postgresql.conf" <<PGCONF
|
||||
wal_level = replica
|
||||
archive_mode = on
|
||||
archive_command = 'cp %p $wal_queue/%f'
|
||||
max_wal_senders = 3
|
||||
listen_addresses = 'localhost'
|
||||
port = $port
|
||||
checkpoint_timeout = 30s
|
||||
PGCONF
|
||||
|
||||
echo "local all all trust" > "$pgdata/pg_hba.conf"
|
||||
echo "host all all 127.0.0.1/32 trust" >> "$pgdata/pg_hba.conf"
|
||||
echo "host all all ::1/128 trust" >> "$pgdata/pg_hba.conf"
|
||||
echo "local replication all trust" >> "$pgdata/pg_hba.conf"
|
||||
echo "host replication all 127.0.0.1/32 trust" >> "$pgdata/pg_hba.conf"
|
||||
echo "host replication all ::1/128 trust" >> "$pgdata/pg_hba.conf"
|
||||
|
||||
su postgres -c "$pg_bin_dir/pg_ctl -D $pgdata -l /tmp/pg.log start -w"
|
||||
|
||||
su postgres -c "$pg_bin_dir/psql -p $port -c \"CREATE USER testuser WITH SUPERUSER REPLICATION;\"" > /dev/null 2>&1 || true
|
||||
su postgres -c "$pg_bin_dir/psql -p $port -c \"CREATE DATABASE testdb OWNER testuser;\"" > /dev/null 2>&1 || true
|
||||
|
||||
echo "PostgreSQL initialized and started on port $port"
|
||||
}
|
||||
|
||||
insert_test_data() {
|
||||
local port="$1"
|
||||
local pg_bin_dir="$2"
|
||||
|
||||
su postgres -c "$pg_bin_dir/psql -p $port -U testuser -d testdb" <<SQL
|
||||
CREATE TABLE e2e_test_data (
|
||||
id SERIAL PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
value INT NOT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
);
|
||||
|
||||
INSERT INTO e2e_test_data (name, value) VALUES
|
||||
('row1', 100),
|
||||
('row2', 200),
|
||||
('row3', 300);
|
||||
SQL
|
||||
|
||||
echo "Test data inserted (3 rows)"
|
||||
}
|
||||
|
||||
force_checkpoint() {
|
||||
local port="$1"
|
||||
local pg_bin_dir="$2"
|
||||
|
||||
su postgres -c "$pg_bin_dir/psql -p $port -c 'CHECKPOINT;'" > /dev/null
|
||||
echo "Checkpoint forced"
|
||||
}
|
||||
|
||||
run_agent_backup() {
|
||||
local mock_server="$1"
|
||||
local pg_host="$2"
|
||||
local pg_port="$3"
|
||||
local wal_queue="$4"
|
||||
local pg_type="$5"
|
||||
local pg_host_bin_dir="${6:-}"
|
||||
local pg_docker_container="${7:-}"
|
||||
|
||||
# Reset mock server state and set version to match agent (prevents background upgrade loop)
|
||||
curl -sf -X POST "$mock_server/mock/reset" > /dev/null
|
||||
curl -sf -X POST "$mock_server/mock/set-version" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"version":"v1.0.0"}' > /dev/null
|
||||
|
||||
# Build JSON config
|
||||
cd /tmp
|
||||
|
||||
local extra_fields=""
|
||||
if [ -n "$pg_host_bin_dir" ]; then
|
||||
extra_fields="$extra_fields\"pgHostBinDir\": \"$pg_host_bin_dir\","
|
||||
fi
|
||||
if [ -n "$pg_docker_container" ]; then
|
||||
extra_fields="$extra_fields\"pgDockerContainerName\": \"$pg_docker_container\","
|
||||
fi
|
||||
|
||||
cat > databasus.json <<AGENTCONF
|
||||
{
|
||||
"databasusHost": "$mock_server",
|
||||
"dbId": "test-db-id",
|
||||
"token": "test-token",
|
||||
"pgHost": "$pg_host",
|
||||
"pgPort": $pg_port,
|
||||
"pgUser": "testuser",
|
||||
"pgPassword": "",
|
||||
${extra_fields}
|
||||
"pgType": "$pg_type",
|
||||
"pgWalDir": "$wal_queue",
|
||||
"deleteWalAfterUpload": true
|
||||
}
|
||||
AGENTCONF
|
||||
|
||||
# Run agent daemon in background
|
||||
"$AGENT" _run > /tmp/agent-output.log 2>&1 &
|
||||
AGENT_PID=$!
|
||||
|
||||
echo "Agent started with PID $AGENT_PID"
|
||||
}
|
||||
|
||||
generate_wal_background() {
|
||||
local port="$1"
|
||||
local pg_bin_dir="$2"
|
||||
|
||||
while true; do
|
||||
su postgres -c "$pg_bin_dir/psql -p $port -U testuser -d testdb -c \"
|
||||
INSERT INTO e2e_test_data (name, value)
|
||||
SELECT 'bulk_' || g, g FROM generate_series(1, 1000) g;
|
||||
SELECT pg_switch_wal();
|
||||
\"" > /dev/null 2>&1 || break
|
||||
sleep 2
|
||||
done
|
||||
}
|
||||
|
||||
generate_wal_docker_background() {
|
||||
local container="$1"
|
||||
|
||||
while true; do
|
||||
docker exec "$container" psql -U testuser -d testdb -c "
|
||||
INSERT INTO e2e_test_data (name, value)
|
||||
SELECT 'bulk_' || g, g FROM generate_series(1, 1000) g;
|
||||
SELECT pg_switch_wal();
|
||||
" > /dev/null 2>&1 || break
|
||||
sleep 2
|
||||
done
|
||||
}
|
||||
|
||||
wait_for_backup_complete() {
|
||||
local mock_server="$1"
|
||||
local timeout="${2:-120}"
|
||||
|
||||
echo "Waiting for backup to complete (timeout: ${timeout}s)..."
|
||||
|
||||
for i in $(seq 1 "$timeout"); do
|
||||
STATUS=$(curl -sf "$mock_server/mock/backup-status" 2>/dev/null || echo '{}')
|
||||
IS_FINALIZED=$(echo "$STATUS" | grep -o '"isFinalized":true' || true)
|
||||
WAL_COUNT=$(echo "$STATUS" | grep -o '"walSegmentCount":[0-9]*' | grep -o '[0-9]*$' || echo "0")
|
||||
|
||||
if [ -n "$IS_FINALIZED" ] && [ "$WAL_COUNT" -gt 0 ]; then
|
||||
echo "Backup complete: finalized with $WAL_COUNT WAL segments"
|
||||
return 0
|
||||
fi
|
||||
|
||||
sleep 1
|
||||
done
|
||||
|
||||
echo "FAIL: Backup did not complete within ${timeout} seconds"
|
||||
echo "Last status: $STATUS"
|
||||
echo "Agent output:"
|
||||
cat /tmp/agent-output.log 2>/dev/null || true
|
||||
return 1
|
||||
}
|
||||
|
||||
stop_agent() {
|
||||
if [ -n "$AGENT_PID" ]; then
|
||||
kill "$AGENT_PID" 2>/dev/null || true
|
||||
wait "$AGENT_PID" 2>/dev/null || true
|
||||
AGENT_PID=""
|
||||
fi
|
||||
|
||||
echo "Agent stopped"
|
||||
}
|
||||
|
||||
stop_pg() {
|
||||
local pgdata="$1"
|
||||
local pg_bin_dir="$2"
|
||||
|
||||
su postgres -c "$pg_bin_dir/pg_ctl -D $pgdata stop -m fast" 2>/dev/null || true
|
||||
|
||||
echo "PostgreSQL stopped"
|
||||
}
|
||||
|
||||
run_agent_restore() {
|
||||
local mock_server="$1"
|
||||
local restore_dir="$2"
|
||||
|
||||
rm -rf "$restore_dir"
|
||||
mkdir -p "$restore_dir"
|
||||
chown postgres:postgres "$restore_dir"
|
||||
|
||||
cd /tmp
|
||||
|
||||
"$AGENT" restore \
|
||||
--skip-update \
|
||||
--databasus-host "$mock_server" \
|
||||
--token test-token \
|
||||
--target-dir "$restore_dir"
|
||||
|
||||
echo "Agent restore completed"
|
||||
}
|
||||
|
||||
start_restored_pg() {
|
||||
local restore_dir="$1"
|
||||
local port="$2"
|
||||
local pg_bin_dir="$3"
|
||||
|
||||
# Ensure port is set in restored config
|
||||
if ! grep -q "^port" "$restore_dir/postgresql.conf" 2>/dev/null; then
|
||||
echo "port = $port" >> "$restore_dir/postgresql.conf"
|
||||
fi
|
||||
|
||||
# Ensure listen_addresses is set
|
||||
if ! grep -q "^listen_addresses" "$restore_dir/postgresql.conf" 2>/dev/null; then
|
||||
echo "listen_addresses = 'localhost'" >> "$restore_dir/postgresql.conf"
|
||||
fi
|
||||
|
||||
chown -R postgres:postgres "$restore_dir"
|
||||
chmod 700 "$restore_dir"
|
||||
|
||||
if ! su postgres -c "$pg_bin_dir/pg_ctl -D $restore_dir -l /tmp/pg-restore.log start -w"; then
|
||||
echo "FAIL: PostgreSQL failed to start on restored data"
|
||||
echo "--- pg-restore.log ---"
|
||||
cat /tmp/pg-restore.log 2>/dev/null || echo "(no log file)"
|
||||
echo "--- postgresql.auto.conf ---"
|
||||
cat "$restore_dir/postgresql.auto.conf" 2>/dev/null || echo "(no file)"
|
||||
echo "--- pg_wal/ listing ---"
|
||||
ls -la "$restore_dir/pg_wal/" 2>/dev/null || echo "(no pg_wal dir)"
|
||||
echo "--- databasus-wal-restore/ listing ---"
|
||||
ls -la "$restore_dir/databasus-wal-restore/" 2>/dev/null || echo "(no dir)"
|
||||
echo "--- end diagnostics ---"
|
||||
return 1
|
||||
fi
|
||||
|
||||
echo "PostgreSQL started on restored data"
|
||||
}
|
||||
|
||||
wait_for_recovery_complete() {
|
||||
local port="$1"
|
||||
local pg_bin_dir="$2"
|
||||
local timeout="${3:-60}"
|
||||
|
||||
echo "Waiting for recovery to complete (timeout: ${timeout}s)..."
|
||||
|
||||
for i in $(seq 1 "$timeout"); do
|
||||
IS_READY=$(su postgres -c "$pg_bin_dir/pg_isready -p $port" 2>&1 || true)
|
||||
|
||||
if echo "$IS_READY" | grep -q "accepting connections"; then
|
||||
IN_RECOVERY=$(su postgres -c "$pg_bin_dir/psql -p $port -U testuser -d testdb -t -c 'SELECT pg_is_in_recovery();'" 2>/dev/null | tr -d ' \n' || echo "t")
|
||||
|
||||
if [ "$IN_RECOVERY" = "f" ]; then
|
||||
echo "PostgreSQL recovered and promoted to primary"
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
|
||||
sleep 1
|
||||
done
|
||||
|
||||
echo "FAIL: PostgreSQL did not recover within ${timeout} seconds"
|
||||
echo "Recovery log:"
|
||||
cat /tmp/pg-restore.log 2>/dev/null || true
|
||||
return 1
|
||||
}
|
||||
|
||||
verify_restored_data() {
|
||||
local port="$1"
|
||||
local pg_bin_dir="$2"
|
||||
|
||||
ROW_COUNT=$(su postgres -c "$pg_bin_dir/psql -p $port -U testuser -d testdb -t -c 'SELECT COUNT(*) FROM e2e_test_data;'" | tr -d ' \n')
|
||||
|
||||
if [ "$ROW_COUNT" -lt 3 ]; then
|
||||
echo "FAIL: Expected at least 3 rows, got $ROW_COUNT"
|
||||
su postgres -c "$pg_bin_dir/psql -p $port -U testuser -d testdb -c 'SELECT * FROM e2e_test_data;'"
|
||||
return 1
|
||||
fi
|
||||
|
||||
RESULT=$(su postgres -c "$pg_bin_dir/psql -p $port -U testuser -d testdb -t -c \"SELECT value FROM e2e_test_data WHERE name='row1';\"" | tr -d ' \n')
|
||||
|
||||
if [ "$RESULT" != "100" ]; then
|
||||
echo "FAIL: Expected row1 value=100, got $RESULT"
|
||||
return 1
|
||||
fi
|
||||
|
||||
RESULT2=$(su postgres -c "$pg_bin_dir/psql -p $port -U testuser -d testdb -t -c \"SELECT value FROM e2e_test_data WHERE name='row3';\"" | tr -d ' \n')
|
||||
|
||||
if [ "$RESULT2" != "300" ]; then
|
||||
echo "FAIL: Expected row3 value=300, got $RESULT2"
|
||||
return 1
|
||||
fi
|
||||
|
||||
echo "PASS: Found $ROW_COUNT rows, data integrity verified"
|
||||
return 0
|
||||
}
|
||||
|
||||
find_pg_bin_dir() {
|
||||
# Find the PG bin dir from the installed version
|
||||
local pg_config_path
|
||||
pg_config_path=$(which pg_config 2>/dev/null || true)
|
||||
|
||||
if [ -n "$pg_config_path" ]; then
|
||||
pg_config --bindir
|
||||
return
|
||||
fi
|
||||
|
||||
# Fallback: search common locations
|
||||
for version in 18 17 16 15; do
|
||||
if [ -d "/usr/lib/postgresql/$version/bin" ]; then
|
||||
echo "/usr/lib/postgresql/$version/bin"
|
||||
return
|
||||
fi
|
||||
done
|
||||
|
||||
echo "ERROR: Cannot find PostgreSQL bin directory" >&2
|
||||
return 1
|
||||
}
|
||||
@@ -5,6 +5,7 @@ MODE="${1:-host}"
|
||||
SCRIPT_DIR="$(dirname "$0")"
|
||||
PASSED=0
|
||||
FAILED=0
|
||||
FAILED_NAMES=""
|
||||
|
||||
run_test() {
|
||||
local name="$1"
|
||||
@@ -21,6 +22,7 @@ run_test() {
|
||||
else
|
||||
echo " FAILED: $name"
|
||||
FAILED=$((FAILED + 1))
|
||||
FAILED_NAMES="${FAILED_NAMES}\n - ${name}"
|
||||
fi
|
||||
}
|
||||
|
||||
@@ -28,11 +30,11 @@ if [ "$MODE" = "host" ]; then
|
||||
run_test "Test 1: Upgrade success (v1 -> v2)" "$SCRIPT_DIR/test-upgrade-success.sh"
|
||||
run_test "Test 2: Upgrade skip (version matches)" "$SCRIPT_DIR/test-upgrade-skip.sh"
|
||||
run_test "Test 3: Background upgrade (v1 -> v2 while running)" "$SCRIPT_DIR/test-upgrade-background.sh"
|
||||
run_test "Test 4: pg_basebackup in PATH" "$SCRIPT_DIR/test-pg-host-path.sh"
|
||||
run_test "Test 5: pg_basebackup via bindir" "$SCRIPT_DIR/test-pg-host-bindir.sh"
|
||||
run_test "Test 4: Backup-restore via host PATH" "$SCRIPT_DIR/test-pg-host-path.sh"
|
||||
run_test "Test 5: Backup-restore via host bindir" "$SCRIPT_DIR/test-pg-host-bindir.sh"
|
||||
|
||||
elif [ "$MODE" = "docker" ]; then
|
||||
run_test "Test 6: pg_basebackup via docker exec" "$SCRIPT_DIR/test-pg-docker-exec.sh"
|
||||
run_test "Test 6: Backup-restore via docker exec" "$SCRIPT_DIR/test-pg-docker-exec.sh"
|
||||
|
||||
else
|
||||
echo "Unknown mode: $MODE (expected 'host' or 'docker')"
|
||||
@@ -42,6 +44,11 @@ fi
|
||||
echo ""
|
||||
echo "========================================"
|
||||
echo " Results: $PASSED passed, $FAILED failed"
|
||||
if [ "$FAILED" -gt 0 ]; then
|
||||
echo ""
|
||||
echo " Failed:"
|
||||
echo -e "$FAILED_NAMES"
|
||||
fi
|
||||
echo "========================================"
|
||||
|
||||
if [ "$FAILED" -gt 0 ]; then
|
||||
|
||||
@@ -1,23 +1,18 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
ARTIFACTS="/opt/agent/artifacts"
|
||||
AGENT="/tmp/test-agent"
|
||||
SCRIPT_DIR="$(dirname "$0")"
|
||||
source "$SCRIPT_DIR/backup-restore-helpers.sh"
|
||||
|
||||
MOCK_SERVER="${MOCK_SERVER_OVERRIDE:-http://e2e-mock-server:4050}"
|
||||
PG_CONTAINER="e2e-agent-postgres"
|
||||
RESTORE_PGDATA="/tmp/restore-pgdata"
|
||||
WAL_QUEUE="/wal-queue"
|
||||
PG_PORT=5432
|
||||
|
||||
# Cleanup from previous runs
|
||||
pkill -f "test-agent" 2>/dev/null || true
|
||||
for i in $(seq 1 20); do
|
||||
pgrep -f "test-agent" > /dev/null 2>&1 || break
|
||||
sleep 0.5
|
||||
done
|
||||
pkill -9 -f "test-agent" 2>/dev/null || true
|
||||
sleep 0.5
|
||||
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
|
||||
|
||||
# Copy agent binary
|
||||
cp "$ARTIFACTS/agent-v1" "$AGENT"
|
||||
chmod +x "$AGENT"
|
||||
# For restore verification we need a local PG bin dir
|
||||
PG_BIN_DIR=$(find_pg_bin_dir)
|
||||
echo "Using local PG bin dir for restore verification: $PG_BIN_DIR"
|
||||
|
||||
# Verify docker CLI works and PG container is accessible
|
||||
if ! docker exec "$PG_CONTAINER" pg_basebackup --version > /dev/null 2>&1; then
|
||||
@@ -25,37 +20,76 @@ if ! docker exec "$PG_CONTAINER" pg_basebackup --version > /dev/null 2>&1; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Run start with --skip-update and pg-type=docker
|
||||
echo "Running agent start (pg_basebackup via docker exec)..."
|
||||
OUTPUT=$("$AGENT" start \
|
||||
--skip-update \
|
||||
--databasus-host http://e2e-mock-server:4050 \
|
||||
--db-id test-db-id \
|
||||
--token test-token \
|
||||
--pg-host e2e-postgres \
|
||||
--pg-port 5432 \
|
||||
--pg-user testuser \
|
||||
--pg-password testpassword \
|
||||
--pg-wal-dir /tmp/wal \
|
||||
--pg-type docker \
|
||||
--pg-docker-container-name "$PG_CONTAINER" 2>&1)
|
||||
echo "=== Phase 1: Setup agent ==="
|
||||
setup_agent
|
||||
|
||||
EXIT_CODE=$?
|
||||
echo "$OUTPUT"
|
||||
echo "=== Phase 2: Insert test data into containerized PostgreSQL ==="
|
||||
docker exec "$PG_CONTAINER" psql -U testuser -d testdb -c "
|
||||
CREATE TABLE IF NOT EXISTS e2e_test_data (
|
||||
id SERIAL PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
value INT NOT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
);
|
||||
DELETE FROM e2e_test_data;
|
||||
INSERT INTO e2e_test_data (name, value) VALUES
|
||||
('row1', 100),
|
||||
('row2', 200),
|
||||
('row3', 300);
|
||||
"
|
||||
echo "Test data inserted (3 rows)"
|
||||
|
||||
if [ "$EXIT_CODE" -ne 0 ]; then
|
||||
echo "FAIL: Agent exited with code $EXIT_CODE"
|
||||
exit 1
|
||||
fi
|
||||
echo "=== Phase 3: Start agent backup (docker exec mode) ==="
|
||||
curl -sf -X POST "$MOCK_SERVER/mock/reset" > /dev/null
|
||||
|
||||
if ! echo "$OUTPUT" | grep -q "pg_basebackup verified (docker)"; then
|
||||
echo "FAIL: Expected output to contain 'pg_basebackup verified (docker)'"
|
||||
exit 1
|
||||
fi
|
||||
cd /tmp
|
||||
cat > databasus.json <<AGENTCONF
|
||||
{
|
||||
"databasusHost": "$MOCK_SERVER",
|
||||
"dbId": "test-db-id",
|
||||
"token": "test-token",
|
||||
"pgHost": "$PG_CONTAINER",
|
||||
"pgPort": $PG_PORT,
|
||||
"pgUser": "testuser",
|
||||
"pgPassword": "testpassword",
|
||||
"pgType": "docker",
|
||||
"pgDockerContainerName": "$PG_CONTAINER",
|
||||
"pgWalDir": "$WAL_QUEUE",
|
||||
"deleteWalAfterUpload": true
|
||||
}
|
||||
AGENTCONF
|
||||
|
||||
if ! echo "$OUTPUT" | grep -q "PostgreSQL connection verified"; then
|
||||
echo "FAIL: Expected output to contain 'PostgreSQL connection verified'"
|
||||
exit 1
|
||||
fi
|
||||
"$AGENT" _run > /tmp/agent-output.log 2>&1 &
|
||||
AGENT_PID=$!
|
||||
echo "Agent started with PID $AGENT_PID"
|
||||
|
||||
echo "pg_basebackup found via docker exec and DB connection verified"
|
||||
echo "=== Phase 4: Generate WAL in background ==="
|
||||
generate_wal_docker_background "$PG_CONTAINER" &
|
||||
WAL_GEN_PID=$!
|
||||
|
||||
echo "=== Phase 5: Wait for backup to complete ==="
|
||||
wait_for_backup_complete "$MOCK_SERVER" 120
|
||||
|
||||
echo "=== Phase 6: Stop WAL generator and agent ==="
|
||||
kill $WAL_GEN_PID 2>/dev/null || true
|
||||
wait $WAL_GEN_PID 2>/dev/null || true
|
||||
stop_agent
|
||||
|
||||
echo "=== Phase 7: Restore to local directory ==="
|
||||
run_agent_restore "$MOCK_SERVER" "$RESTORE_PGDATA"
|
||||
|
||||
echo "=== Phase 8: Start local PostgreSQL on restored data ==="
|
||||
# Use a different port to avoid conflict with the containerized PG
|
||||
RESTORE_PORT=5433
|
||||
start_restored_pg "$RESTORE_PGDATA" "$RESTORE_PORT" "$PG_BIN_DIR"
|
||||
|
||||
echo "=== Phase 9: Wait for recovery ==="
|
||||
wait_for_recovery_complete "$RESTORE_PORT" "$PG_BIN_DIR" 60
|
||||
|
||||
echo "=== Phase 10: Verify data ==="
|
||||
verify_restored_data "$RESTORE_PORT" "$PG_BIN_DIR"
|
||||
|
||||
echo "=== Phase 11: Cleanup ==="
|
||||
stop_pg "$RESTORE_PGDATA" "$PG_BIN_DIR"
|
||||
|
||||
echo "pg_basebackup via docker exec: full backup-restore lifecycle passed"
|
||||
|
||||
@@ -1,67 +1,62 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
ARTIFACTS="/opt/agent/artifacts"
|
||||
AGENT="/tmp/test-agent"
|
||||
SCRIPT_DIR="$(dirname "$0")"
|
||||
source "$SCRIPT_DIR/backup-restore-helpers.sh"
|
||||
|
||||
MOCK_SERVER="${MOCK_SERVER_OVERRIDE:-http://e2e-mock-server:4050}"
|
||||
PGDATA="/tmp/pgdata"
|
||||
RESTORE_PGDATA="/tmp/restore-pgdata"
|
||||
WAL_QUEUE="/tmp/wal-queue"
|
||||
PG_PORT=5433
|
||||
CUSTOM_BIN_DIR="/opt/pg/bin"
|
||||
|
||||
# Cleanup from previous runs
|
||||
pkill -f "test-agent" 2>/dev/null || true
|
||||
for i in $(seq 1 20); do
|
||||
pgrep -f "test-agent" > /dev/null 2>&1 || break
|
||||
sleep 0.5
|
||||
done
|
||||
pkill -9 -f "test-agent" 2>/dev/null || true
|
||||
sleep 0.5
|
||||
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
|
||||
PG_BIN_DIR=$(find_pg_bin_dir)
|
||||
echo "Using PG bin dir: $PG_BIN_DIR"
|
||||
|
||||
# Copy agent binary
|
||||
cp "$ARTIFACTS/agent-v1" "$AGENT"
|
||||
chmod +x "$AGENT"
|
||||
|
||||
# Move pg_basebackup out of PATH into custom directory
|
||||
# Copy pg_basebackup to a custom directory (simulates non-PATH installation)
|
||||
mkdir -p "$CUSTOM_BIN_DIR"
|
||||
cp "$(which pg_basebackup)" "$CUSTOM_BIN_DIR/pg_basebackup"
|
||||
cp "$PG_BIN_DIR/pg_basebackup" "$CUSTOM_BIN_DIR/pg_basebackup"
|
||||
|
||||
# Hide the system one by prepending an empty dir to PATH
|
||||
export PATH="/opt/empty-path:$PATH"
|
||||
mkdir -p /opt/empty-path
|
||||
echo "=== Phase 1: Setup agent ==="
|
||||
setup_agent
|
||||
|
||||
# Verify pg_basebackup is NOT directly callable from default location
|
||||
# (we copied it, but the original is still there in debian — so we test
|
||||
# that the agent uses the custom dir, not PATH, by checking the output)
|
||||
echo "=== Phase 2: Initialize PostgreSQL ==="
|
||||
init_pg_local "$PGDATA" "$PG_PORT" "$WAL_QUEUE" "$PG_BIN_DIR"
|
||||
|
||||
# Run start with --skip-update and custom bin dir
|
||||
echo "Running agent start (pg_basebackup via --pg-host-bin-dir)..."
|
||||
OUTPUT=$("$AGENT" start \
|
||||
--skip-update \
|
||||
--databasus-host http://e2e-mock-server:4050 \
|
||||
--db-id test-db-id \
|
||||
--token test-token \
|
||||
--pg-host e2e-postgres \
|
||||
--pg-port 5432 \
|
||||
--pg-user testuser \
|
||||
--pg-password testpassword \
|
||||
--pg-wal-dir /tmp/wal \
|
||||
--pg-type host \
|
||||
--pg-host-bin-dir "$CUSTOM_BIN_DIR" 2>&1)
|
||||
echo "=== Phase 3: Insert test data ==="
|
||||
insert_test_data "$PG_PORT" "$PG_BIN_DIR"
|
||||
|
||||
EXIT_CODE=$?
|
||||
echo "$OUTPUT"
|
||||
echo "=== Phase 4: Force checkpoint and start agent backup (using --pg-host-bin-dir) ==="
|
||||
force_checkpoint "$PG_PORT" "$PG_BIN_DIR"
|
||||
run_agent_backup "$MOCK_SERVER" "127.0.0.1" "$PG_PORT" "$WAL_QUEUE" "host" "$CUSTOM_BIN_DIR"
|
||||
|
||||
if [ "$EXIT_CODE" -ne 0 ]; then
|
||||
echo "FAIL: Agent exited with code $EXIT_CODE"
|
||||
exit 1
|
||||
fi
|
||||
echo "=== Phase 5: Generate WAL in background ==="
|
||||
generate_wal_background "$PG_PORT" "$PG_BIN_DIR" &
|
||||
WAL_GEN_PID=$!
|
||||
|
||||
if ! echo "$OUTPUT" | grep -q "pg_basebackup verified"; then
|
||||
echo "FAIL: Expected output to contain 'pg_basebackup verified'"
|
||||
exit 1
|
||||
fi
|
||||
echo "=== Phase 6: Wait for backup to complete ==="
|
||||
wait_for_backup_complete "$MOCK_SERVER" 120
|
||||
|
||||
if ! echo "$OUTPUT" | grep -q "PostgreSQL connection verified"; then
|
||||
echo "FAIL: Expected output to contain 'PostgreSQL connection verified'"
|
||||
exit 1
|
||||
fi
|
||||
echo "=== Phase 7: Stop WAL generator, agent, and PostgreSQL ==="
|
||||
kill $WAL_GEN_PID 2>/dev/null || true
|
||||
wait $WAL_GEN_PID 2>/dev/null || true
|
||||
stop_agent
|
||||
stop_pg "$PGDATA" "$PG_BIN_DIR"
|
||||
|
||||
echo "pg_basebackup found via custom bin dir and DB connection verified"
|
||||
echo "=== Phase 8: Restore ==="
|
||||
run_agent_restore "$MOCK_SERVER" "$RESTORE_PGDATA"
|
||||
|
||||
echo "=== Phase 9: Start PostgreSQL on restored data ==="
|
||||
start_restored_pg "$RESTORE_PGDATA" "$PG_PORT" "$PG_BIN_DIR"
|
||||
|
||||
echo "=== Phase 10: Wait for recovery ==="
|
||||
wait_for_recovery_complete "$PG_PORT" "$PG_BIN_DIR" 60
|
||||
|
||||
echo "=== Phase 11: Verify data ==="
|
||||
verify_restored_data "$PG_PORT" "$PG_BIN_DIR"
|
||||
|
||||
echo "=== Phase 12: Cleanup ==="
|
||||
stop_pg "$RESTORE_PGDATA" "$PG_BIN_DIR"
|
||||
|
||||
echo "pg_basebackup via custom bindir: full backup-restore lifecycle passed"
|
||||
|
||||
@@ -1,22 +1,17 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
ARTIFACTS="/opt/agent/artifacts"
|
||||
AGENT="/tmp/test-agent"
|
||||
SCRIPT_DIR="$(dirname "$0")"
|
||||
source "$SCRIPT_DIR/backup-restore-helpers.sh"
|
||||
|
||||
# Cleanup from previous runs
|
||||
pkill -f "test-agent" 2>/dev/null || true
|
||||
for i in $(seq 1 20); do
|
||||
pgrep -f "test-agent" > /dev/null 2>&1 || break
|
||||
sleep 0.5
|
||||
done
|
||||
pkill -9 -f "test-agent" 2>/dev/null || true
|
||||
sleep 0.5
|
||||
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
|
||||
MOCK_SERVER="${MOCK_SERVER_OVERRIDE:-http://e2e-mock-server:4050}"
|
||||
PGDATA="/tmp/pgdata"
|
||||
RESTORE_PGDATA="/tmp/restore-pgdata"
|
||||
WAL_QUEUE="/tmp/wal-queue"
|
||||
PG_PORT=5433
|
||||
|
||||
# Copy agent binary
|
||||
cp "$ARTIFACTS/agent-v1" "$AGENT"
|
||||
chmod +x "$AGENT"
|
||||
PG_BIN_DIR=$(find_pg_bin_dir)
|
||||
echo "Using PG bin dir: $PG_BIN_DIR"
|
||||
|
||||
# Verify pg_basebackup is in PATH
|
||||
if ! which pg_basebackup > /dev/null 2>&1; then
|
||||
@@ -24,36 +19,45 @@ if ! which pg_basebackup > /dev/null 2>&1; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Run start with --skip-update and pg-type=host
|
||||
echo "Running agent start (pg_basebackup in PATH)..."
|
||||
OUTPUT=$("$AGENT" start \
|
||||
--skip-update \
|
||||
--databasus-host http://e2e-mock-server:4050 \
|
||||
--db-id test-db-id \
|
||||
--token test-token \
|
||||
--pg-host e2e-postgres \
|
||||
--pg-port 5432 \
|
||||
--pg-user testuser \
|
||||
--pg-password testpassword \
|
||||
--pg-wal-dir /tmp/wal \
|
||||
--pg-type host 2>&1)
|
||||
echo "=== Phase 1: Setup agent ==="
|
||||
setup_agent
|
||||
|
||||
EXIT_CODE=$?
|
||||
echo "$OUTPUT"
|
||||
echo "=== Phase 2: Initialize PostgreSQL ==="
|
||||
init_pg_local "$PGDATA" "$PG_PORT" "$WAL_QUEUE" "$PG_BIN_DIR"
|
||||
|
||||
if [ "$EXIT_CODE" -ne 0 ]; then
|
||||
echo "FAIL: Agent exited with code $EXIT_CODE"
|
||||
exit 1
|
||||
fi
|
||||
echo "=== Phase 3: Insert test data ==="
|
||||
insert_test_data "$PG_PORT" "$PG_BIN_DIR"
|
||||
|
||||
if ! echo "$OUTPUT" | grep -q "pg_basebackup verified"; then
|
||||
echo "FAIL: Expected output to contain 'pg_basebackup verified'"
|
||||
exit 1
|
||||
fi
|
||||
echo "=== Phase 4: Force checkpoint and start agent backup ==="
|
||||
force_checkpoint "$PG_PORT" "$PG_BIN_DIR"
|
||||
run_agent_backup "$MOCK_SERVER" "127.0.0.1" "$PG_PORT" "$WAL_QUEUE" "host"
|
||||
|
||||
if ! echo "$OUTPUT" | grep -q "PostgreSQL connection verified"; then
|
||||
echo "FAIL: Expected output to contain 'PostgreSQL connection verified'"
|
||||
exit 1
|
||||
fi
|
||||
echo "=== Phase 5: Generate WAL in background ==="
|
||||
generate_wal_background "$PG_PORT" "$PG_BIN_DIR" &
|
||||
WAL_GEN_PID=$!
|
||||
|
||||
echo "pg_basebackup found in PATH and DB connection verified"
|
||||
echo "=== Phase 6: Wait for backup to complete ==="
|
||||
wait_for_backup_complete "$MOCK_SERVER" 120
|
||||
|
||||
echo "=== Phase 7: Stop WAL generator, agent, and PostgreSQL ==="
|
||||
kill $WAL_GEN_PID 2>/dev/null || true
|
||||
wait $WAL_GEN_PID 2>/dev/null || true
|
||||
stop_agent
|
||||
stop_pg "$PGDATA" "$PG_BIN_DIR"
|
||||
|
||||
echo "=== Phase 8: Restore ==="
|
||||
run_agent_restore "$MOCK_SERVER" "$RESTORE_PGDATA"
|
||||
|
||||
echo "=== Phase 9: Start PostgreSQL on restored data ==="
|
||||
start_restored_pg "$RESTORE_PGDATA" "$PG_PORT" "$PG_BIN_DIR"
|
||||
|
||||
echo "=== Phase 10: Wait for recovery ==="
|
||||
wait_for_recovery_complete "$PG_PORT" "$PG_BIN_DIR" 60
|
||||
|
||||
echo "=== Phase 11: Verify data ==="
|
||||
verify_restored_data "$PG_PORT" "$PG_BIN_DIR"
|
||||
|
||||
echo "=== Phase 12: Cleanup ==="
|
||||
stop_pg "$RESTORE_PGDATA" "$PG_BIN_DIR"
|
||||
|
||||
echo "pg_basebackup in PATH: full backup-restore lifecycle passed"
|
||||
|
||||
@@ -59,3 +59,6 @@ if [ "$VERSION" != "v1.0.0" ]; then
|
||||
fi
|
||||
|
||||
echo "Upgrade correctly skipped, version still $VERSION"
|
||||
|
||||
# Cleanup daemon
|
||||
"$AGENT" stop || true
|
||||
|
||||
@@ -64,3 +64,6 @@ if [ "$VERSION" != "v2.0.0" ]; then
|
||||
fi
|
||||
|
||||
echo "Binary upgraded successfully to $VERSION"
|
||||
|
||||
# Cleanup daemon
|
||||
"$AGENT" stop || true
|
||||
|
||||
@@ -110,8 +110,7 @@ func (c *Config) applyDefaults() {
|
||||
}
|
||||
|
||||
if c.IsDeleteWalAfterUpload == nil {
|
||||
v := true
|
||||
c.IsDeleteWalAfterUpload = &v
|
||||
c.IsDeleteWalAfterUpload = new(true)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
@@ -14,25 +15,30 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
chainValidPath = "/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup"
|
||||
nextBackupTimePath = "/api/v1/backups/postgres/wal/next-full-backup-time"
|
||||
walUploadPath = "/api/v1/backups/postgres/wal/upload/wal"
|
||||
fullStartPath = "/api/v1/backups/postgres/wal/upload/full-start"
|
||||
fullCompletePath = "/api/v1/backups/postgres/wal/upload/full-complete"
|
||||
reportErrorPath = "/api/v1/backups/postgres/wal/error"
|
||||
versionPath = "/api/v1/system/version"
|
||||
agentBinaryPath = "/api/v1/system/agent"
|
||||
chainValidPath = "/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup"
|
||||
nextBackupTimePath = "/api/v1/backups/postgres/wal/next-full-backup-time"
|
||||
walUploadPath = "/api/v1/backups/postgres/wal/upload/wal"
|
||||
fullStartPath = "/api/v1/backups/postgres/wal/upload/full-start"
|
||||
fullCompletePath = "/api/v1/backups/postgres/wal/upload/full-complete"
|
||||
reportErrorPath = "/api/v1/backups/postgres/wal/error"
|
||||
restorePlanPath = "/api/v1/backups/postgres/wal/restore/plan"
|
||||
restoreDownloadPath = "/api/v1/backups/postgres/wal/restore/download"
|
||||
versionPath = "/api/v1/system/version"
|
||||
agentBinaryPath = "/api/v1/system/agent"
|
||||
|
||||
apiCallTimeout = 30 * time.Second
|
||||
maxRetryAttempts = 3
|
||||
retryBaseDelay = 1 * time.Second
|
||||
)
|
||||
|
||||
// For stream uploads (basebackup and WAL segments) the standard resty client is not used,
|
||||
// because it buffers the entire body in memory before sending.
|
||||
type Client struct {
|
||||
json *resty.Client
|
||||
stream *resty.Client
|
||||
host string
|
||||
log *slog.Logger
|
||||
json *resty.Client
|
||||
streamHTTP *http.Client
|
||||
host string
|
||||
token string
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
func NewClient(host, token string, log *slog.Logger) *Client {
|
||||
@@ -54,14 +60,12 @@ func NewClient(host, token string, log *slog.Logger) *Client {
|
||||
}).
|
||||
OnBeforeRequest(setAuth)
|
||||
|
||||
streamClient := resty.New().
|
||||
OnBeforeRequest(setAuth)
|
||||
|
||||
return &Client{
|
||||
json: jsonClient,
|
||||
stream: streamClient,
|
||||
host: host,
|
||||
log: log,
|
||||
json: jsonClient,
|
||||
streamHTTP: &http.Client{},
|
||||
host: host,
|
||||
token: token,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,25 +121,28 @@ func (c *Client) UploadBasebackup(
|
||||
ctx context.Context,
|
||||
body io.Reader,
|
||||
) (*UploadBasebackupResponse, error) {
|
||||
resp, err := c.stream.R().
|
||||
SetContext(ctx).
|
||||
SetBody(body).
|
||||
SetHeader("Content-Type", "application/octet-stream").
|
||||
SetDoNotParseResponse(true).
|
||||
Post(c.buildURL(fullStartPath))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.buildURL(fullStartPath), body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create upload request: %w", err)
|
||||
}
|
||||
|
||||
c.setStreamHeaders(req)
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
|
||||
resp, err := c.streamHTTP.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upload request: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.RawBody().Close() }()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode() != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.RawBody())
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
return nil, fmt.Errorf("upload failed with status %d: %s", resp.StatusCode(), string(respBody))
|
||||
return nil, fmt.Errorf("upload failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result UploadBasebackupResponse
|
||||
if err := json.NewDecoder(resp.RawBody()).Decode(&result); err != nil {
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("decode upload response: %w", err)
|
||||
}
|
||||
|
||||
@@ -195,26 +202,29 @@ func (c *Client) UploadWalSegment(
|
||||
segmentName string,
|
||||
body io.Reader,
|
||||
) (*UploadWalSegmentResult, error) {
|
||||
resp, err := c.stream.R().
|
||||
SetContext(ctx).
|
||||
SetBody(body).
|
||||
SetHeader("Content-Type", "application/octet-stream").
|
||||
SetHeader("X-Wal-Segment-Name", segmentName).
|
||||
SetDoNotParseResponse(true).
|
||||
Post(c.buildURL(walUploadPath))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.buildURL(walUploadPath), body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create WAL upload request: %w", err)
|
||||
}
|
||||
|
||||
c.setStreamHeaders(req)
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
req.Header.Set("X-Wal-Segment-Name", segmentName)
|
||||
|
||||
resp, err := c.streamHTTP.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upload request: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.RawBody().Close() }()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
switch resp.StatusCode() {
|
||||
switch resp.StatusCode {
|
||||
case http.StatusNoContent:
|
||||
return &UploadWalSegmentResult{IsGapDetected: false}, nil
|
||||
|
||||
case http.StatusConflict:
|
||||
var errResp uploadErrorResponse
|
||||
|
||||
if err := json.NewDecoder(resp.RawBody()).Decode(&errResp); err != nil {
|
||||
if err := json.NewDecoder(resp.Body).Decode(&errResp); err != nil {
|
||||
return &UploadWalSegmentResult{IsGapDetected: true}, nil
|
||||
}
|
||||
|
||||
@@ -225,12 +235,79 @@ func (c *Client) UploadWalSegment(
|
||||
}, nil
|
||||
|
||||
default:
|
||||
respBody, _ := io.ReadAll(resp.RawBody())
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
return nil, fmt.Errorf("upload failed with status %d: %s", resp.StatusCode(), string(respBody))
|
||||
return nil, fmt.Errorf("upload failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) GetRestorePlan(
|
||||
ctx context.Context,
|
||||
backupID string,
|
||||
) (*GetRestorePlanResponse, *GetRestorePlanErrorResponse, error) {
|
||||
request := c.json.R().SetContext(ctx)
|
||||
|
||||
if backupID != "" {
|
||||
request.SetQueryParam("backupId", backupID)
|
||||
}
|
||||
|
||||
httpResp, err := request.Get(c.buildURL(restorePlanPath))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("get restore plan: %w", err)
|
||||
}
|
||||
|
||||
switch httpResp.StatusCode() {
|
||||
case http.StatusOK:
|
||||
var response GetRestorePlanResponse
|
||||
if err := json.Unmarshal(httpResp.Body(), &response); err != nil {
|
||||
return nil, nil, fmt.Errorf("decode restore plan response: %w", err)
|
||||
}
|
||||
|
||||
return &response, nil, nil
|
||||
|
||||
case http.StatusBadRequest:
|
||||
var errorResponse GetRestorePlanErrorResponse
|
||||
if err := json.Unmarshal(httpResp.Body(), &errorResponse); err != nil {
|
||||
return nil, nil, fmt.Errorf("decode restore plan error: %w", err)
|
||||
}
|
||||
|
||||
return nil, &errorResponse, nil
|
||||
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("get restore plan: server returned status %d: %s",
|
||||
httpResp.StatusCode(), httpResp.String())
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) DownloadBackupFile(
|
||||
ctx context.Context,
|
||||
backupID string,
|
||||
) (io.ReadCloser, error) {
|
||||
requestURL := c.buildURL(restoreDownloadPath) + "?" + url.Values{"backupId": {backupID}}.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create download request: %w", err)
|
||||
}
|
||||
|
||||
c.setStreamHeaders(req)
|
||||
|
||||
resp, err := c.streamHTTP.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("download backup file: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
|
||||
return nil, fmt.Errorf("download backup file: server returned status %d: %s",
|
||||
resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return resp.Body, nil
|
||||
}
|
||||
|
||||
func (c *Client) FetchServerVersion(ctx context.Context) (string, error) {
|
||||
var ver versionResponse
|
||||
|
||||
@@ -250,27 +327,32 @@ func (c *Client) FetchServerVersion(ctx context.Context) (string, error) {
|
||||
}
|
||||
|
||||
func (c *Client) DownloadAgentBinary(ctx context.Context, arch, destPath string) error {
|
||||
resp, err := c.stream.R().
|
||||
SetContext(ctx).
|
||||
SetQueryParam("arch", arch).
|
||||
SetDoNotParseResponse(true).
|
||||
Get(c.buildURL(agentBinaryPath))
|
||||
requestURL := c.buildURL(agentBinaryPath) + "?" + url.Values{"arch": {arch}}.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create agent download request: %w", err)
|
||||
}
|
||||
|
||||
c.setStreamHeaders(req)
|
||||
|
||||
resp, err := c.streamHTTP.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = resp.RawBody().Close() }()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode() != http.StatusOK {
|
||||
return fmt.Errorf("server returned %d for agent download", resp.StatusCode())
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("server returned %d for agent download", resp.StatusCode)
|
||||
}
|
||||
|
||||
f, err := os.Create(destPath)
|
||||
file, err := os.Create(destPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
defer func() { _ = file.Close() }()
|
||||
|
||||
_, err = io.Copy(f, resp.RawBody())
|
||||
_, err = io.Copy(file, resp.Body)
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -286,3 +368,9 @@ func (c *Client) checkResponse(resp *resty.Response, method string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) setStreamHeaders(req *http.Request) {
|
||||
if c.token != "" {
|
||||
req.Header.Set("Authorization", c.token)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,3 +42,31 @@ type uploadErrorResponse struct {
|
||||
ExpectedSegmentName string `json:"expectedSegmentName"`
|
||||
ReceivedSegmentName string `json:"receivedSegmentName"`
|
||||
}
|
||||
|
||||
type RestorePlanFullBackup struct {
|
||||
BackupID string `json:"id"`
|
||||
FullBackupWalStartSegment string `json:"fullBackupWalStartSegment"`
|
||||
FullBackupWalStopSegment string `json:"fullBackupWalStopSegment"`
|
||||
PgVersion string `json:"pgVersion"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
SizeBytes int64 `json:"sizeBytes"`
|
||||
}
|
||||
|
||||
type RestorePlanWalSegment struct {
|
||||
BackupID string `json:"backupId"`
|
||||
SegmentName string `json:"segmentName"`
|
||||
SizeBytes int64 `json:"sizeBytes"`
|
||||
}
|
||||
|
||||
type GetRestorePlanResponse struct {
|
||||
FullBackup RestorePlanFullBackup `json:"fullBackup"`
|
||||
WalSegments []RestorePlanWalSegment `json:"walSegments"`
|
||||
TotalSizeBytes int64 `json:"totalSizeBytes"`
|
||||
LatestAvailableSegment string `json:"latestAvailableSegment"`
|
||||
}
|
||||
|
||||
type GetRestorePlanErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
Message string `json:"message"`
|
||||
LastContiguousSegment string `json:"lastContiguousSegment,omitempty"`
|
||||
}
|
||||
|
||||
60
agent/internal/features/api/idle_timeout_reader.go
Normal file
60
agent/internal/features/api/idle_timeout_reader.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
// IdleTimeoutReader wraps an io.Reader and cancels the associated context
|
||||
// if no bytes are successfully read within the specified timeout duration.
|
||||
// This detects stalled uploads where the network or source stops transmitting data.
|
||||
//
|
||||
// When the idle timeout fires, the reader is also closed (if it implements io.Closer)
|
||||
// to unblock any goroutine blocked on the underlying Read.
|
||||
type IdleTimeoutReader struct {
|
||||
reader io.Reader
|
||||
timeout time.Duration
|
||||
cancel context.CancelCauseFunc
|
||||
timer *time.Timer
|
||||
}
|
||||
|
||||
// NewIdleTimeoutReader creates a reader that cancels the context via cancel
|
||||
// if Read does not return any bytes for the given timeout duration.
|
||||
func NewIdleTimeoutReader(reader io.Reader, timeout time.Duration, cancel context.CancelCauseFunc) *IdleTimeoutReader {
|
||||
r := &IdleTimeoutReader{
|
||||
reader: reader,
|
||||
timeout: timeout,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
r.timer = time.AfterFunc(timeout, func() {
|
||||
cancel(fmt.Errorf("upload idle timeout: no bytes transmitted for %v", timeout))
|
||||
|
||||
if closer, ok := reader.(io.Closer); ok {
|
||||
_ = closer.Close()
|
||||
}
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *IdleTimeoutReader) Read(p []byte) (int, error) {
|
||||
n, err := r.reader.Read(p)
|
||||
|
||||
if n > 0 {
|
||||
r.timer.Reset(r.timeout)
|
||||
}
|
||||
|
||||
if err != nil && err != io.EOF {
|
||||
r.Stop()
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Stop cancels the idle timer. Must be called when the reader is no longer needed.
|
||||
func (r *IdleTimeoutReader) Stop() {
|
||||
r.timer.Stop()
|
||||
}
|
||||
112
agent/internal/features/api/idle_timeout_reader_test.go
Normal file
112
agent/internal/features/api/idle_timeout_reader_test.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_ReadThroughIdleTimeoutReader_WhenBytesFlowContinuously_DoesNotCancelContext(t *testing.T) {
|
||||
ctx, cancel := context.WithCancelCause(t.Context())
|
||||
defer cancel(nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
idleReader := NewIdleTimeoutReader(pr, 200*time.Millisecond, cancel)
|
||||
defer idleReader.Stop()
|
||||
|
||||
go func() {
|
||||
for range 5 {
|
||||
_, _ = pw.Write([]byte("data"))
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
_ = pw.Close()
|
||||
}()
|
||||
|
||||
data, err := io.ReadAll(idleReader)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "datadatadatadatadata", string(data))
|
||||
assert.NoError(t, ctx.Err(), "context should not be cancelled when bytes flow continuously")
|
||||
}
|
||||
|
||||
func Test_ReadThroughIdleTimeoutReader_WhenNoBytesTransmitted_CancelsContext(t *testing.T) {
|
||||
ctx, cancel := context.WithCancelCause(t.Context())
|
||||
defer cancel(nil)
|
||||
|
||||
pr, _ := io.Pipe()
|
||||
|
||||
idleReader := NewIdleTimeoutReader(pr, 100*time.Millisecond, cancel)
|
||||
defer idleReader.Stop()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
assert.Error(t, ctx.Err(), "context should be cancelled when no bytes are transmitted")
|
||||
assert.Contains(t, context.Cause(ctx).Error(), "upload idle timeout")
|
||||
}
|
||||
|
||||
func Test_ReadThroughIdleTimeoutReader_WhenBytesStopMidStream_CancelsContext(t *testing.T) {
|
||||
ctx, cancel := context.WithCancelCause(t.Context())
|
||||
defer cancel(nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
idleReader := NewIdleTimeoutReader(pr, 100*time.Millisecond, cancel)
|
||||
defer idleReader.Stop()
|
||||
|
||||
go func() {
|
||||
_, _ = pw.Write([]byte("initial"))
|
||||
// Stop writing — simulate stalled source
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
n, _ := idleReader.Read(buf)
|
||||
assert.Equal(t, "initial", string(buf[:n]))
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
assert.Error(t, ctx.Err(), "context should be cancelled when bytes stop mid-stream")
|
||||
assert.Contains(t, context.Cause(ctx).Error(), "upload idle timeout")
|
||||
}
|
||||
|
||||
func Test_StopIdleTimeoutReader_WhenCalledBeforeTimeout_DoesNotCancelContext(t *testing.T) {
|
||||
ctx, cancel := context.WithCancelCause(t.Context())
|
||||
defer cancel(nil)
|
||||
|
||||
pr, _ := io.Pipe()
|
||||
|
||||
idleReader := NewIdleTimeoutReader(pr, 100*time.Millisecond, cancel)
|
||||
idleReader.Stop()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
assert.NoError(t, ctx.Err(), "context should not be cancelled when reader is stopped before timeout")
|
||||
}
|
||||
|
||||
func Test_ReadThroughIdleTimeoutReader_WhenReaderReturnsError_PropagatesError(t *testing.T) {
|
||||
ctx, cancel := context.WithCancelCause(t.Context())
|
||||
defer cancel(nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
idleReader := NewIdleTimeoutReader(pr, 5*time.Second, cancel)
|
||||
defer idleReader.Stop()
|
||||
|
||||
expectedErr := fmt.Errorf("test read error")
|
||||
_ = pw.CloseWithError(expectedErr)
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
_, err := idleReader.Read(buf)
|
||||
|
||||
assert.ErrorIs(t, err, expectedErr)
|
||||
|
||||
// Timer should be stopped after error — context should not be cancelled
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
assert.NoError(t, ctx.Err(), "context should not be cancelled after reader error stops the timer")
|
||||
}
|
||||
@@ -21,9 +21,11 @@ import (
|
||||
const (
|
||||
checkInterval = 30 * time.Second
|
||||
retryDelay = 1 * time.Minute
|
||||
uploadTimeout = 30 * time.Minute
|
||||
uploadTimeout = 23 * time.Hour
|
||||
)
|
||||
|
||||
var uploadIdleTimeout = 5 * time.Minute
|
||||
|
||||
var retryDelayOverride *time.Duration
|
||||
|
||||
type CmdBuilder func(ctx context.Context) *exec.Cmd
|
||||
@@ -38,8 +40,9 @@ type CmdBuilder func(ctx context.Context) *exec.Cmd
|
||||
// On failure the error is reported to the server and the backup retries after 1 minute, indefinitely.
|
||||
// WAL segment uploads (handled by wal.Streamer) continue independently and are not paused.
|
||||
//
|
||||
// pg_basebackup runs as "pg_basebackup -Ft -D - -X none --verbose". Stdout (tar) is zstd-compressed
|
||||
// and uploaded to the server. Stderr is parsed for WAL start/stop segment names (LSN → segment arithmetic).
|
||||
// pg_basebackup runs as "pg_basebackup -Ft -D - -X fetch --verbose --checkpoint=fast".
|
||||
// Stdout (tar) is zstd-compressed and uploaded to the server.
|
||||
// Stderr is parsed for WAL start/stop segment names (LSN → segment arithmetic).
|
||||
type FullBackuper struct {
|
||||
cfg *config.Config
|
||||
apiClient *api.Client
|
||||
@@ -175,16 +178,37 @@ func (backuper *FullBackuper) executeAndUploadBasebackup(ctx context.Context) er
|
||||
|
||||
// Phase 1: Stream compressed data via io.Pipe directly to the API.
|
||||
pipeReader, pipeWriter := io.Pipe()
|
||||
defer func() { _ = pipeReader.Close() }()
|
||||
|
||||
go backuper.compressAndStream(pipeWriter, stdoutPipe)
|
||||
|
||||
uploadCtx, cancel := context.WithTimeout(ctx, uploadTimeout)
|
||||
defer cancel()
|
||||
uploadCtx, timeoutCancel := context.WithTimeout(ctx, uploadTimeout)
|
||||
defer timeoutCancel()
|
||||
|
||||
uploadResp, uploadErr := backuper.apiClient.UploadBasebackup(uploadCtx, pipeReader)
|
||||
idleCtx, idleCancel := context.WithCancelCause(uploadCtx)
|
||||
defer idleCancel(nil)
|
||||
|
||||
idleReader := api.NewIdleTimeoutReader(pipeReader, uploadIdleTimeout, idleCancel)
|
||||
defer idleReader.Stop()
|
||||
|
||||
uploadResp, uploadErr := backuper.apiClient.UploadBasebackup(idleCtx, idleReader)
|
||||
|
||||
if uploadErr != nil && cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
|
||||
cmdErr := cmd.Wait()
|
||||
|
||||
if uploadErr != nil {
|
||||
if cause := context.Cause(idleCtx); cause != nil {
|
||||
uploadErr = cause
|
||||
}
|
||||
|
||||
stderrStr := stderrBuf.String()
|
||||
if stderrStr != "" {
|
||||
return fmt.Errorf("upload basebackup: %w (pg_basebackup stderr: %s)", uploadErr, stderrStr)
|
||||
}
|
||||
|
||||
return fmt.Errorf("upload basebackup: %w", uploadErr)
|
||||
}
|
||||
|
||||
@@ -192,7 +216,7 @@ func (backuper *FullBackuper) executeAndUploadBasebackup(ctx context.Context) er
|
||||
errMsg := fmt.Sprintf("pg_basebackup exited with error: %v (stderr: %s)", cmdErr, stderrBuf.String())
|
||||
_ = backuper.apiClient.FinalizeBasebackupWithError(ctx, uploadResp.BackupID, errMsg)
|
||||
|
||||
return fmt.Errorf("pg_basebackup: %w", cmdErr)
|
||||
return fmt.Errorf("%s", errMsg)
|
||||
}
|
||||
|
||||
// Phase 2: Parse stderr for WAL segments and finalize the backup.
|
||||
@@ -266,7 +290,7 @@ func (backuper *FullBackuper) buildHostCmd(ctx context.Context) *exec.Cmd {
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, binary,
|
||||
"-Ft", "-D", "-", "-X", "none", "--verbose",
|
||||
"-Ft", "-D", "-", "-X", "fetch", "--verbose", "--checkpoint=fast",
|
||||
"-h", backuper.cfg.PgHost,
|
||||
"-p", fmt.Sprintf("%d", backuper.cfg.PgPort),
|
||||
"-U", backuper.cfg.PgUser,
|
||||
@@ -282,9 +306,9 @@ func (backuper *FullBackuper) buildDockerCmd(ctx context.Context) *exec.Cmd {
|
||||
"-e", "PGPASSWORD="+backuper.cfg.PgPassword,
|
||||
"-i", backuper.cfg.PgDockerContainerName,
|
||||
"pg_basebackup",
|
||||
"-Ft", "-D", "-", "-X", "none", "--verbose",
|
||||
"-h", backuper.cfg.PgHost,
|
||||
"-p", fmt.Sprintf("%d", backuper.cfg.PgPort),
|
||||
"-Ft", "-D", "-", "-X", "fetch", "--verbose", "--checkpoint=fast",
|
||||
"-h", "localhost",
|
||||
"-p", "5432",
|
||||
"-U", backuper.cfg.PgUser,
|
||||
)
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ func Test_RunFullBackup_WhenChainBroken_BasebackupTriggered(t *testing.T) {
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "test-backup-data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
@@ -124,7 +124,7 @@ func Test_RunFullBackup_WhenScheduledBackupDue_BasebackupTriggered(t *testing.T)
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "scheduled-backup-data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
@@ -169,7 +169,7 @@ func Test_RunFullBackup_WhenNoFullBackupExists_ImmediateBasebackupTriggered(t *t
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "first-backup-data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
@@ -233,7 +233,7 @@ func Test_RunFullBackup_WhenUploadFails_RetriesAfterDelay(t *testing.T) {
|
||||
setRetryDelay(100 * time.Millisecond)
|
||||
defer setRetryDelay(origRetryDelay)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
@@ -282,7 +282,7 @@ func Test_RunFullBackup_WhenAlreadyRunning_SkipsExecution(t *testing.T) {
|
||||
|
||||
fb.isRunning.Store(true)
|
||||
|
||||
fb.checkAndRunIfNeeded(context.Background())
|
||||
fb.checkAndRunIfNeeded(t.Context())
|
||||
|
||||
mu.Lock()
|
||||
count := uploadCount
|
||||
@@ -318,7 +318,7 @@ func Test_RunFullBackup_WhenContextCancelled_StopsCleanly(t *testing.T) {
|
||||
setRetryDelay(5 * time.Second)
|
||||
defer setRetryDelay(origRetryDelay)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
@@ -360,7 +360,7 @@ func Test_RunFullBackup_WhenChainValidAndNotScheduled_NoBasebackupTriggered(t *t
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
@@ -411,7 +411,7 @@ func Test_RunFullBackup_WhenStderrParsingFails_FinalizesWithErrorAndRetries(t *t
|
||||
setRetryDelay(100 * time.Millisecond)
|
||||
defer setRetryDelay(origRetryDelay)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
@@ -458,7 +458,7 @@ func Test_RunFullBackup_WhenNextBackupTimeNull_BasebackupTriggered(t *testing.T)
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "first-run-data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
@@ -498,7 +498,7 @@ func Test_RunFullBackup_WhenChainValidityReturns401_NoBasebackupTriggered(t *tes
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
@@ -538,7 +538,7 @@ func Test_RunFullBackup_WhenUploadSucceeds_BodyIsZstdCompressed(t *testing.T) {
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, originalContent, validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
@@ -562,6 +562,68 @@ func Test_RunFullBackup_WhenUploadSucceeds_BodyIsZstdCompressed(t *testing.T) {
|
||||
assert.Equal(t, originalContent, string(decompressed))
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenUploadStalls_FailsWithIdleTimeout(t *testing.T) {
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testFullStartPath:
|
||||
// Server reads body normally — it will block until connection is closed
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = stallingCmdBuilder(t)
|
||||
|
||||
origIdleTimeout := uploadIdleTimeout
|
||||
uploadIdleTimeout = 200 * time.Millisecond
|
||||
defer func() { uploadIdleTimeout = origIdleTimeout }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := fb.executeAndUploadBasebackup(ctx)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "idle timeout", "error should mention idle timeout")
|
||||
}
|
||||
|
||||
func stallingCmdBuilder(t *testing.T) CmdBuilder {
|
||||
t.Helper()
|
||||
|
||||
return func(ctx context.Context) *exec.Cmd {
|
||||
cmd := exec.CommandContext(ctx, os.Args[0],
|
||||
"-test.run=TestHelperProcessStalling",
|
||||
"--",
|
||||
)
|
||||
|
||||
cmd.Env = append(os.Environ(), "GO_TEST_HELPER_PROCESS_STALLING=1")
|
||||
|
||||
return cmd
|
||||
}
|
||||
}
|
||||
|
||||
func TestHelperProcessStalling(t *testing.T) {
|
||||
if os.Getenv("GO_TEST_HELPER_PROCESS_STALLING") != "1" {
|
||||
return
|
||||
}
|
||||
|
||||
// Write enough data to flush through the zstd encoder's internal buffer (~128KB blocks).
|
||||
// Without enough data, zstd buffers everything and the pipe never receives bytes.
|
||||
data := make([]byte, 256*1024)
|
||||
for i := range data {
|
||||
data[i] = byte(i)
|
||||
}
|
||||
_, _ = os.Stdout.Write(data)
|
||||
|
||||
// Stall with stdout open — the compress goroutine blocks on its next read.
|
||||
// The parent process will kill us when the context is cancelled.
|
||||
time.Sleep(time.Hour)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func newTestServer(t *testing.T, handler http.HandlerFunc) *httptest.Server {
|
||||
t.Helper()
|
||||
|
||||
@@ -632,9 +694,11 @@ func TestHelperProcess(t *testing.T) {
|
||||
func validStderr() string {
|
||||
return `pg_basebackup: initiating base backup, waiting for checkpoint to complete
|
||||
pg_basebackup: checkpoint completed
|
||||
pg_basebackup: write-ahead log start point: 0/2000028, on timeline 1
|
||||
pg_basebackup: checkpoint redo point at 0/2000028
|
||||
pg_basebackup: write-ahead log start point: 0/2000028 on timeline 1
|
||||
pg_basebackup: starting background WAL receiver
|
||||
pg_basebackup: write-ahead log end point: 0/2000100
|
||||
pg_basebackup: waiting for background process to finish streaming ...
|
||||
pg_basebackup: syncing data to disk ...
|
||||
pg_basebackup: base backup completed`
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
const defaultWalSegmentSize uint32 = 16 * 1024 * 1024 // 16 MB
|
||||
|
||||
var (
|
||||
startLSNRegex = regexp.MustCompile(`checkpoint redo point at ([0-9A-Fa-f]+/[0-9A-Fa-f]+)`)
|
||||
startLSNRegex = regexp.MustCompile(`write-ahead log start point: ([0-9A-Fa-f]+/[0-9A-Fa-f]+)`)
|
||||
stopLSNRegex = regexp.MustCompile(`write-ahead log end point: ([0-9A-Fa-f]+/[0-9A-Fa-f]+)`)
|
||||
)
|
||||
|
||||
|
||||
@@ -7,12 +7,11 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_ParseBasebackupStderr_WithPG17Output_ExtractsCorrectSegments(t *testing.T) {
|
||||
func Test_ParseBasebackupStderr_WithPG17FetchOutput_ExtractsCorrectSegments(t *testing.T) {
|
||||
stderr := `pg_basebackup: initiating base backup, waiting for checkpoint to complete
|
||||
pg_basebackup: checkpoint completed
|
||||
pg_basebackup: write-ahead log start point: 0/2000028, on timeline 1
|
||||
pg_basebackup: write-ahead log start point: 0/2000028 on timeline 1
|
||||
pg_basebackup: starting background WAL receiver
|
||||
pg_basebackup: checkpoint redo point at 0/2000028
|
||||
pg_basebackup: write-ahead log end point: 0/2000100
|
||||
pg_basebackup: waiting for background process to finish streaming ...
|
||||
pg_basebackup: syncing data to disk ...
|
||||
@@ -26,13 +25,9 @@ pg_basebackup: base backup completed`
|
||||
assert.Equal(t, "000000010000000000000002", stopSeg)
|
||||
}
|
||||
|
||||
func Test_ParseBasebackupStderr_WithPG15Output_ExtractsCorrectSegments(t *testing.T) {
|
||||
stderr := `pg_basebackup: initiating base backup, waiting for checkpoint to complete
|
||||
pg_basebackup: checkpoint completed
|
||||
pg_basebackup: write-ahead log start point: 1/AB000028, on timeline 1
|
||||
pg_basebackup: checkpoint redo point at 1/AB000028
|
||||
pg_basebackup: write-ahead log end point: 1/AC000000
|
||||
pg_basebackup: base backup completed`
|
||||
func Test_ParseBasebackupStderr_WithHighLSNValues_ExtractsCorrectSegments(t *testing.T) {
|
||||
stderr := `pg_basebackup: write-ahead log start point: 1/AB000028 on timeline 1
|
||||
pg_basebackup: write-ahead log end point: 1/AC000000`
|
||||
|
||||
startSeg, stopSeg, err := ParseBasebackupStderr(stderr)
|
||||
|
||||
@@ -42,7 +37,7 @@ pg_basebackup: base backup completed`
|
||||
}
|
||||
|
||||
func Test_ParseBasebackupStderr_WithHighLogID_ExtractsCorrectSegments(t *testing.T) {
|
||||
stderr := `pg_basebackup: checkpoint redo point at A/FF000028
|
||||
stderr := `pg_basebackup: write-ahead log start point: A/FF000028 on timeline 1
|
||||
pg_basebackup: write-ahead log end point: B/1000000`
|
||||
|
||||
startSeg, stopSeg, err := ParseBasebackupStderr(stderr)
|
||||
@@ -63,7 +58,7 @@ pg_basebackup: base backup completed`
|
||||
}
|
||||
|
||||
func Test_ParseBasebackupStderr_WhenStopLSNMissing_ReturnsError(t *testing.T) {
|
||||
stderr := `pg_basebackup: checkpoint redo point at 0/2000028
|
||||
stderr := `pg_basebackup: write-ahead log start point: 0/2000028 on timeline 1
|
||||
pg_basebackup: base backup completed`
|
||||
|
||||
_, _, err := ParseBasebackupStderr(stderr)
|
||||
|
||||
444
agent/internal/features/restore/restorer.go
Normal file
444
agent/internal/features/restore/restorer.go
Normal file
@@ -0,0 +1,444 @@
|
||||
package restore
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"databasus-agent/internal/features/api"
|
||||
)
|
||||
|
||||
const (
|
||||
walRestoreDir = "databasus-wal-restore"
|
||||
maxRetryAttempts = 3
|
||||
retryBaseDelay = 1 * time.Second
|
||||
recoverySignalFile = "recovery.signal"
|
||||
autoConfFile = "postgresql.auto.conf"
|
||||
dockerContainerPgDataDir = "/var/lib/postgresql/data"
|
||||
)
|
||||
|
||||
var retryDelayOverride *time.Duration
|
||||
|
||||
type Restorer struct {
|
||||
apiClient *api.Client
|
||||
log *slog.Logger
|
||||
targetPgDataDir string
|
||||
backupID string
|
||||
targetTime string
|
||||
pgType string
|
||||
}
|
||||
|
||||
func NewRestorer(
|
||||
apiClient *api.Client,
|
||||
log *slog.Logger,
|
||||
targetPgDataDir string,
|
||||
backupID string,
|
||||
targetTime string,
|
||||
pgType string,
|
||||
) *Restorer {
|
||||
return &Restorer{
|
||||
apiClient,
|
||||
log,
|
||||
targetPgDataDir,
|
||||
backupID,
|
||||
targetTime,
|
||||
pgType,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Restorer) Run(ctx context.Context) error {
|
||||
var parsedTargetTime *time.Time
|
||||
|
||||
if r.targetTime != "" {
|
||||
parsed, err := time.Parse(time.RFC3339, r.targetTime)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid --target-time format (expected RFC3339, e.g. 2026-02-28T14:30:00Z): %w", err)
|
||||
}
|
||||
|
||||
parsedTargetTime = &parsed
|
||||
}
|
||||
|
||||
if err := r.validateTargetPgDataDir(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
plan, err := r.getRestorePlanFromServer(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.logRestorePlan(plan, parsedTargetTime)
|
||||
|
||||
r.log.Info("Downloading and extracting basebackup...")
|
||||
if err := r.downloadAndExtractBasebackup(ctx, plan.FullBackup.BackupID); err != nil {
|
||||
return fmt.Errorf("basebackup download failed: %w", err)
|
||||
}
|
||||
r.log.Info("Basebackup extracted successfully")
|
||||
|
||||
if err := r.downloadAllWalSegments(ctx, plan.WalSegments); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.configurePostgresRecovery(parsedTargetTime); err != nil {
|
||||
return fmt.Errorf("failed to configure recovery: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(r.targetPgDataDir, 0o700); err != nil {
|
||||
return fmt.Errorf("set PGDATA permissions: %w", err)
|
||||
}
|
||||
|
||||
r.printCompletionMessage()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Restorer) validateTargetPgDataDir() error {
|
||||
info, err := os.Stat(r.targetPgDataDir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return fmt.Errorf("target pgdata directory does not exist: %s", r.targetPgDataDir)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot access target pgdata directory: %w", err)
|
||||
}
|
||||
|
||||
if !info.IsDir() {
|
||||
return fmt.Errorf("target pgdata path is not a directory: %s", r.targetPgDataDir)
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(r.targetPgDataDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot read target pgdata directory: %w", err)
|
||||
}
|
||||
|
||||
if len(entries) > 0 {
|
||||
return fmt.Errorf("target pgdata directory is not empty: %s", r.targetPgDataDir)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Restorer) getRestorePlanFromServer(ctx context.Context) (*api.GetRestorePlanResponse, error) {
|
||||
plan, planErr, err := r.apiClient.GetRestorePlan(ctx, r.backupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch restore plan: %w", err)
|
||||
}
|
||||
|
||||
if planErr != nil {
|
||||
if planErr.LastContiguousSegment != "" {
|
||||
return nil, fmt.Errorf("restore plan error: %s (last contiguous segment: %s)",
|
||||
planErr.Message, planErr.LastContiguousSegment)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("restore plan error: %s", planErr.Message)
|
||||
}
|
||||
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
func (r *Restorer) logRestorePlan(plan *api.GetRestorePlanResponse, parsedTargetTime *time.Time) {
|
||||
recoveryTarget := "full recovery (all available WAL)"
|
||||
if parsedTargetTime != nil {
|
||||
recoveryTarget = parsedTargetTime.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
r.log.Info("Restore plan",
|
||||
"fullBackupID", plan.FullBackup.BackupID,
|
||||
"fullBackupCreatedAt", plan.FullBackup.CreatedAt.Format(time.RFC3339),
|
||||
"pgVersion", plan.FullBackup.PgVersion,
|
||||
"walSegmentCount", len(plan.WalSegments),
|
||||
"totalDownloadSize", formatSizeBytes(plan.TotalSizeBytes),
|
||||
"latestAvailableSegment", plan.LatestAvailableSegment,
|
||||
"recoveryTarget", recoveryTarget,
|
||||
)
|
||||
}
|
||||
|
||||
func (r *Restorer) downloadAndExtractBasebackup(ctx context.Context, backupID string) error {
|
||||
body, err := r.apiClient.DownloadBackupFile(ctx, backupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = body.Close() }()
|
||||
|
||||
zstdReader, err := zstd.NewReader(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create zstd decompressor: %w", err)
|
||||
}
|
||||
defer zstdReader.Close()
|
||||
|
||||
tarReader := tar.NewReader(zstdReader)
|
||||
|
||||
return r.extractTarArchive(tarReader)
|
||||
}
|
||||
|
||||
func (r *Restorer) extractTarArchive(tarReader *tar.Reader) error {
|
||||
for {
|
||||
header, err := tarReader.Next()
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("read tar entry: %w", err)
|
||||
}
|
||||
|
||||
targetPath := filepath.Join(r.targetPgDataDir, header.Name)
|
||||
|
||||
relativePath, err := filepath.Rel(r.targetPgDataDir, targetPath)
|
||||
if err != nil || strings.HasPrefix(relativePath, "..") {
|
||||
return fmt.Errorf("tar entry attempts path traversal: %s", header.Name)
|
||||
}
|
||||
|
||||
switch header.Typeflag {
|
||||
case tar.TypeDir:
|
||||
if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil {
|
||||
return fmt.Errorf("create directory %s: %w", header.Name, err)
|
||||
}
|
||||
|
||||
case tar.TypeReg:
|
||||
parentDir := filepath.Dir(targetPath)
|
||||
if err := os.MkdirAll(parentDir, 0o755); err != nil {
|
||||
return fmt.Errorf("create parent directory for %s: %w", header.Name, err)
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(targetPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create file %s: %w", header.Name, err)
|
||||
}
|
||||
|
||||
if _, err := io.Copy(file, tarReader); err != nil {
|
||||
_ = file.Close()
|
||||
return fmt.Errorf("write file %s: %w", header.Name, err)
|
||||
}
|
||||
|
||||
_ = file.Close()
|
||||
|
||||
case tar.TypeSymlink:
|
||||
if err := os.Symlink(header.Linkname, targetPath); err != nil {
|
||||
return fmt.Errorf("create symlink %s: %w", header.Name, err)
|
||||
}
|
||||
|
||||
case tar.TypeLink:
|
||||
linkTarget := filepath.Join(r.targetPgDataDir, header.Linkname)
|
||||
if err := os.Link(linkTarget, targetPath); err != nil {
|
||||
return fmt.Errorf("create hard link %s: %w", header.Name, err)
|
||||
}
|
||||
|
||||
default:
|
||||
r.log.Warn("Skipping unsupported tar entry type",
|
||||
"name", header.Name,
|
||||
"type", header.Typeflag,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Restorer) downloadAllWalSegments(ctx context.Context, segments []api.RestorePlanWalSegment) error {
|
||||
walRestorePath := filepath.Join(r.targetPgDataDir, walRestoreDir)
|
||||
if err := os.MkdirAll(walRestorePath, 0o755); err != nil {
|
||||
return fmt.Errorf("create WAL restore directory: %w", err)
|
||||
}
|
||||
|
||||
for segmentIndex, segment := range segments {
|
||||
if err := r.downloadWalSegmentWithRetry(ctx, segment, segmentIndex, len(segments)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Restorer) downloadWalSegmentWithRetry(
|
||||
ctx context.Context,
|
||||
segment api.RestorePlanWalSegment,
|
||||
segmentIndex int,
|
||||
segmentsTotal int,
|
||||
) error {
|
||||
r.log.Info("Downloading WAL segment",
|
||||
"segment", segment.SegmentName,
|
||||
"progress", fmt.Sprintf("%d/%d", segmentIndex+1, segmentsTotal),
|
||||
)
|
||||
|
||||
var lastErr error
|
||||
|
||||
for attempt := range maxRetryAttempts {
|
||||
if err := r.downloadWalSegment(ctx, segment); err != nil {
|
||||
lastErr = err
|
||||
|
||||
delay := r.getRetryDelay(attempt)
|
||||
r.log.Warn("WAL segment download failed, retrying",
|
||||
"segment", segment.SegmentName,
|
||||
"attempt", attempt+1,
|
||||
"maxAttempts", maxRetryAttempts,
|
||||
"retryDelay", delay,
|
||||
"error", err,
|
||||
)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(delay):
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to download WAL segment %s after %d attempts: %w",
|
||||
segment.SegmentName, maxRetryAttempts, lastErr)
|
||||
}
|
||||
|
||||
func (r *Restorer) downloadWalSegment(ctx context.Context, segment api.RestorePlanWalSegment) error {
|
||||
body, err := r.apiClient.DownloadBackupFile(ctx, segment.BackupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = body.Close() }()
|
||||
|
||||
zstdReader, err := zstd.NewReader(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create zstd decompressor: %w", err)
|
||||
}
|
||||
defer zstdReader.Close()
|
||||
|
||||
segmentPath := filepath.Join(r.targetPgDataDir, walRestoreDir, segment.SegmentName)
|
||||
|
||||
file, err := os.Create(segmentPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create WAL segment file: %w", err)
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
|
||||
if _, err := io.Copy(file, zstdReader); err != nil {
|
||||
return fmt.Errorf("write WAL segment: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Restorer) configurePostgresRecovery(parsedTargetTime *time.Time) error {
|
||||
recoverySignalPath := filepath.Join(r.targetPgDataDir, recoverySignalFile)
|
||||
if err := os.WriteFile(recoverySignalPath, []byte{}, 0o644); err != nil {
|
||||
return fmt.Errorf("create recovery.signal: %w", err)
|
||||
}
|
||||
|
||||
walRestoreAbsPath, err := r.resolveWalRestorePath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
autoConfPath := filepath.Join(r.targetPgDataDir, autoConfFile)
|
||||
|
||||
autoConfFile, err := os.OpenFile(autoConfPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open postgresql.auto.conf: %w", err)
|
||||
}
|
||||
defer func() { _ = autoConfFile.Close() }()
|
||||
|
||||
var configLines strings.Builder
|
||||
configLines.WriteString("\n# Added by databasus-agent restore\n")
|
||||
fmt.Fprintf(&configLines, "restore_command = 'cp %s/%%f %%p'\n", walRestoreAbsPath)
|
||||
fmt.Fprintf(&configLines, "recovery_end_command = 'rm -rf %s'\n", walRestoreAbsPath)
|
||||
configLines.WriteString("recovery_target_action = 'promote'\n")
|
||||
|
||||
if parsedTargetTime != nil {
|
||||
fmt.Fprintf(&configLines, "recovery_target_time = '%s'\n", parsedTargetTime.Format(time.RFC3339))
|
||||
}
|
||||
|
||||
if _, err := autoConfFile.WriteString(configLines.String()); err != nil {
|
||||
return fmt.Errorf("write to postgresql.auto.conf: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Restorer) printCompletionMessage() {
|
||||
absPgDataDir, _ := filepath.Abs(r.targetPgDataDir)
|
||||
isDocker := r.pgType == "docker"
|
||||
|
||||
fmt.Printf("\nRestore complete. PGDATA directory is ready at %s.\n", absPgDataDir)
|
||||
|
||||
fmt.Print(`
|
||||
What happens when you start PostgreSQL:
|
||||
1. PostgreSQL detects recovery.signal and enters recovery mode
|
||||
2. It replays WAL from the basebackup's consistency point
|
||||
3. It executes restore_command to fetch WAL segments from databasus-wal-restore/
|
||||
4. WAL replay continues until target_time (if PITR) or end of available WAL
|
||||
5. recovery_end_command automatically removes databasus-wal-restore/
|
||||
6. PostgreSQL promotes to primary and removes recovery.signal
|
||||
7. Normal operations resume
|
||||
`)
|
||||
|
||||
if isDocker {
|
||||
fmt.Printf(`
|
||||
Start PostgreSQL by launching a container with the restored data mounted:
|
||||
docker run -d -v %s:%s postgres:<VERSION>
|
||||
|
||||
Or if you have an existing container:
|
||||
docker start <CONTAINER_NAME>
|
||||
|
||||
Ensure %s is mounted as the container's pgdata volume at %s.
|
||||
`, absPgDataDir, dockerContainerPgDataDir, absPgDataDir, dockerContainerPgDataDir)
|
||||
} else {
|
||||
fmt.Printf(`
|
||||
Start PostgreSQL:
|
||||
pg_ctl -D %s start
|
||||
|
||||
Note: If you move the PGDATA directory before starting PostgreSQL,
|
||||
update restore_command and recovery_end_command paths in
|
||||
postgresql.auto.conf accordingly.
|
||||
`, absPgDataDir)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Restorer) resolveWalRestorePath() (string, error) {
|
||||
if r.pgType == "docker" {
|
||||
return dockerContainerPgDataDir + "/" + walRestoreDir, nil
|
||||
}
|
||||
|
||||
absPgDataDir, err := filepath.Abs(r.targetPgDataDir)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("resolve absolute path: %w", err)
|
||||
}
|
||||
|
||||
absPgDataDir = filepath.ToSlash(absPgDataDir)
|
||||
|
||||
return absPgDataDir + "/" + walRestoreDir, nil
|
||||
}
|
||||
|
||||
func (r *Restorer) getRetryDelay(attempt int) time.Duration {
|
||||
if retryDelayOverride != nil {
|
||||
return *retryDelayOverride
|
||||
}
|
||||
|
||||
return retryBaseDelay * time.Duration(1<<attempt)
|
||||
}
|
||||
|
||||
func formatSizeBytes(sizeBytes int64) string {
|
||||
const (
|
||||
kilobyte = 1024
|
||||
megabyte = 1024 * kilobyte
|
||||
gigabyte = 1024 * megabyte
|
||||
)
|
||||
|
||||
switch {
|
||||
case sizeBytes >= gigabyte:
|
||||
return fmt.Sprintf("%.2f GB", float64(sizeBytes)/float64(gigabyte))
|
||||
case sizeBytes >= megabyte:
|
||||
return fmt.Sprintf("%.2f MB", float64(sizeBytes)/float64(megabyte))
|
||||
case sizeBytes >= kilobyte:
|
||||
return fmt.Sprintf("%.2f KB", float64(sizeBytes)/float64(kilobyte))
|
||||
default:
|
||||
return fmt.Sprintf("%d B", sizeBytes)
|
||||
}
|
||||
}
|
||||
711
agent/internal/features/restore/restorer_test.go
Normal file
711
agent/internal/features/restore/restorer_test.go
Normal file
@@ -0,0 +1,711 @@
|
||||
package restore
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"databasus-agent/internal/features/api"
|
||||
"databasus-agent/internal/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
testRestorePlanPath = "/api/v1/backups/postgres/wal/restore/plan"
|
||||
testRestoreDownloadPath = "/api/v1/backups/postgres/wal/restore/download"
|
||||
|
||||
testFullBackupID = "full-backup-id-1234"
|
||||
testWalSegment1 = "000000010000000100000001"
|
||||
testWalSegment2 = "000000010000000100000002"
|
||||
)
|
||||
|
||||
func Test_RunRestore_WhenBasebackupAndWalSegmentsAvailable_FilesExtractedAndRecoveryConfigured(t *testing.T) {
|
||||
tarFiles := map[string][]byte{
|
||||
"PG_VERSION": []byte("16"),
|
||||
"base/1/somefile": []byte("table-data"),
|
||||
}
|
||||
zstdTarData := createZstdTar(t, tarFiles)
|
||||
walData1 := createZstdData(t, []byte("wal-segment-1-data"))
|
||||
walData2 := createZstdData(t, []byte("wal-segment-2-data"))
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testRestorePlanPath:
|
||||
writeJSON(w, api.GetRestorePlanResponse{
|
||||
FullBackup: api.RestorePlanFullBackup{
|
||||
BackupID: testFullBackupID,
|
||||
FullBackupWalStartSegment: testWalSegment1,
|
||||
FullBackupWalStopSegment: testWalSegment1,
|
||||
PgVersion: "16",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
SizeBytes: 1024,
|
||||
},
|
||||
WalSegments: []api.RestorePlanWalSegment{
|
||||
{BackupID: "wal-1", SegmentName: testWalSegment1, SizeBytes: 512},
|
||||
{BackupID: "wal-2", SegmentName: testWalSegment2, SizeBytes: 512},
|
||||
},
|
||||
TotalSizeBytes: 2048,
|
||||
LatestAvailableSegment: testWalSegment2,
|
||||
})
|
||||
|
||||
case testRestoreDownloadPath:
|
||||
backupID := r.URL.Query().Get("backupId")
|
||||
switch backupID {
|
||||
case testFullBackupID:
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write(zstdTarData)
|
||||
case "wal-1":
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write(walData1)
|
||||
case "wal-2":
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write(walData2)
|
||||
default:
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}
|
||||
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
|
||||
|
||||
err := restorer.Run(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
pgVersionContent, err := os.ReadFile(filepath.Join(targetDir, "PG_VERSION"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "16", string(pgVersionContent))
|
||||
|
||||
someFileContent, err := os.ReadFile(filepath.Join(targetDir, "base", "1", "somefile"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "table-data", string(someFileContent))
|
||||
|
||||
walSegment1Content, err := os.ReadFile(filepath.Join(targetDir, walRestoreDir, testWalSegment1))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "wal-segment-1-data", string(walSegment1Content))
|
||||
|
||||
walSegment2Content, err := os.ReadFile(filepath.Join(targetDir, walRestoreDir, testWalSegment2))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "wal-segment-2-data", string(walSegment2Content))
|
||||
|
||||
recoverySignalPath := filepath.Join(targetDir, "recovery.signal")
|
||||
recoverySignalInfo, err := os.Stat(recoverySignalPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), recoverySignalInfo.Size())
|
||||
|
||||
autoConfContent, err := os.ReadFile(filepath.Join(targetDir, "postgresql.auto.conf"))
|
||||
require.NoError(t, err)
|
||||
autoConfStr := string(autoConfContent)
|
||||
|
||||
assert.Contains(t, autoConfStr, "restore_command")
|
||||
assert.Contains(t, autoConfStr, walRestoreDir)
|
||||
assert.Contains(t, autoConfStr, "recovery_target_action = 'promote'")
|
||||
assert.Contains(t, autoConfStr, "recovery_end_command")
|
||||
assert.NotContains(t, autoConfStr, "recovery_target_time")
|
||||
}
|
||||
|
||||
func Test_RunRestore_WhenTargetTimeProvided_RecoveryTargetTimeWrittenToConfig(t *testing.T) {
|
||||
tarFiles := map[string][]byte{"PG_VERSION": []byte("16")}
|
||||
zstdTarData := createZstdTar(t, tarFiles)
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testRestorePlanPath:
|
||||
writeJSON(w, api.GetRestorePlanResponse{
|
||||
FullBackup: api.RestorePlanFullBackup{
|
||||
BackupID: testFullBackupID,
|
||||
PgVersion: "16",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
SizeBytes: 1024,
|
||||
},
|
||||
WalSegments: []api.RestorePlanWalSegment{},
|
||||
TotalSizeBytes: 1024,
|
||||
LatestAvailableSegment: "",
|
||||
})
|
||||
|
||||
case testRestoreDownloadPath:
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write(zstdTarData)
|
||||
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "2026-02-28T14:30:00Z", "")
|
||||
|
||||
err := restorer.Run(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
autoConfContent, err := os.ReadFile(filepath.Join(targetDir, "postgresql.auto.conf"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Contains(t, string(autoConfContent), "recovery_target_time = '2026-02-28T14:30:00Z'")
|
||||
}
|
||||
|
||||
func Test_RunRestore_WhenPgDataDirNotEmpty_ReturnsError(t *testing.T) {
|
||||
targetDir := createTestTargetDir(t)
|
||||
|
||||
err := os.WriteFile(filepath.Join(targetDir, "existing-file"), []byte("data"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
restorer := newTestRestorer("http://localhost:0", targetDir, "", "", "")
|
||||
|
||||
err = restorer.Run(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not empty")
|
||||
}
|
||||
|
||||
func Test_RunRestore_WhenPgDataDirDoesNotExist_ReturnsError(t *testing.T) {
|
||||
nonExistentDir := filepath.Join(os.TempDir(), "databasus-test-nonexistent-dir-12345")
|
||||
|
||||
restorer := newTestRestorer("http://localhost:0", nonExistentDir, "", "", "")
|
||||
|
||||
err := restorer.Run(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "does not exist")
|
||||
}
|
||||
|
||||
func Test_RunRestore_WhenNoBackupsAvailable_ReturnsError(t *testing.T) {
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_ = json.NewEncoder(w).Encode(api.GetRestorePlanErrorResponse{
|
||||
Error: "no_backups",
|
||||
Message: "No full backups available",
|
||||
})
|
||||
})
|
||||
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
|
||||
|
||||
err := restorer.Run(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "No full backups available")
|
||||
}
|
||||
|
||||
func Test_RunRestore_WhenWalChainBroken_ReturnsError(t *testing.T) {
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_ = json.NewEncoder(w).Encode(api.GetRestorePlanErrorResponse{
|
||||
Error: "wal_chain_broken",
|
||||
Message: "WAL chain broken",
|
||||
LastContiguousSegment: testWalSegment1,
|
||||
})
|
||||
})
|
||||
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
|
||||
|
||||
err := restorer.Run(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "WAL chain broken")
|
||||
assert.Contains(t, err.Error(), testWalSegment1)
|
||||
}
|
||||
|
||||
func Test_DownloadWalSegment_WhenFirstAttemptFails_RetriesAndSucceeds(t *testing.T) {
|
||||
tarFiles := map[string][]byte{"PG_VERSION": []byte("16")}
|
||||
zstdTarData := createZstdTar(t, tarFiles)
|
||||
walData := createZstdData(t, []byte("wal-segment-data"))
|
||||
|
||||
var mu sync.Mutex
|
||||
var walDownloadAttempts int
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testRestorePlanPath:
|
||||
writeJSON(w, api.GetRestorePlanResponse{
|
||||
FullBackup: api.RestorePlanFullBackup{
|
||||
BackupID: testFullBackupID,
|
||||
PgVersion: "16",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
SizeBytes: 1024,
|
||||
},
|
||||
WalSegments: []api.RestorePlanWalSegment{
|
||||
{BackupID: "wal-1", SegmentName: testWalSegment1, SizeBytes: 512},
|
||||
},
|
||||
TotalSizeBytes: 1536,
|
||||
LatestAvailableSegment: testWalSegment1,
|
||||
})
|
||||
|
||||
case testRestoreDownloadPath:
|
||||
backupID := r.URL.Query().Get("backupId")
|
||||
if backupID == testFullBackupID {
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write(zstdTarData)
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
walDownloadAttempts++
|
||||
attempt := walDownloadAttempts
|
||||
mu.Unlock()
|
||||
|
||||
if attempt == 1 {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte(`{"error":"storage unavailable"}`))
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write(walData)
|
||||
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
|
||||
|
||||
origDelay := retryDelayOverride
|
||||
testDelay := 10 * time.Millisecond
|
||||
retryDelayOverride = &testDelay
|
||||
defer func() { retryDelayOverride = origDelay }()
|
||||
|
||||
err := restorer.Run(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
mu.Lock()
|
||||
attempts := walDownloadAttempts
|
||||
mu.Unlock()
|
||||
|
||||
assert.Equal(t, 2, attempts)
|
||||
|
||||
walContent, err := os.ReadFile(filepath.Join(targetDir, walRestoreDir, testWalSegment1))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "wal-segment-data", string(walContent))
|
||||
}
|
||||
|
||||
func Test_DownloadWalSegment_WhenAllAttemptsFail_ReturnsErrorWithSegmentName(t *testing.T) {
|
||||
tarFiles := map[string][]byte{"PG_VERSION": []byte("16")}
|
||||
zstdTarData := createZstdTar(t, tarFiles)
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testRestorePlanPath:
|
||||
writeJSON(w, api.GetRestorePlanResponse{
|
||||
FullBackup: api.RestorePlanFullBackup{
|
||||
BackupID: testFullBackupID,
|
||||
PgVersion: "16",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
SizeBytes: 1024,
|
||||
},
|
||||
WalSegments: []api.RestorePlanWalSegment{
|
||||
{BackupID: "wal-1", SegmentName: testWalSegment1, SizeBytes: 512},
|
||||
},
|
||||
TotalSizeBytes: 1536,
|
||||
LatestAvailableSegment: testWalSegment1,
|
||||
})
|
||||
|
||||
case testRestoreDownloadPath:
|
||||
backupID := r.URL.Query().Get("backupId")
|
||||
if backupID == testFullBackupID {
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write(zstdTarData)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte(`{"error":"storage unavailable"}`))
|
||||
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
|
||||
|
||||
origDelay := retryDelayOverride
|
||||
testDelay := 10 * time.Millisecond
|
||||
retryDelayOverride = &testDelay
|
||||
defer func() { retryDelayOverride = origDelay }()
|
||||
|
||||
err := restorer.Run(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), testWalSegment1)
|
||||
assert.Contains(t, err.Error(), "3 attempts")
|
||||
}
|
||||
|
||||
func Test_RunRestore_WhenInvalidTargetTimeFormat_ReturnsError(t *testing.T) {
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer("http://localhost:0", targetDir, "", "not-a-valid-time", "")
|
||||
|
||||
err := restorer.Run(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid --target-time format")
|
||||
}
|
||||
|
||||
func Test_RunRestore_WhenBasebackupDownloadFails_ReturnsError(t *testing.T) {
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testRestorePlanPath:
|
||||
writeJSON(w, api.GetRestorePlanResponse{
|
||||
FullBackup: api.RestorePlanFullBackup{
|
||||
BackupID: testFullBackupID,
|
||||
PgVersion: "16",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
SizeBytes: 1024,
|
||||
},
|
||||
WalSegments: []api.RestorePlanWalSegment{},
|
||||
TotalSizeBytes: 1024,
|
||||
LatestAvailableSegment: "",
|
||||
})
|
||||
|
||||
case testRestoreDownloadPath:
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte(`{"error":"storage error"}`))
|
||||
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
|
||||
|
||||
err := restorer.Run(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "basebackup download failed")
|
||||
}
|
||||
|
||||
func Test_RunRestore_WhenNoWalSegmentsInPlan_BasebackupRestoredSuccessfully(t *testing.T) {
|
||||
tarFiles := map[string][]byte{
|
||||
"PG_VERSION": []byte("16"),
|
||||
"global/pg_control": []byte("control-data"),
|
||||
}
|
||||
zstdTarData := createZstdTar(t, tarFiles)
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testRestorePlanPath:
|
||||
writeJSON(w, api.GetRestorePlanResponse{
|
||||
FullBackup: api.RestorePlanFullBackup{
|
||||
BackupID: testFullBackupID,
|
||||
PgVersion: "16",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
SizeBytes: 1024,
|
||||
},
|
||||
WalSegments: []api.RestorePlanWalSegment{},
|
||||
TotalSizeBytes: 1024,
|
||||
LatestAvailableSegment: "",
|
||||
})
|
||||
|
||||
case testRestoreDownloadPath:
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write(zstdTarData)
|
||||
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
|
||||
|
||||
err := restorer.Run(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
pgVersionContent, err := os.ReadFile(filepath.Join(targetDir, "PG_VERSION"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "16", string(pgVersionContent))
|
||||
|
||||
walRestoreDirInfo, err := os.Stat(filepath.Join(targetDir, walRestoreDir))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, walRestoreDirInfo.IsDir())
|
||||
|
||||
_, err = os.Stat(filepath.Join(targetDir, "recovery.signal"))
|
||||
require.NoError(t, err)
|
||||
|
||||
autoConfContent, err := os.ReadFile(filepath.Join(targetDir, "postgresql.auto.conf"))
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, string(autoConfContent), "restore_command")
|
||||
}
|
||||
|
||||
func Test_RunRestore_WhenMakingApiCalls_AuthTokenIncludedInRequests(t *testing.T) {
|
||||
tarFiles := map[string][]byte{"PG_VERSION": []byte("16")}
|
||||
zstdTarData := createZstdTar(t, tarFiles)
|
||||
|
||||
var receivedAuthHeaders atomic.Int32
|
||||
var mu sync.Mutex
|
||||
var authHeaderValues []string
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader != "" {
|
||||
receivedAuthHeaders.Add(1)
|
||||
|
||||
mu.Lock()
|
||||
authHeaderValues = append(authHeaderValues, authHeader)
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
switch r.URL.Path {
|
||||
case testRestorePlanPath:
|
||||
writeJSON(w, api.GetRestorePlanResponse{
|
||||
FullBackup: api.RestorePlanFullBackup{
|
||||
BackupID: testFullBackupID,
|
||||
PgVersion: "16",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
SizeBytes: 1024,
|
||||
},
|
||||
WalSegments: []api.RestorePlanWalSegment{},
|
||||
TotalSizeBytes: 1024,
|
||||
LatestAvailableSegment: "",
|
||||
})
|
||||
|
||||
case testRestoreDownloadPath:
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write(zstdTarData)
|
||||
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
|
||||
|
||||
err := restorer.Run(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.GreaterOrEqual(t, int(receivedAuthHeaders.Load()), 2)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
for _, headerValue := range authHeaderValues {
|
||||
assert.Equal(t, "test-token", headerValue)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ConfigurePostgresRecovery_WhenPgTypeHost_UsesHostAbsolutePath(t *testing.T) {
|
||||
tarFiles := map[string][]byte{"PG_VERSION": []byte("16")}
|
||||
zstdTarData := createZstdTar(t, tarFiles)
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testRestorePlanPath:
|
||||
writeJSON(w, api.GetRestorePlanResponse{
|
||||
FullBackup: api.RestorePlanFullBackup{
|
||||
BackupID: testFullBackupID,
|
||||
PgVersion: "16",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
SizeBytes: 1024,
|
||||
},
|
||||
WalSegments: []api.RestorePlanWalSegment{},
|
||||
TotalSizeBytes: 1024,
|
||||
LatestAvailableSegment: "",
|
||||
})
|
||||
|
||||
case testRestoreDownloadPath:
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write(zstdTarData)
|
||||
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "host")
|
||||
|
||||
err := restorer.Run(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
autoConfContent, err := os.ReadFile(filepath.Join(targetDir, "postgresql.auto.conf"))
|
||||
require.NoError(t, err)
|
||||
autoConfStr := string(autoConfContent)
|
||||
|
||||
absTargetDir, _ := filepath.Abs(targetDir)
|
||||
absTargetDir = filepath.ToSlash(absTargetDir)
|
||||
expectedWalPath := absTargetDir + "/" + walRestoreDir
|
||||
|
||||
assert.Contains(t, autoConfStr, fmt.Sprintf("restore_command = 'cp %s/%%f %%p'", expectedWalPath))
|
||||
assert.Contains(t, autoConfStr, fmt.Sprintf("recovery_end_command = 'rm -rf %s'", expectedWalPath))
|
||||
assert.NotContains(t, autoConfStr, "/var/lib/postgresql/data")
|
||||
}
|
||||
|
||||
func Test_ConfigurePostgresRecovery_WhenPgTypeDocker_UsesContainerPath(t *testing.T) {
|
||||
tarFiles := map[string][]byte{"PG_VERSION": []byte("16")}
|
||||
zstdTarData := createZstdTar(t, tarFiles)
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testRestorePlanPath:
|
||||
writeJSON(w, api.GetRestorePlanResponse{
|
||||
FullBackup: api.RestorePlanFullBackup{
|
||||
BackupID: testFullBackupID,
|
||||
PgVersion: "16",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
SizeBytes: 1024,
|
||||
},
|
||||
WalSegments: []api.RestorePlanWalSegment{},
|
||||
TotalSizeBytes: 1024,
|
||||
LatestAvailableSegment: "",
|
||||
})
|
||||
|
||||
case testRestoreDownloadPath:
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write(zstdTarData)
|
||||
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "docker")
|
||||
|
||||
err := restorer.Run(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
autoConfContent, err := os.ReadFile(filepath.Join(targetDir, "postgresql.auto.conf"))
|
||||
require.NoError(t, err)
|
||||
autoConfStr := string(autoConfContent)
|
||||
|
||||
expectedWalPath := "/var/lib/postgresql/data/" + walRestoreDir
|
||||
|
||||
assert.Contains(t, autoConfStr, fmt.Sprintf("restore_command = 'cp %s/%%f %%p'", expectedWalPath))
|
||||
assert.Contains(t, autoConfStr, fmt.Sprintf("recovery_end_command = 'rm -rf %s'", expectedWalPath))
|
||||
|
||||
absTargetDir, _ := filepath.Abs(targetDir)
|
||||
absTargetDir = filepath.ToSlash(absTargetDir)
|
||||
assert.NotContains(t, autoConfStr, absTargetDir)
|
||||
}
|
||||
|
||||
func newTestServer(t *testing.T, handler http.HandlerFunc) *httptest.Server {
|
||||
t.Helper()
|
||||
|
||||
server := httptest.NewServer(handler)
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
func createTestTargetDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
baseDir := filepath.Join(".", ".test-tmp")
|
||||
if err := os.MkdirAll(baseDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create base test dir: %v", err)
|
||||
}
|
||||
|
||||
dir, err := os.MkdirTemp(baseDir, t.Name()+"-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test target dir: %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(dir)
|
||||
})
|
||||
|
||||
return dir
|
||||
}
|
||||
|
||||
func createZstdTar(t *testing.T, files map[string][]byte) []byte {
|
||||
t.Helper()
|
||||
|
||||
var tarBuffer bytes.Buffer
|
||||
tarWriter := tar.NewWriter(&tarBuffer)
|
||||
|
||||
createdDirs := make(map[string]bool)
|
||||
|
||||
for name, content := range files {
|
||||
dir := filepath.Dir(name)
|
||||
if dir != "." && !createdDirs[dir] {
|
||||
parts := strings.Split(filepath.ToSlash(dir), "/")
|
||||
for partIndex := range parts {
|
||||
partialDir := strings.Join(parts[:partIndex+1], "/")
|
||||
if !createdDirs[partialDir] {
|
||||
err := tarWriter.WriteHeader(&tar.Header{
|
||||
Name: partialDir + "/",
|
||||
Typeflag: tar.TypeDir,
|
||||
Mode: 0o755,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
createdDirs[partialDir] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err := tarWriter.WriteHeader(&tar.Header{
|
||||
Name: name,
|
||||
Size: int64(len(content)),
|
||||
Mode: 0o644,
|
||||
Typeflag: tar.TypeReg,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = tarWriter.Write(content)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.NoError(t, tarWriter.Close())
|
||||
|
||||
var zstdBuffer bytes.Buffer
|
||||
|
||||
encoder, err := zstd.NewWriter(&zstdBuffer,
|
||||
zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(5)),
|
||||
zstd.WithEncoderCRC(true),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = encoder.Write(tarBuffer.Bytes())
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, encoder.Close())
|
||||
|
||||
return zstdBuffer.Bytes()
|
||||
}
|
||||
|
||||
func createZstdData(t *testing.T, data []byte) []byte {
|
||||
t.Helper()
|
||||
|
||||
var buffer bytes.Buffer
|
||||
|
||||
encoder, err := zstd.NewWriter(&buffer,
|
||||
zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(5)),
|
||||
zstd.WithEncoderCRC(true),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = encoder.Write(data)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, encoder.Close())
|
||||
|
||||
return buffer.Bytes()
|
||||
}
|
||||
|
||||
func newTestRestorer(serverURL, targetPgDataDir, backupID, targetTime, pgType string) *Restorer {
|
||||
apiClient := api.NewClient(serverURL, "test-token", logger.GetLogger())
|
||||
|
||||
return NewRestorer(apiClient, logger.GetLogger(), targetPgDataDir, backupID, targetTime, pgType)
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, value any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if err := json.NewEncoder(w).Encode(value); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
@@ -50,8 +50,23 @@ func AcquireLock(log *slog.Logger) (*os.File, error) {
|
||||
|
||||
func ReleaseLock(f *os.File) {
|
||||
_ = syscall.Flock(int(f.Fd()), syscall.LOCK_UN)
|
||||
|
||||
lockedStat, lockedErr := f.Stat()
|
||||
_ = f.Close()
|
||||
_ = os.Remove(lockFileName)
|
||||
|
||||
if lockedErr != nil {
|
||||
_ = os.Remove(lockFileName)
|
||||
return
|
||||
}
|
||||
|
||||
diskStat, diskErr := os.Stat(lockFileName)
|
||||
if diskErr != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if os.SameFile(lockedStat, diskStat) {
|
||||
_ = os.Remove(lockFileName)
|
||||
}
|
||||
}
|
||||
|
||||
func ReadLockFilePID() (int, error) {
|
||||
|
||||
@@ -21,7 +21,7 @@ func Test_NewLockWatcher_CapturesInode(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
_, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
watcher, err := NewLockWatcher(lockFile, cancel, log)
|
||||
@@ -37,7 +37,7 @@ func Test_LockWatcher_FileUnchanged_ContextNotCancelled(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
watcher, err := NewLockWatcher(lockFile, cancel, log)
|
||||
@@ -62,7 +62,7 @@ func Test_LockWatcher_FileDeleted_CancelsContext(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
watcher, err := NewLockWatcher(lockFile, cancel, log)
|
||||
@@ -88,7 +88,7 @@ func Test_LockWatcher_FileReplacedWithDifferentInode_CancelsContext(t *testing.T
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
watcher, err := NewLockWatcher(lockFile, cancel, log)
|
||||
|
||||
@@ -205,6 +205,15 @@ func verifyPgBasebackupDocker(cfg *config.Config, log *slog.Logger) error {
|
||||
}
|
||||
|
||||
func verifyDatabase(cfg *config.Config, log *slog.Logger) error {
|
||||
switch cfg.PgType {
|
||||
case "docker":
|
||||
return verifyDatabaseDocker(cfg, log)
|
||||
default:
|
||||
return verifyDatabaseHost(cfg, log)
|
||||
}
|
||||
}
|
||||
|
||||
func verifyDatabaseHost(cfg *config.Config, log *slog.Logger) error {
|
||||
connStr := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=postgres sslmode=disable",
|
||||
cfg.PgHost, cfg.PgPort, cfg.PgUser, cfg.PgPassword,
|
||||
@@ -255,6 +264,51 @@ func verifyDatabase(cfg *config.Config, log *slog.Logger) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyDatabaseDocker(cfg *config.Config, log *slog.Logger) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dbVerifyTimeout)
|
||||
defer cancel()
|
||||
|
||||
query := "SELECT current_setting('server_version_num')"
|
||||
|
||||
cmd := exec.CommandContext(ctx,
|
||||
"docker", "exec",
|
||||
"-e", "PGPASSWORD="+cfg.PgPassword,
|
||||
cfg.PgDockerContainerName,
|
||||
"psql", "-h", "localhost", "-p", "5432", "-U", cfg.PgUser,
|
||||
"-d", "postgres", "-t", "-A", "-c", query,
|
||||
)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to connect to PostgreSQL in container '%s' as user '%s': %w (output: %s)",
|
||||
cfg.PgDockerContainerName, cfg.PgUser, err, strings.TrimSpace(string(output)),
|
||||
)
|
||||
}
|
||||
|
||||
versionNumStr := strings.TrimSpace(string(output))
|
||||
|
||||
majorVersion, err := parsePgVersionNum(versionNumStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse PostgreSQL version '%s': %w", versionNumStr, err)
|
||||
}
|
||||
|
||||
if majorVersion < minPgMajorVersion {
|
||||
return fmt.Errorf(
|
||||
"PostgreSQL %d is not supported, minimum required version is %d",
|
||||
majorVersion, minPgMajorVersion,
|
||||
)
|
||||
}
|
||||
|
||||
log.Info("PostgreSQL connection verified (docker)",
|
||||
"container", cfg.PgDockerContainerName,
|
||||
"user", cfg.PgUser,
|
||||
"version", majorVersion,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parsePgVersionNum(versionNumStr string) (int, error) {
|
||||
versionNum, err := strconv.Atoi(strings.TrimSpace(versionNumStr))
|
||||
if err != nil {
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"databasus-agent/internal/features/api"
|
||||
)
|
||||
|
||||
const backgroundCheckInterval = 5 * time.Second
|
||||
const backgroundCheckInterval = 10 * time.Second
|
||||
|
||||
type BackgroundUpgrader struct {
|
||||
apiClient *api.Client
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -18,8 +18,10 @@ import (
|
||||
"databasus-agent/internal/features/api"
|
||||
)
|
||||
|
||||
var uploadIdleTimeout = 5 * time.Minute
|
||||
|
||||
const (
|
||||
pollInterval = 2 * time.Second
|
||||
pollInterval = 10 * time.Second
|
||||
uploadTimeout = 5 * time.Minute
|
||||
)
|
||||
|
||||
@@ -65,6 +67,13 @@ func (s *Streamer) processQueue(ctx context.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(segments) == 0 {
|
||||
s.log.Info("No WAL segments pending", "dir", s.cfg.PgWalDir)
|
||||
return
|
||||
}
|
||||
|
||||
s.log.Info("WAL segments pending upload", "dir", s.cfg.PgWalDir, "count", len(segments))
|
||||
|
||||
for _, segmentName := range segments {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
@@ -106,7 +115,7 @@ func (s *Streamer) listSegments() ([]string, error) {
|
||||
segments = append(segments, name)
|
||||
}
|
||||
|
||||
sort.Strings(segments)
|
||||
slices.Sort(segments)
|
||||
|
||||
return segments, nil
|
||||
}
|
||||
@@ -115,14 +124,27 @@ func (s *Streamer) uploadSegment(ctx context.Context, segmentName string) error
|
||||
filePath := filepath.Join(s.cfg.PgWalDir, segmentName)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
defer func() { _ = pr.Close() }()
|
||||
|
||||
go s.compressAndStream(pw, filePath)
|
||||
|
||||
uploadCtx, cancel := context.WithTimeout(ctx, uploadTimeout)
|
||||
defer cancel()
|
||||
uploadCtx, timeoutCancel := context.WithTimeout(ctx, uploadTimeout)
|
||||
defer timeoutCancel()
|
||||
|
||||
result, err := s.apiClient.UploadWalSegment(uploadCtx, segmentName, pr)
|
||||
idleCtx, idleCancel := context.WithCancelCause(uploadCtx)
|
||||
defer idleCancel(nil)
|
||||
|
||||
idleReader := api.NewIdleTimeoutReader(pr, uploadIdleTimeout, idleCancel)
|
||||
defer idleReader.Stop()
|
||||
|
||||
s.log.Info("Uploading WAL segment", "segment", segmentName)
|
||||
|
||||
result, err := s.apiClient.UploadWalSegment(idleCtx, segmentName, idleReader)
|
||||
if err != nil {
|
||||
if cause := context.Cause(idleCtx); cause != nil {
|
||||
return fmt.Errorf("upload WAL segment: %w", cause)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -136,7 +158,7 @@ func (s *Streamer) uploadSegment(ctx context.Context, segmentName string) error
|
||||
return fmt.Errorf("gap detected for segment %s", segmentName)
|
||||
}
|
||||
|
||||
s.log.Debug("WAL segment uploaded", "segment", segmentName)
|
||||
s.log.Info("WAL segment uploaded", "segment", segmentName)
|
||||
|
||||
if *s.cfg.IsDeleteWalAfterUpload {
|
||||
if err := os.Remove(filePath); err != nil {
|
||||
|
||||
@@ -2,6 +2,7 @@ package wal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -42,7 +44,7 @@ func Test_UploadSegment_SingleSegment_ServerReceivesCorrectHeadersAndBody(t *tes
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
@@ -79,7 +81,7 @@ func Test_UploadSegments_MultipleSegmentsOutOfOrder_UploadedInAscendingOrder(t *
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
@@ -115,7 +117,7 @@ func Test_UploadSegments_DirectoryHasTmpFiles_TmpFilesIgnored(t *testing.T) {
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
@@ -146,7 +148,7 @@ func Test_UploadSegment_DeleteEnabled_FileRemovedAfterUpload(t *testing.T) {
|
||||
apiClient := api.NewClient(server.URL, cfg.Token, logger.GetLogger())
|
||||
streamer := NewStreamer(cfg, apiClient, logger.GetLogger())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
@@ -174,7 +176,7 @@ func Test_UploadSegment_DeleteDisabled_FileKeptAfterUpload(t *testing.T) {
|
||||
apiClient := api.NewClient(server.URL, cfg.Token, logger.GetLogger())
|
||||
streamer := NewStreamer(cfg, apiClient, logger.GetLogger())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
@@ -199,7 +201,7 @@ func Test_UploadSegment_ServerReturns500_FileKeptInQueue(t *testing.T) {
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
@@ -223,7 +225,7 @@ func Test_ProcessQueue_EmptyDirectory_NoUploads(t *testing.T) {
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
@@ -238,7 +240,7 @@ func Test_Run_ContextCancelled_StopsImmediately(t *testing.T) {
|
||||
|
||||
streamer := newTestStreamer(walDir, "http://localhost:0")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
@@ -276,7 +278,7 @@ func Test_UploadSegment_ServerReturns409_FileNotDeleted(t *testing.T) {
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
@@ -287,6 +289,49 @@ func Test_UploadSegment_ServerReturns409_FileNotDeleted(t *testing.T) {
|
||||
assert.NoError(t, err, "segment file should not be deleted on gap detection")
|
||||
}
|
||||
|
||||
func Test_UploadSegment_WhenUploadStalls_FailsWithIdleTimeout(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
|
||||
// Use incompressible random data to ensure TCP buffers fill up
|
||||
segmentContent := make([]byte, 1024*1024)
|
||||
_, err := rand.Read(segmentContent)
|
||||
require.NoError(t, err)
|
||||
|
||||
writeTestSegment(t, walDir, "000000010000000100000001", segmentContent)
|
||||
|
||||
var requestReceived atomic.Bool
|
||||
handlerDone := make(chan struct{})
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestReceived.Store(true)
|
||||
|
||||
// Read one byte then stall — simulates a network stall
|
||||
buf := make([]byte, 1)
|
||||
_, _ = r.Body.Read(buf)
|
||||
<-handlerDone
|
||||
}))
|
||||
defer server.Close()
|
||||
defer close(handlerDone)
|
||||
|
||||
origIdleTimeout := uploadIdleTimeout
|
||||
uploadIdleTimeout = 200 * time.Millisecond
|
||||
defer func() { uploadIdleTimeout = origIdleTimeout }()
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
uploadErr := streamer.uploadSegment(ctx, "000000010000000100000001")
|
||||
|
||||
assert.Error(t, uploadErr, "upload should fail when stalled")
|
||||
assert.True(t, requestReceived.Load(), "server should have received the request")
|
||||
assert.Contains(t, uploadErr.Error(), "idle timeout", "error should mention idle timeout")
|
||||
|
||||
_, statErr := os.Stat(filepath.Join(walDir, "000000010000000100000001"))
|
||||
assert.NoError(t, statErr, "segment file should remain in queue after idle timeout")
|
||||
}
|
||||
|
||||
func newTestStreamer(walDir, serverURL string) *Streamer {
|
||||
cfg := createTestConfig(walDir, serverURL)
|
||||
apiClient := api.NewClient(serverURL, cfg.Token, logger.GetLogger())
|
||||
|
||||
@@ -64,16 +64,12 @@ func (w *rotatingWriter) rotate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
loggerInstance *slog.Logger
|
||||
once sync.Once
|
||||
)
|
||||
var loggerInstance *slog.Logger
|
||||
|
||||
var initLogger = sync.OnceFunc(initialize)
|
||||
|
||||
func GetLogger() *slog.Logger {
|
||||
once.Do(func() {
|
||||
initialize()
|
||||
})
|
||||
|
||||
initLogger()
|
||||
return loggerInstance
|
||||
}
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ func Test_Write_MultipleSmallWrites_CurrentSizeAccumulated(t *testing.T) {
|
||||
rw, _, _ := setupRotatingWriter(t, 1024)
|
||||
|
||||
var totalWritten int64
|
||||
for i := 0; i < 10; i++ {
|
||||
for range 10 {
|
||||
data := []byte("line\n")
|
||||
n, err := rw.Write(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -27,6 +27,13 @@ VALKEY_PORT=6379
|
||||
VALKEY_USERNAME=
|
||||
VALKEY_PASSWORD=
|
||||
VALKEY_IS_SSL=false
|
||||
# billing
|
||||
PRICE_PER_GB_CENTS=
|
||||
IS_PADDLE_SANDBOX=true
|
||||
PADDLE_API_KEY=
|
||||
PADDLE_WEBHOOK_SECRET=
|
||||
PADDLE_PRICE_ID=
|
||||
PADDLE_CLIENT_TOKEN=
|
||||
# testing
|
||||
# to get Google Drive env variables: add storage in UI and copy data from added storage here
|
||||
TEST_GOOGLE_DRIVE_CLIENT_ID=
|
||||
@@ -45,9 +52,6 @@ TEST_MINIO_PORT=9000
|
||||
TEST_MINIO_CONSOLE_PORT=9001
|
||||
# testing NAS
|
||||
TEST_NAS_PORT=7006
|
||||
# testing Telegram
|
||||
TEST_TELEGRAM_BOT_TOKEN=
|
||||
TEST_TELEGRAM_CHAT_ID=
|
||||
# testing Azure Blob Storage
|
||||
TEST_AZURITE_BLOB_PORT=10000
|
||||
# supabase
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime/debug"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -25,6 +26,8 @@ import (
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
backups_services "databasus-backend/internal/features/backups/backups/services"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/billing"
|
||||
billing_paddle "databasus-backend/internal/features/billing/paddle"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/disk"
|
||||
"databasus-backend/internal/features/encryption/secrets"
|
||||
@@ -105,7 +108,9 @@ func main() {
|
||||
go generateSwaggerDocs(log)
|
||||
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
ginApp := gin.Default()
|
||||
ginApp := gin.New()
|
||||
ginApp.Use(gin.Logger())
|
||||
ginApp.Use(ginRecoveryWithLogger(log))
|
||||
|
||||
// Add GZIP compression middleware
|
||||
ginApp.Use(gzip.Gzip(
|
||||
@@ -188,7 +193,7 @@ func startServerWithGracefulShutdown(log *slog.Logger, app *gin.Engine) {
|
||||
log.Info("Shutdown signal received")
|
||||
|
||||
// Gracefully shutdown VictoriaLogs writer
|
||||
logger.ShutdownVictoriaLogs(5 * time.Second)
|
||||
logger.ShutdownVictoriaLogs()
|
||||
|
||||
// The context is used to inform the server it has 10 seconds to finish
|
||||
// the request it is currently handling
|
||||
@@ -217,6 +222,10 @@ func setUpRoutes(r *gin.Engine) {
|
||||
backups_controllers.GetPostgresWalBackupController().RegisterRoutes(v1)
|
||||
databases.GetDatabaseController().RegisterPublicRoutes(v1)
|
||||
|
||||
if config.GetEnv().IsCloud {
|
||||
billing_paddle.GetPaddleBillingController().RegisterPublicRoutes(v1)
|
||||
}
|
||||
|
||||
// Setup auth middleware
|
||||
userService := users_services.GetUserService()
|
||||
authMiddleware := users_middleware.AuthMiddleware(userService)
|
||||
@@ -240,6 +249,7 @@ func setUpRoutes(r *gin.Engine) {
|
||||
audit_logs.GetAuditLogController().RegisterRoutes(protected)
|
||||
users_controllers.GetManagementController().RegisterRoutes(protected)
|
||||
users_controllers.GetSettingsController().RegisterRoutes(protected)
|
||||
billing.GetBillingController().RegisterRoutes(protected)
|
||||
}
|
||||
|
||||
func setUpDependencies() {
|
||||
@@ -252,6 +262,11 @@ func setUpDependencies() {
|
||||
storages.SetupDependencies()
|
||||
backups_config.SetupDependencies()
|
||||
task_cancellation.SetupDependencies()
|
||||
billing.SetupDependencies()
|
||||
|
||||
if config.GetEnv().IsCloud {
|
||||
billing_paddle.SetupDependencies()
|
||||
}
|
||||
}
|
||||
|
||||
func runBackgroundTasks(log *slog.Logger) {
|
||||
@@ -308,6 +323,12 @@ func runBackgroundTasks(log *slog.Logger) {
|
||||
go runWithPanicLogging(log, "restore nodes registry background service", func() {
|
||||
restoring.GetRestoreNodesRegistry().Run(ctx)
|
||||
})
|
||||
|
||||
if config.GetEnv().IsCloud {
|
||||
go runWithPanicLogging(log, "billing background service", func() {
|
||||
billing.GetBillingService().Run(ctx, *log)
|
||||
})
|
||||
}
|
||||
} else {
|
||||
log.Info("Skipping primary node tasks as not primary node")
|
||||
}
|
||||
@@ -330,7 +351,7 @@ func runBackgroundTasks(log *slog.Logger) {
|
||||
func runWithPanicLogging(log *slog.Logger, serviceName string, fn func()) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error("Panic in "+serviceName, "error", r)
|
||||
log.Error("Panic in "+serviceName, "error", r, "stacktrace", string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
fn()
|
||||
@@ -410,6 +431,25 @@ func enableCors(ginApp *gin.Engine) {
|
||||
}
|
||||
}
|
||||
|
||||
func ginRecoveryWithLogger(log *slog.Logger) gin.HandlerFunc {
|
||||
return func(ctx *gin.Context) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error("Panic recovered in HTTP handler",
|
||||
"error", r,
|
||||
"stacktrace", string(debug.Stack()),
|
||||
"method", ctx.Request.Method,
|
||||
"path", ctx.Request.URL.Path,
|
||||
)
|
||||
|
||||
ctx.AbortWithStatus(http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
|
||||
ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func mountFrontend(ginApp *gin.Engine) {
|
||||
staticDir := "./ui/build"
|
||||
ginApp.NoRoute(func(c *gin.Context) {
|
||||
|
||||
@@ -5,6 +5,7 @@ go 1.26.1
|
||||
require (
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0
|
||||
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.3
|
||||
github.com/PaddleHQ/paddle-go-sdk v1.0.0
|
||||
github.com/gin-contrib/cors v1.7.5
|
||||
github.com/gin-contrib/gzip v1.2.3
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
@@ -100,6 +101,8 @@ require (
|
||||
github.com/emersion/go-message v0.18.2 // indirect
|
||||
github.com/emersion/go-vcard v0.0.0-20241024213814-c9703dde27ff // indirect
|
||||
github.com/flynn/noise v1.1.0 // indirect
|
||||
github.com/ggicci/httpin v0.19.0 // indirect
|
||||
github.com/ggicci/owl v0.8.2 // indirect
|
||||
github.com/go-chi/chi/v5 v5.2.3 // indirect
|
||||
github.com/go-darwin/apfs v0.0.0-20211011131704-f84b94dbf348 // indirect
|
||||
github.com/go-git/go-billy/v5 v5.6.2 // indirect
|
||||
|
||||
@@ -77,6 +77,8 @@ github.com/Max-Sum/base32768 v0.0.0-20230304063302-18e6ce5945fd/go.mod h1:C8yoIf
|
||||
github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY=
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/PaddleHQ/paddle-go-sdk v1.0.0 h1:+EXitsPFbRcc0CpQE/MIeudxiVOR8pFe/aOWTEUHDKU=
|
||||
github.com/PaddleHQ/paddle-go-sdk v1.0.0/go.mod h1:kbBBzf0BHEj38QvhtoELqlGip3alKgA/I+vl7RQzB58=
|
||||
github.com/ProtonMail/bcrypt v0.0.0-20210511135022-227b4adcab57/go.mod h1:HecWFHognK8GfRDGnFQbW/LiV7A3MX3gZVs45vk5h8I=
|
||||
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf h1:yc9daCCYUefEs69zUkSzubzjBbL+cmOXgnmt9Fyd9ug=
|
||||
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf/go.mod h1:o0ESU9p83twszAU8LBeJKFAAMX14tISa0yk4Oo5TOqo=
|
||||
@@ -248,6 +250,10 @@ github.com/gabriel-vasile/mimetype v1.4.11/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9t
|
||||
github.com/geoffgarside/ber v1.1.0/go.mod h1:jVPKeCbj6MvQZhwLYsGwaGI52oUorHoHKNecGT85ZCc=
|
||||
github.com/geoffgarside/ber v1.2.0 h1:/loowoRcs/MWLYmGX9QtIAbA+V/FrnVLsMMPhwiRm64=
|
||||
github.com/geoffgarside/ber v1.2.0/go.mod h1:jVPKeCbj6MvQZhwLYsGwaGI52oUorHoHKNecGT85ZCc=
|
||||
github.com/ggicci/httpin v0.19.0 h1:p0B3SWLVgg770VirYiHB14M5wdRx3zR8mCTzM/TkTQ8=
|
||||
github.com/ggicci/httpin v0.19.0/go.mod h1:hzsQHcbqLabmGOycf7WNw6AAzcVbsMeoOp46bWAbIWc=
|
||||
github.com/ggicci/owl v0.8.2 h1:og+lhqpzSMPDdEB+NJfzoAJARP7qCG3f8uUC3xvGukA=
|
||||
github.com/ggicci/owl v0.8.2/go.mod h1:PHRD57u41vFN5UtFz2SF79yTVoM3HlWpjMiE+ZU2dj4=
|
||||
github.com/gin-contrib/cors v1.7.5 h1:cXC9SmofOrRg0w9PigwGlHG3ztswH6bqq4vJVXnvYMk=
|
||||
github.com/gin-contrib/cors v1.7.5/go.mod h1:4q3yi7xBEDDWKapjT2o1V7mScKDDr8k+jZ0fSquGoy0=
|
||||
github.com/gin-contrib/gzip v1.2.3 h1:dAhT722RuEG330ce2agAs75z7yB+NKvX/ZM1r8w0u2U=
|
||||
@@ -454,6 +460,8 @@ github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1
|
||||
github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk=
|
||||
github.com/jtolio/noiseconn v0.0.0-20231127013910-f6d9ecbf1de7 h1:JcltaO1HXM5S2KYOYcKgAV7slU0xPy1OcvrVgn98sRQ=
|
||||
github.com/jtolio/noiseconn v0.0.0-20231127013910-f6d9ecbf1de7/go.mod h1:MEkhEPFwP3yudWO0lj6vfYpLIB+3eIcuIW+e0AZzUQk=
|
||||
github.com/justinas/alice v1.2.0 h1:+MHSA/vccVCF4Uq37S42jwlkvI2Xzl7zTPCN5BnZNVo=
|
||||
github.com/justinas/alice v1.2.0/go.mod h1:fN5HRH/reO/zrUflLfTN43t3vXvKzvZIENsNEe7i7qA=
|
||||
github.com/jzelinskie/whirlpool v0.0.0-20201016144138-0675e54bb004 h1:G+9t9cEtnC9jFiTxyptEKuNIAbiN5ZCQzX2a74lj3xg=
|
||||
github.com/jzelinskie/whirlpool v0.0.0-20201016144138-0675e54bb004/go.mod h1:KmHnJWQrgEvbuy0vcvj00gtMqbvNn1L+3YUZLK/B92c=
|
||||
github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU=
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ilyakaznacheev/cleanenv"
|
||||
"github.com/joho/godotenv"
|
||||
@@ -53,6 +54,20 @@ type EnvVariables struct {
|
||||
TempFolder string
|
||||
SecretKeyPath string
|
||||
|
||||
// Billing (always tax-exclusive)
|
||||
PricePerGBCents int64 `env:"PRICE_PER_GB_CENTS"`
|
||||
MinStorageGB int
|
||||
MaxStorageGB int
|
||||
TrialDuration time.Duration
|
||||
TrialStorageGB int
|
||||
GracePeriod time.Duration
|
||||
// Paddle billing
|
||||
IsPaddleSandbox bool `env:"IS_PADDLE_SANDBOX"`
|
||||
PaddleApiKey string `env:"PADDLE_API_KEY"`
|
||||
PaddleWebhookSecret string `env:"PADDLE_WEBHOOK_SECRET"`
|
||||
PaddlePriceID string `env:"PADDLE_PRICE_ID"`
|
||||
PaddleClientToken string `env:"PADDLE_CLIENT_TOKEN"`
|
||||
|
||||
TestGoogleDriveClientID string `env:"TEST_GOOGLE_DRIVE_CLIENT_ID"`
|
||||
TestGoogleDriveClientSecret string `env:"TEST_GOOGLE_DRIVE_CLIENT_SECRET"`
|
||||
TestGoogleDriveTokenJSON string `env:"TEST_GOOGLE_DRIVE_TOKEN_JSON"`
|
||||
@@ -109,10 +124,6 @@ type EnvVariables struct {
|
||||
CloudflareTurnstileSecretKey string `env:"CLOUDFLARE_TURNSTILE_SECRET_KEY"`
|
||||
CloudflareTurnstileSiteKey string `env:"CLOUDFLARE_TURNSTILE_SITE_KEY"`
|
||||
|
||||
// testing Telegram
|
||||
TestTelegramBotToken string `env:"TEST_TELEGRAM_BOT_TOKEN"`
|
||||
TestTelegramChatID string `env:"TEST_TELEGRAM_CHAT_ID"`
|
||||
|
||||
// testing Supabase
|
||||
TestSupabaseHost string `env:"TEST_SUPABASE_HOST"`
|
||||
TestSupabasePort string `env:"TEST_SUPABASE_PORT"`
|
||||
@@ -131,14 +142,13 @@ type EnvVariables struct {
|
||||
DatabasusURL string `env:"DATABASUS_URL"`
|
||||
}
|
||||
|
||||
var (
|
||||
env EnvVariables
|
||||
once sync.Once
|
||||
)
|
||||
var env EnvVariables
|
||||
|
||||
func GetEnv() EnvVariables {
|
||||
once.Do(loadEnvVariables)
|
||||
return env
|
||||
var initEnv = sync.OnceFunc(loadEnvVariables)
|
||||
|
||||
func GetEnv() *EnvVariables {
|
||||
initEnv()
|
||||
return &env
|
||||
}
|
||||
|
||||
func loadEnvVariables() {
|
||||
@@ -365,16 +375,41 @@ func loadEnvVariables() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if env.TestTelegramBotToken == "" {
|
||||
log.Error("TEST_TELEGRAM_BOT_TOKEN is empty")
|
||||
}
|
||||
|
||||
// Billing
|
||||
if env.IsCloud {
|
||||
if env.PricePerGBCents == 0 {
|
||||
log.Error("PRICE_PER_GB_CENTS is empty or zero")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if env.TestTelegramChatID == "" {
|
||||
log.Error("TEST_TELEGRAM_CHAT_ID is empty")
|
||||
if env.PaddleApiKey == "" {
|
||||
log.Error("PADDLE_API_KEY is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if env.PaddleWebhookSecret == "" {
|
||||
log.Error("PADDLE_WEBHOOK_SECRET is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if env.PaddlePriceID == "" {
|
||||
log.Error("PADDLE_PRICE_ID is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if env.PaddleClientToken == "" {
|
||||
log.Error("PADDLE_CLIENT_TOKEN is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
env.MinStorageGB = 20
|
||||
env.MaxStorageGB = 10_000
|
||||
env.TrialDuration = 24 * time.Hour
|
||||
env.TrialStorageGB = 20
|
||||
env.GracePeriod = 30 * 24 * time.Hour
|
||||
|
||||
log.Info("Environment variables loaded successfully!")
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
@@ -13,39 +12,32 @@ type AuditLogBackgroundService struct {
|
||||
auditLogService *AuditLogService
|
||||
logger *slog.Logger
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *AuditLogBackgroundService) Run(ctx context.Context) {
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
if s.hasRun.Swap(true) {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
s.logger.Info("Starting audit log cleanup background service")
|
||||
|
||||
s.logger.Info("Starting audit log cleanup background service")
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.cleanOldAuditLogs(); err != nil {
|
||||
s.logger.Error("Failed to clean old audit logs", "error", err)
|
||||
}
|
||||
case <-ticker.C:
|
||||
if err := s.cleanOldAuditLogs(); err != nil {
|
||||
s.logger.Error("Failed to clean old audit logs", "error", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -102,7 +102,7 @@ func Test_CleanOldAuditLogs_DeletesMultipleOldLogs(t *testing.T) {
|
||||
|
||||
// Create many old logs with specific UUIDs to track them
|
||||
testLogIDs := make([]uuid.UUID, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
for i := range 5 {
|
||||
testLogIDs[i] = uuid.New()
|
||||
daysAgo := 400 + (i * 10)
|
||||
log := &AuditLog{
|
||||
|
||||
@@ -2,7 +2,6 @@ package audit_logs
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
"databasus-backend/internal/util/logger"
|
||||
@@ -23,8 +22,6 @@ var auditLogController = &AuditLogController{
|
||||
var auditLogBackgroundService = &AuditLogBackgroundService{
|
||||
auditLogService: auditLogService,
|
||||
logger: logger.GetLogger(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
|
||||
func GetAuditLogService() *AuditLogService {
|
||||
@@ -39,23 +36,8 @@ func GetAuditLogBackgroundService() *AuditLogBackgroundService {
|
||||
return auditLogBackgroundService
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
users_services.GetUserService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
var SetupDependencies = sync.OnceFunc(func() {
|
||||
users_services.GetUserService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
|
||||
})
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -46,80 +45,73 @@ type BackuperNode struct {
|
||||
|
||||
lastHeartbeat time.Time
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (n *BackuperNode) Run(ctx context.Context) {
|
||||
wasAlreadyRun := n.hasRun.Load()
|
||||
if n.hasRun.Swap(true) {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", n))
|
||||
}
|
||||
|
||||
n.runOnce.Do(func() {
|
||||
n.hasRun.Store(true)
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
|
||||
|
||||
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
|
||||
backupNode := BackupNode{
|
||||
ID: n.nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: time.Now().UTC(),
|
||||
}
|
||||
|
||||
backupNode := BackupNode{
|
||||
ID: n.nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: time.Now().UTC(),
|
||||
}
|
||||
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
|
||||
n.logger.Error("Failed to register node in registry", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
|
||||
n.logger.Error("Failed to register node in registry", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
|
||||
go func() {
|
||||
n.MakeBackup(backupID, isCallNotifier)
|
||||
if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil {
|
||||
n.logger.Error(
|
||||
"Failed to publish backup completion",
|
||||
"error",
|
||||
err,
|
||||
"backupID",
|
||||
backupID,
|
||||
)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
err := n.backupNodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID, backupHandler)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := n.backupNodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
|
||||
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
|
||||
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
|
||||
go func() {
|
||||
n.MakeBackup(backupID, isCallNotifier)
|
||||
if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil {
|
||||
n.logger.Error(
|
||||
"Failed to publish backup completion",
|
||||
"error",
|
||||
err,
|
||||
"backupID",
|
||||
backupID,
|
||||
)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(heartbeatTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
|
||||
|
||||
if err := n.backupNodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
|
||||
n.logger.Error("Failed to unregister node from registry", "error", err)
|
||||
}
|
||||
|
||||
return
|
||||
case <-ticker.C:
|
||||
n.sendHeartbeat(&backupNode)
|
||||
}
|
||||
err := n.backupNodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID, backupHandler)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := n.backupNodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
|
||||
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
|
||||
}
|
||||
})
|
||||
}()
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", n))
|
||||
ticker := time.NewTicker(heartbeatTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
|
||||
|
||||
if err := n.backupNodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
|
||||
n.logger.Error("Failed to unregister node from registry", "error", err)
|
||||
}
|
||||
|
||||
return
|
||||
case <-ticker.C:
|
||||
n.sendHeartbeat(&backupNode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -171,26 +163,6 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
backup.BackupSizeMb = completedMBs
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
// Check size limit (0 = unlimited)
|
||||
if backupConfig.MaxBackupSizeMB > 0 &&
|
||||
completedMBs > float64(backupConfig.MaxBackupSizeMB) {
|
||||
errMsg := fmt.Sprintf(
|
||||
"backup size (%.2f MB) exceeded maximum allowed size (%d MB)",
|
||||
completedMBs,
|
||||
backupConfig.MaxBackupSizeMB,
|
||||
)
|
||||
|
||||
backup.Status = backups_core.BackupStatusFailed
|
||||
backup.IsSkipRetry = true
|
||||
backup.FailMessage = &errMsg
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save backup with size exceeded error", "error", err)
|
||||
}
|
||||
cancel() // Cancel the backup context
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to update backup progress", "error", err)
|
||||
}
|
||||
@@ -308,7 +280,6 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
return
|
||||
}
|
||||
|
||||
backup.Status = backups_core.BackupStatusCompleted
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
// Update backup with encryption metadata if provided
|
||||
@@ -325,12 +296,6 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
backup.Encryption = backupMetadata.Encryption
|
||||
}
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save backup", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Save metadata file to storage
|
||||
if backupMetadata != nil {
|
||||
metadataJSON, err := json.Marshal(backupMetadata)
|
||||
if err != nil {
|
||||
@@ -363,6 +328,13 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
}
|
||||
}
|
||||
|
||||
backup.Status = backups_core.BackupStatusCompleted
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save backup", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Update database last backup time
|
||||
now := time.Now().UTC()
|
||||
if updateErr := n.databaseService.SetLastBackupTime(databaseID, now); updateErr != nil {
|
||||
|
||||
@@ -153,121 +153,3 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
assert.Equal(t, notifier.ID, capturedNotifier.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_BackupSizeLimits(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
t.Run("UnlimitedSize_MaxBackupSizeMBIsZero_BackupCompletes", func(t *testing.T) {
|
||||
// Enable backups with unlimited size (0)
|
||||
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
backupConfig.MaxBackupSizeMB = 0 // unlimited
|
||||
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.createBackupUseCase = &CreateLargeBackupUsecase{}
|
||||
|
||||
// Create a backup record
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode.MakeBackup(backup.ID, false)
|
||||
|
||||
// Verify backup completed successfully even with large size
|
||||
updatedBackup, err := backupRepository.FindByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
|
||||
assert.Equal(t, float64(10000), updatedBackup.BackupSizeMb)
|
||||
assert.Nil(t, updatedBackup.FailMessage)
|
||||
})
|
||||
|
||||
t.Run("SizeExceeded_BackupFailedWithIsSkipRetry", func(t *testing.T) {
|
||||
// Enable backups with 5 MB limit
|
||||
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
backupConfig.MaxBackupSizeMB = 5
|
||||
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.createBackupUseCase = &CreateProgressiveBackupUsecase{}
|
||||
|
||||
// Create a backup record
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode.MakeBackup(backup.ID, false)
|
||||
|
||||
// Verify backup was marked as failed with IsSkipRetry=true
|
||||
updatedBackup, err := backupRepository.FindByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusFailed, updatedBackup.Status)
|
||||
assert.True(t, updatedBackup.IsSkipRetry)
|
||||
assert.NotNil(t, updatedBackup.FailMessage)
|
||||
assert.Contains(t, *updatedBackup.FailMessage, "exceeded maximum allowed size")
|
||||
assert.Contains(t, *updatedBackup.FailMessage, "10.00 MB")
|
||||
assert.Contains(t, *updatedBackup.FailMessage, "5 MB")
|
||||
assert.Greater(t, updatedBackup.BackupSizeMb, float64(5))
|
||||
})
|
||||
|
||||
t.Run("SizeWithinLimit_BackupCompletes", func(t *testing.T) {
|
||||
// Enable backups with 100 MB limit
|
||||
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
backupConfig.MaxBackupSizeMB = 100
|
||||
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.createBackupUseCase = &CreateMediumBackupUsecase{}
|
||||
|
||||
// Create a backup record
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode.MakeBackup(backup.ID, false)
|
||||
|
||||
// Verify backup completed successfully
|
||||
updatedBackup, err := backupRepository.FindByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
|
||||
assert.Equal(t, float64(50), updatedBackup.BackupSizeMb)
|
||||
assert.Nil(t, updatedBackup.FailMessage)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,12 +4,12 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/storages"
|
||||
@@ -26,49 +26,47 @@ type BackupCleaner struct {
|
||||
backupRepository *backups_core.BackupRepository
|
||||
storageService *storages.StorageService
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
billingService BillingService
|
||||
fieldEncryptor util_encryption.FieldEncryptor
|
||||
logger *slog.Logger
|
||||
backupRemoveListeners []backups_core.BackupRemoveListener
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) Run(ctx context.Context) {
|
||||
wasAlreadyRun := c.hasRun.Load()
|
||||
if c.hasRun.Swap(true) {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", c))
|
||||
}
|
||||
|
||||
c.runOnce.Do(func() {
|
||||
c.hasRun.Store(true)
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
retentionLog := c.logger.With("task_name", "clean_by_retention_policy")
|
||||
exceededLog := c.logger.With("task_name", "clean_exceeded_storage_backups")
|
||||
staleLog := c.logger.With("task_name", "clean_stale_basebackups")
|
||||
|
||||
ticker := time.NewTicker(cleanerTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
case <-ticker.C:
|
||||
if err := c.cleanByRetentionPolicy(retentionLog); err != nil {
|
||||
retentionLog.Error("failed to clean backups by retention policy", "error", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(cleanerTickerInterval)
|
||||
defer ticker.Stop()
|
||||
if err := c.cleanExceededStorageBackups(exceededLog); err != nil {
|
||||
exceededLog.Error("failed to clean exceeded backups", "error", err)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := c.cleanByRetentionPolicy(); err != nil {
|
||||
c.logger.Error("Failed to clean backups by retention policy", "error", err)
|
||||
}
|
||||
|
||||
if err := c.cleanExceededBackups(); err != nil {
|
||||
c.logger.Error("Failed to clean exceeded backups", "error", err)
|
||||
}
|
||||
|
||||
if err := c.cleanStaleUploadedBasebackups(); err != nil {
|
||||
c.logger.Error("Failed to clean stale uploaded basebackups", "error", err)
|
||||
}
|
||||
if err := c.cleanStaleUploadedBasebackups(staleLog); err != nil {
|
||||
staleLog.Error("failed to clean stale uploaded basebackups", "error", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", c))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,7 +102,7 @@ func (c *BackupCleaner) AddBackupRemoveListener(listener backups_core.BackupRemo
|
||||
c.backupRemoveListeners = append(c.backupRemoveListeners, listener)
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanStaleUploadedBasebackups() error {
|
||||
func (c *BackupCleaner) cleanStaleUploadedBasebackups(logger *slog.Logger) error {
|
||||
staleBackups, err := c.backupRepository.FindStaleUploadedBasebackups(
|
||||
time.Now().UTC().Add(-10 * time.Minute),
|
||||
)
|
||||
@@ -113,31 +111,30 @@ func (c *BackupCleaner) cleanStaleUploadedBasebackups() error {
|
||||
}
|
||||
|
||||
for _, backup := range staleBackups {
|
||||
backupLog := logger.With("database_id", backup.DatabaseID, "backup_id", backup.ID)
|
||||
|
||||
staleStorage, storageErr := c.storageService.GetStorageByID(backup.StorageID)
|
||||
if storageErr != nil {
|
||||
c.logger.Error(
|
||||
"Failed to get storage for stale basebackup cleanup",
|
||||
"backupId", backup.ID,
|
||||
"storageId", backup.StorageID,
|
||||
backupLog.Error(
|
||||
"failed to get storage for stale basebackup cleanup",
|
||||
"storage_id", backup.StorageID,
|
||||
"error", storageErr,
|
||||
)
|
||||
} else {
|
||||
if err := staleStorage.DeleteFile(c.fieldEncryptor, backup.FileName); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to delete stale basebackup file",
|
||||
"backupId", backup.ID,
|
||||
"fileName", backup.FileName,
|
||||
"error", err,
|
||||
backupLog.Error(
|
||||
fmt.Sprintf("failed to delete stale basebackup file: %s", backup.FileName),
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
metadataFileName := backup.FileName + ".metadata"
|
||||
if err := staleStorage.DeleteFile(c.fieldEncryptor, metadataFileName); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to delete stale basebackup metadata file",
|
||||
"backupId", backup.ID,
|
||||
"fileName", metadataFileName,
|
||||
"error", err,
|
||||
backupLog.Error(
|
||||
fmt.Sprintf("failed to delete stale basebackup metadata file: %s", metadataFileName),
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -147,77 +144,67 @@ func (c *BackupCleaner) cleanStaleUploadedBasebackups() error {
|
||||
backup.FailMessage = &failMsg
|
||||
|
||||
if err := c.backupRepository.Save(backup); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to mark stale uploaded basebackup as failed",
|
||||
"backupId", backup.ID,
|
||||
"error", err,
|
||||
)
|
||||
backupLog.Error("failed to mark stale uploaded basebackup as failed", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Marked stale uploaded basebackup as failed and cleaned storage",
|
||||
"backupId", backup.ID,
|
||||
"databaseId", backup.DatabaseID,
|
||||
)
|
||||
backupLog.Info("marked stale uploaded basebackup as failed and cleaned storage")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanByRetentionPolicy() error {
|
||||
func (c *BackupCleaner) cleanByRetentionPolicy(logger *slog.Logger) error {
|
||||
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
dbLog := logger.With("database_id", backupConfig.DatabaseID, "policy", backupConfig.RetentionPolicyType)
|
||||
|
||||
var cleanErr error
|
||||
|
||||
switch backupConfig.RetentionPolicyType {
|
||||
case backups_config.RetentionPolicyTypeCount:
|
||||
cleanErr = c.cleanByCount(backupConfig)
|
||||
cleanErr = c.cleanByCount(dbLog, backupConfig)
|
||||
case backups_config.RetentionPolicyTypeGFS:
|
||||
cleanErr = c.cleanByGFS(backupConfig)
|
||||
cleanErr = c.cleanByGFS(dbLog, backupConfig)
|
||||
default:
|
||||
cleanErr = c.cleanByTimePeriod(backupConfig)
|
||||
cleanErr = c.cleanByTimePeriod(dbLog, backupConfig)
|
||||
}
|
||||
|
||||
if cleanErr != nil {
|
||||
c.logger.Error(
|
||||
"Failed to clean backups by retention policy",
|
||||
"databaseId", backupConfig.DatabaseID,
|
||||
"policy", backupConfig.RetentionPolicyType,
|
||||
"error", cleanErr,
|
||||
)
|
||||
dbLog.Error("failed to clean backups by retention policy", "error", cleanErr)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanExceededBackups() error {
|
||||
func (c *BackupCleaner) cleanExceededStorageBackups(logger *slog.Logger) error {
|
||||
if !config.GetEnv().IsCloud {
|
||||
return nil
|
||||
}
|
||||
|
||||
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
if backupConfig.MaxBackupsTotalSizeMB <= 0 {
|
||||
dbLog := logger.With("database_id", backupConfig.DatabaseID)
|
||||
|
||||
subscription, subErr := c.billingService.GetSubscription(dbLog, backupConfig.DatabaseID)
|
||||
if subErr != nil {
|
||||
dbLog.Error("failed to get subscription for exceeded backups check", "error", subErr)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := c.cleanExceededBackupsForDatabase(
|
||||
backupConfig.DatabaseID,
|
||||
backupConfig.MaxBackupsTotalSizeMB,
|
||||
); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to clean exceeded backups for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
storageLimitMB := int64(subscription.GetBackupsStorageGB()) * 1024
|
||||
|
||||
if err := c.cleanExceededBackupsForDatabase(dbLog, backupConfig.DatabaseID, storageLimitMB); err != nil {
|
||||
dbLog.Error("failed to clean exceeded backups for database", "error", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -225,7 +212,7 @@ func (c *BackupCleaner) cleanExceededBackups() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanByTimePeriod(backupConfig *backups_config.BackupConfig) error {
|
||||
func (c *BackupCleaner) cleanByTimePeriod(logger *slog.Logger, backupConfig *backups_config.BackupConfig) error {
|
||||
if backupConfig.RetentionTimePeriod == "" {
|
||||
return nil
|
||||
}
|
||||
@@ -255,21 +242,17 @@ func (c *BackupCleaner) cleanByTimePeriod(backupConfig *backups_config.BackupCon
|
||||
}
|
||||
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
c.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
|
||||
logger.Error("failed to delete old backup", "backup_id", backup.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Deleted old backup",
|
||||
"backupId", backup.ID,
|
||||
"databaseId", backupConfig.DatabaseID,
|
||||
)
|
||||
logger.Info("deleted old backup", "backup_id", backup.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanByCount(backupConfig *backups_config.BackupConfig) error {
|
||||
func (c *BackupCleaner) cleanByCount(logger *slog.Logger, backupConfig *backups_config.BackupConfig) error {
|
||||
if backupConfig.RetentionCount <= 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -298,28 +281,20 @@ func (c *BackupCleaner) cleanByCount(backupConfig *backups_config.BackupConfig)
|
||||
}
|
||||
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to delete backup by count policy",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
logger.Error("failed to delete backup by count policy", "backup_id", backup.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Deleted backup by count policy",
|
||||
"backupId", backup.ID,
|
||||
"databaseId", backupConfig.DatabaseID,
|
||||
"retentionCount", backupConfig.RetentionCount,
|
||||
logger.Info(
|
||||
fmt.Sprintf("deleted backup by count policy: retention count is %d", backupConfig.RetentionCount),
|
||||
"backup_id", backup.ID,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanByGFS(backupConfig *backups_config.BackupConfig) error {
|
||||
func (c *BackupCleaner) cleanByGFS(logger *slog.Logger, backupConfig *backups_config.BackupConfig) error {
|
||||
if backupConfig.RetentionGfsHours <= 0 && backupConfig.RetentionGfsDays <= 0 &&
|
||||
backupConfig.RetentionGfsWeeks <= 0 && backupConfig.RetentionGfsMonths <= 0 &&
|
||||
backupConfig.RetentionGfsYears <= 0 {
|
||||
@@ -357,29 +332,20 @@ func (c *BackupCleaner) cleanByGFS(backupConfig *backups_config.BackupConfig) er
|
||||
}
|
||||
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to delete backup by GFS policy",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
logger.Error("failed to delete backup by GFS policy", "backup_id", backup.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Deleted backup by GFS policy",
|
||||
"backupId", backup.ID,
|
||||
"databaseId", backupConfig.DatabaseID,
|
||||
)
|
||||
logger.Info("deleted backup by GFS policy", "backup_id", backup.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanExceededBackupsForDatabase(
|
||||
logger *slog.Logger,
|
||||
databaseID uuid.UUID,
|
||||
limitperDbMB int64,
|
||||
limitPerDbMB int64,
|
||||
) error {
|
||||
for {
|
||||
backupsTotalSizeMB, err := c.backupRepository.GetTotalSizeByDatabase(databaseID)
|
||||
@@ -387,7 +353,7 @@ func (c *BackupCleaner) cleanExceededBackupsForDatabase(
|
||||
return err
|
||||
}
|
||||
|
||||
if backupsTotalSizeMB <= float64(limitperDbMB) {
|
||||
if backupsTotalSizeMB <= float64(limitPerDbMB) {
|
||||
break
|
||||
}
|
||||
|
||||
@@ -400,59 +366,27 @@ func (c *BackupCleaner) cleanExceededBackupsForDatabase(
|
||||
}
|
||||
|
||||
if len(oldestBackups) == 0 {
|
||||
c.logger.Warn(
|
||||
"No backups to delete but still over limit",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"totalSizeMB",
|
||||
backupsTotalSizeMB,
|
||||
"limitMB",
|
||||
limitperDbMB,
|
||||
)
|
||||
logger.Warn(fmt.Sprintf(
|
||||
"no backups to delete but still over limit: total size is %.1f MB, limit is %d MB",
|
||||
backupsTotalSizeMB, limitPerDbMB,
|
||||
))
|
||||
break
|
||||
}
|
||||
|
||||
backup := oldestBackups[0]
|
||||
if isRecentBackup(backup) {
|
||||
c.logger.Warn(
|
||||
"Oldest backup is too recent to delete, stopping size cleanup",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"totalSizeMB",
|
||||
backupsTotalSizeMB,
|
||||
"limitMB",
|
||||
limitperDbMB,
|
||||
)
|
||||
break
|
||||
}
|
||||
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to delete exceeded backup",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
logger.Error("failed to delete exceeded backup", "backup_id", backup.ID, "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Deleted exceeded backup",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"backupSizeMB",
|
||||
backup.BackupSizeMb,
|
||||
"totalSizeMB",
|
||||
backupsTotalSizeMB,
|
||||
"limitMB",
|
||||
limitperDbMB,
|
||||
logger.Info(
|
||||
fmt.Sprintf("deleted exceeded backup: backup size is %.1f MB, total size is %.1f MB, limit is %d MB",
|
||||
backup.BackupSizeMb, backupsTotalSizeMB, limitPerDbMB),
|
||||
"backup_id", backup.ID,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ func Test_BuildGFSKeepSet(t *testing.T) {
|
||||
// backupsEveryDay returns n backups, newest-first, each 1 day apart.
|
||||
backupsEveryDay := func(n int) []*backups_core.Backup {
|
||||
bs := make([]*backups_core.Backup, n)
|
||||
for i := 0; i < n; i++ {
|
||||
for i := range n {
|
||||
bs[i] = newBackup(ref.Add(-time.Duration(i) * day))
|
||||
}
|
||||
return bs
|
||||
@@ -42,7 +42,7 @@ func Test_BuildGFSKeepSet(t *testing.T) {
|
||||
// backupsEveryWeek returns n backups, newest-first, each 7 days apart.
|
||||
backupsEveryWeek := func(n int) []*backups_core.Backup {
|
||||
bs := make([]*backups_core.Backup, n)
|
||||
for i := 0; i < n; i++ {
|
||||
for i := range n {
|
||||
bs[i] = newBackup(ref.Add(-time.Duration(i) * week))
|
||||
}
|
||||
return bs
|
||||
@@ -53,7 +53,7 @@ func Test_BuildGFSKeepSet(t *testing.T) {
|
||||
// backupsEveryHour returns n backups, newest-first, each 1 hour apart.
|
||||
backupsEveryHour := func(n int) []*backups_core.Backup {
|
||||
bs := make([]*backups_core.Backup, n)
|
||||
for i := 0; i < n; i++ {
|
||||
for i := range n {
|
||||
bs[i] = newBackup(ref.Add(-time.Duration(i) * hour))
|
||||
}
|
||||
return bs
|
||||
@@ -62,7 +62,7 @@ func Test_BuildGFSKeepSet(t *testing.T) {
|
||||
// backupsEveryMonth returns n backups, newest-first, each ~1 month apart.
|
||||
backupsEveryMonth := func(n int) []*backups_core.Backup {
|
||||
bs := make([]*backups_core.Backup, n)
|
||||
for i := 0; i < n; i++ {
|
||||
for i := range n {
|
||||
bs[i] = newBackup(ref.AddDate(0, -i, 0))
|
||||
}
|
||||
return bs
|
||||
@@ -71,7 +71,7 @@ func Test_BuildGFSKeepSet(t *testing.T) {
|
||||
// backupsEveryYear returns n backups, newest-first, each 1 year apart.
|
||||
backupsEveryYear := func(n int) []*backups_core.Backup {
|
||||
bs := make([]*backups_core.Backup, n)
|
||||
for i := 0; i < n; i++ {
|
||||
for i := range n {
|
||||
bs[i] = newBackup(ref.AddDate(-i, 0, 0))
|
||||
}
|
||||
return bs
|
||||
@@ -410,7 +410,7 @@ func Test_CleanByGFS_KeepsCorrectBackupsPerSlot(t *testing.T) {
|
||||
|
||||
// Create 5 backups on 5 different days; only the 3 newest days should be kept
|
||||
var backupIDs []uuid.UUID
|
||||
for i := 0; i < 5; i++ {
|
||||
for i := range 5 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
@@ -425,7 +425,7 @@ func Test_CleanByGFS_KeepsCorrectBackupsPerSlot(t *testing.T) {
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -486,7 +486,7 @@ func Test_CleanByGFS_WithWeeklyAndMonthlySlots_KeepsWiderSpread(t *testing.T) {
|
||||
// Create one backup per week for 6 weeks (each on Monday of that week)
|
||||
// GFS should keep: 2 daily (most recent 2 unique days) + 2 weekly + 1 monthly = up to 5 unique
|
||||
var createdIDs []uuid.UUID
|
||||
for i := 0; i < 6; i++ {
|
||||
for i := range 6 {
|
||||
weekOffset := time.Duration(5-i) * 7 * 24 * time.Hour
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
@@ -502,7 +502,7 @@ func Test_CleanByGFS_WithWeeklyAndMonthlySlots_KeepsWiderSpread(t *testing.T) {
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -561,7 +561,7 @@ func Test_CleanByGFS_WithHourlySlots_KeepsCorrectBackups(t *testing.T) {
|
||||
|
||||
// Create 5 backups spaced 1 hour apart; only the 3 newest hours should be kept
|
||||
var backupIDs []uuid.UUID
|
||||
for i := 0; i < 5; i++ {
|
||||
for i := range 5 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
@@ -576,7 +576,7 @@ func Test_CleanByGFS_WithHourlySlots_KeepsCorrectBackups(t *testing.T) {
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -677,7 +677,7 @@ func Test_CleanByGFS_SkipsRecentBackup_WhenNotInKeepSet(t *testing.T) {
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -759,7 +759,7 @@ func Test_CleanByGFS_With20DailyBackups_KeepsOnlyExpectedCount(t *testing.T) {
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -824,8 +824,8 @@ func Test_CleanByGFS_WithMultipleBackupsPerDay_KeepsOnlyOnePerDailySlot(t *testi
|
||||
|
||||
// Create 3 backups per day for 10 days = 30 total, all beyond grace period.
|
||||
// Each day gets backups at base+0h, base+6h, base+12h.
|
||||
for day := 0; day < 10; day++ {
|
||||
for sub := 0; sub < 3; sub++ {
|
||||
for day := range 10 {
|
||||
for sub := range 3 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
@@ -844,7 +844,7 @@ func Test_CleanByGFS_WithMultipleBackupsPerDay_KeepsOnlyOnePerDailySlot(t *testi
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -915,7 +915,7 @@ func Test_CleanByGFS_With24HourlySlotsAnd23DailyBackups_DeletesExcessBackups(t *
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
for i := 0; i < 23; i++ {
|
||||
for i := range 23 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
@@ -929,7 +929,7 @@ func Test_CleanByGFS_With24HourlySlotsAnd23DailyBackups_DeletesExcessBackups(t *
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -985,7 +985,7 @@ func Test_CleanByGFS_WithDisabledHourlySlotsAnd23DailyBackups_DeletesExcessBacku
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
for i := 0; i < 23; i++ {
|
||||
for i := range 23 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
@@ -999,7 +999,7 @@ func Test_CleanByGFS_WithDisabledHourlySlotsAnd23DailyBackups_DeletesExcessBacku
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -1055,7 +1055,7 @@ func Test_CleanByGFS_WithDailySlotsAndWeeklyBackups_DeletesExcessBackups(t *test
|
||||
// Create 10 weekly backups (1 per week, all >2h old past grace period).
|
||||
// With 7d/4w config, correct behavior: ~8 kept (4 weekly + overlap with daily for recent ones).
|
||||
// Daily slots should NOT absorb weekly backups that are older than 7 days.
|
||||
for i := 0; i < 10; i++ {
|
||||
for i := range 10 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
@@ -1069,7 +1069,7 @@ func Test_CleanByGFS_WithDailySlotsAndWeeklyBackups_DeletesExcessBackups(t *test
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -1138,7 +1138,7 @@ func Test_CleanByGFS_WithWeeklySlotsAndMonthlyBackups_DeletesExcessBackups(t *te
|
||||
// With 52w/3m config, correct behavior: 3 kept (3 monthly slots; weekly should only
|
||||
// cover recent 52 weeks but not artificially retain old monthly backups).
|
||||
// Bug: all 8 kept because each monthly backup fills a unique weekly slot.
|
||||
for i := 0; i < 8; i++ {
|
||||
for i := range 8 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
@@ -1152,7 +1152,7 @@ func Test_CleanByGFS_WithWeeklySlotsAndMonthlyBackups_DeletesExcessBackups(t *te
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
billing_models "databasus-backend/internal/features/billing/models"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
@@ -17,6 +20,7 @@ import (
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/storage"
|
||||
"databasus-backend/internal/util/logger"
|
||||
"databasus-backend/internal/util/period"
|
||||
)
|
||||
|
||||
@@ -51,6 +55,7 @@ func Test_CleanOldBackups_DeletesBackupsOlderThanRetentionTimePeriod(t *testing.
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
@@ -89,7 +94,7 @@ func Test_CleanOldBackups_DeletesBackupsOlderThanRetentionTimePeriod(t *testing.
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -129,6 +134,7 @@ func Test_CleanOldBackups_SkipsDatabaseWithForeverRetentionPeriod(t *testing.T)
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
@@ -145,7 +151,7 @@ func Test_CleanOldBackups_SkipsDatabaseWithForeverRetentionPeriod(t *testing.T)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -154,7 +160,8 @@ func Test_CleanOldBackups_SkipsDatabaseWithForeverRetentionPeriod(t *testing.T)
|
||||
assert.Equal(t, oldBackup.ID, remainingBackups[0].ID)
|
||||
}
|
||||
|
||||
func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
|
||||
func Test_CleanExceededBackups_WhenUnderStorageLimit_NoBackupsDeleted(t *testing.T) {
|
||||
enableCloud(t)
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
@@ -178,33 +185,36 @@ func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
|
||||
interval := createTestInterval()
|
||||
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
MaxBackupsTotalSizeMB: 100,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
for i := range 3 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 16.67,
|
||||
BackupSizeMb: 100,
|
||||
CreatedAt: time.Now().UTC().Add(-time.Duration(i) * time.Hour),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanExceededBackups()
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
|
||||
}
|
||||
cleaner := CreateTestBackupCleaner(mockBilling)
|
||||
err = cleaner.cleanExceededStorageBackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -212,7 +222,8 @@ func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
|
||||
assert.Equal(t, 3, len(remainingBackups))
|
||||
}
|
||||
|
||||
func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T) {
|
||||
func Test_CleanExceededBackups_WhenOverStorageLimit_DeletesOldestBackups(t *testing.T) {
|
||||
enableCloud(t)
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
@@ -236,27 +247,29 @@ func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T)
|
||||
interval := createTestInterval()
|
||||
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
MaxBackupsTotalSizeMB: 30,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 5 backups at 300 MB each = 1500 MB total, limit = 1 GB (1024 MB)
|
||||
// Expect 2 oldest deleted, 3 remain (900 MB < 1024 MB)
|
||||
now := time.Now().UTC()
|
||||
var backupIDs []uuid.UUID
|
||||
for i := 0; i < 5; i++ {
|
||||
for i := range 5 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10,
|
||||
BackupSizeMb: 300,
|
||||
CreatedAt: now.Add(-time.Duration(4-i) * time.Hour),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
@@ -264,8 +277,11 @@ func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T)
|
||||
backupIDs = append(backupIDs, backup.ID)
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanExceededBackups()
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
|
||||
}
|
||||
cleaner := CreateTestBackupCleaner(mockBilling)
|
||||
err = cleaner.cleanExceededStorageBackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -284,6 +300,7 @@ func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T)
|
||||
}
|
||||
|
||||
func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
|
||||
enableCloud(t)
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
@@ -307,28 +324,29 @@ func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
|
||||
interval := createTestInterval()
|
||||
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
MaxBackupsTotalSizeMB: 50,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
// 3 completed at 500 MB each = 1500 MB, limit = 1 GB (1024 MB)
|
||||
completedBackups := make([]*backups_core.Backup, 3)
|
||||
for i := 0; i < 3; i++ {
|
||||
for i := range 3 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 30,
|
||||
BackupSizeMb: 500,
|
||||
CreatedAt: now.Add(-time.Duration(3-i) * time.Hour),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
@@ -347,8 +365,11 @@ func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
|
||||
err = backupRepository.Save(inProgressBackup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanExceededBackups()
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
|
||||
}
|
||||
cleaner := CreateTestBackupCleaner(mockBilling)
|
||||
err = cleaner.cleanExceededStorageBackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -365,7 +386,8 @@ func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
|
||||
assert.True(t, inProgressFound, "In-progress backup should not be deleted")
|
||||
}
|
||||
|
||||
func Test_CleanExceededBackups_WithZeroLimit_SkipsDatabase(t *testing.T) {
|
||||
func Test_CleanExceededBackups_WithZeroStorageLimit_RemovesAllBackups(t *testing.T) {
|
||||
enableCloud(t)
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
@@ -389,38 +411,42 @@ func Test_CleanExceededBackups_WithZeroLimit_SkipsDatabase(t *testing.T) {
|
||||
interval := createTestInterval()
|
||||
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
MaxBackupsTotalSizeMB: 0,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
for i := range 10 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 100,
|
||||
CreatedAt: time.Now().UTC().Add(-time.Duration(i) * time.Hour),
|
||||
CreatedAt: time.Now().UTC().Add(-time.Duration(i+2) * time.Hour),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanExceededBackups()
|
||||
// StorageGB=0 means no storage allowed — all backups should be removed
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{StorageGB: 0, Status: billing_models.StatusActive},
|
||||
}
|
||||
cleaner := CreateTestBackupCleaner(mockBilling)
|
||||
err = cleaner.cleanExceededStorageBackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 10, len(remainingBackups))
|
||||
assert.Equal(t, 0, len(remainingBackups))
|
||||
}
|
||||
|
||||
func Test_GetTotalSizeByDatabase_CalculatesCorrectly(t *testing.T) {
|
||||
@@ -522,13 +548,14 @@ func Test_CleanByCount_KeepsNewestNBackups_DeletesOlder(t *testing.T) {
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
now := time.Now().UTC()
|
||||
var backupIDs []uuid.UUID
|
||||
for i := 0; i < 5; i++ {
|
||||
for i := range 5 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
@@ -545,7 +572,7 @@ func Test_CleanByCount_KeepsNewestNBackups_DeletesOlder(t *testing.T) {
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -594,11 +621,12 @@ func Test_CleanByCount_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
for i := range 5 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
@@ -612,7 +640,7 @@ func Test_CleanByCount_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -651,13 +679,14 @@ func Test_CleanByCount_DoesNotDeleteInProgressBackups(t *testing.T) {
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
for i := range 3 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
@@ -682,7 +711,7 @@ func Test_CleanByCount_DoesNotDeleteInProgressBackups(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -776,6 +805,7 @@ func Test_CleanByTimePeriod_SkipsRecentBackup_EvenIfOlderThanRetention(t *testin
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
@@ -805,7 +835,7 @@ func Test_CleanByTimePeriod_SkipsRecentBackup_EvenIfOlderThanRetention(t *testin
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -847,6 +877,7 @@ func Test_CleanByCount_SkipsRecentBackup_EvenIfOverLimit(t *testing.T) {
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
@@ -893,7 +924,7 @@ func Test_CleanByCount_SkipsRecentBackup_EvenIfOverLimit(t *testing.T) {
|
||||
}
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanByRetentionPolicy()
|
||||
err = cleaner.cleanByRetentionPolicy(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -914,7 +945,8 @@ func Test_CleanByCount_SkipsRecentBackup_EvenIfOverLimit(t *testing.T) {
|
||||
assert.True(t, remainingIDs[newestBackup.ID], "Newest backup should be preserved")
|
||||
}
|
||||
|
||||
func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testing.T) {
|
||||
func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverStorageLimit(t *testing.T) {
|
||||
enableCloud(t)
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
@@ -937,18 +969,18 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
|
||||
|
||||
interval := createTestInterval()
|
||||
|
||||
// Total size limit is 10 MB. We have two backups of 8 MB each (16 MB total).
|
||||
// Total size limit = 1 GB (1024 MB). Two backups of 600 MB each (1200 MB total).
|
||||
// The oldest backup was created 30 minutes ago — within the grace period.
|
||||
// The cleaner must stop and leave both backups intact.
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
MaxBackupsTotalSizeMB: 10,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
@@ -960,7 +992,7 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 8,
|
||||
BackupSizeMb: 600,
|
||||
CreatedAt: now.Add(-30 * time.Minute),
|
||||
}
|
||||
newerRecentBackup := &backups_core.Backup{
|
||||
@@ -968,7 +1000,7 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 8,
|
||||
BackupSizeMb: 600,
|
||||
CreatedAt: now.Add(-10 * time.Minute),
|
||||
}
|
||||
|
||||
@@ -977,8 +1009,11 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
|
||||
err = backupRepository.Save(newerRecentBackup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanExceededBackups()
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
|
||||
}
|
||||
cleaner := CreateTestBackupCleaner(mockBilling)
|
||||
err = cleaner.cleanExceededStorageBackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -991,6 +1026,82 @@ func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testi
|
||||
)
|
||||
}
|
||||
|
||||
func Test_CleanExceededStorageBackups_WhenNonCloud_SkipsCleanup(t *testing.T) {
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
interval := createTestInterval()
|
||||
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: backups_config.RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
Encryption: backups_config.BackupEncryptionEncrypted,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 5 backups at 500 MB each = 2500 MB, would exceed 1 GB limit in cloud mode
|
||||
now := time.Now().UTC()
|
||||
for i := range 5 {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 500,
|
||||
CreatedAt: now.Add(-time.Duration(i+2) * time.Hour),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// IsCloud is false by default — cleaner should skip entirely
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{StorageGB: 1, Status: billing_models.StatusActive},
|
||||
}
|
||||
cleaner := CreateTestBackupCleaner(mockBilling)
|
||||
err = cleaner.cleanExceededStorageBackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 5, len(remainingBackups), "All backups must remain in non-cloud mode")
|
||||
}
|
||||
|
||||
type mockBillingService struct {
|
||||
subscription *billing_models.Subscription
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockBillingService) GetSubscription(
|
||||
logger *slog.Logger,
|
||||
databaseID uuid.UUID,
|
||||
) (*billing_models.Subscription, error) {
|
||||
return m.subscription, m.err
|
||||
}
|
||||
|
||||
// Mock listener for testing
|
||||
type mockBackupRemoveListener struct {
|
||||
onBeforeBackupRemove func(*backups_core.Backup) error
|
||||
@@ -1041,7 +1152,7 @@ func Test_CleanStaleUploadedBasebackups_MarksAsFailed(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanStaleUploadedBasebackups()
|
||||
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
updated, err := backupRepository.FindByID(staleBackup.ID)
|
||||
@@ -1088,7 +1199,7 @@ func Test_CleanStaleUploadedBasebackups_SkipsRecentUploads(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanStaleUploadedBasebackups()
|
||||
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
updated, err := backupRepository.FindByID(recentBackup.ID)
|
||||
@@ -1131,7 +1242,7 @@ func Test_CleanStaleUploadedBasebackups_SkipsActiveStreaming(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanStaleUploadedBasebackups()
|
||||
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
updated, err := backupRepository.FindByID(activeBackup.ID)
|
||||
@@ -1179,7 +1290,7 @@ func Test_CleanStaleUploadedBasebackups_CleansStorageFiles(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanStaleUploadedBasebackups()
|
||||
err = cleaner.cleanStaleUploadedBasebackups(testLogger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
updated, err := backupRepository.FindByID(staleBackup.ID)
|
||||
@@ -1189,6 +1300,18 @@ func Test_CleanStaleUploadedBasebackups_CleansStorageFiles(t *testing.T) {
|
||||
assert.Contains(t, *updated.FailMessage, "finalization timed out")
|
||||
}
|
||||
|
||||
func enableCloud(t *testing.T) {
|
||||
t.Helper()
|
||||
config.GetEnv().IsCloud = true
|
||||
t.Cleanup(func() {
|
||||
config.GetEnv().IsCloud = false
|
||||
})
|
||||
}
|
||||
|
||||
func testLogger() *slog.Logger {
|
||||
return logger.GetLogger().With("task_name", "test")
|
||||
}
|
||||
|
||||
func createTestInterval() *intervals.Interval {
|
||||
timeOfDay := "04:00"
|
||||
interval := &intervals.Interval{
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -10,6 +9,7 @@ import (
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/usecases"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/billing"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
@@ -28,10 +28,10 @@ var backupCleaner = &BackupCleaner{
|
||||
backupRepository,
|
||||
storages.GetStorageService(),
|
||||
backups_config.GetBackupConfigService(),
|
||||
billing.GetBillingService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
logger.GetLogger(),
|
||||
[]backups_core.BackupRemoveListener{},
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
@@ -41,7 +41,6 @@ var backupNodesRegistry = &BackupNodesRegistry{
|
||||
cache_utils.DefaultCacheTimeout,
|
||||
cache_utils.NewPubSubManager(),
|
||||
cache_utils.NewPubSubManager(),
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
@@ -63,7 +62,6 @@ var backuperNode = &BackuperNode{
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
getNodeID(),
|
||||
time.Time{},
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
@@ -73,11 +71,11 @@ var backupsScheduler = &BackupsScheduler{
|
||||
taskCancelManager,
|
||||
backupNodesRegistry,
|
||||
databases.GetDatabaseService(),
|
||||
billing.GetBillingService(),
|
||||
time.Now().UTC(),
|
||||
logger.GetLogger(),
|
||||
make(map[uuid.UUID]BackupToNodeRelation),
|
||||
backuperNode,
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
billing_models "databasus-backend/internal/features/billing/models"
|
||||
)
|
||||
|
||||
type BillingService interface {
|
||||
GetSubscription(logger *slog.Logger, databaseID uuid.UUID) (*billing_models.Subscription, error)
|
||||
}
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -50,36 +49,30 @@ type BackupNodesRegistry struct {
|
||||
pubsubBackups *cache_utils.PubSubManager
|
||||
pubsubCompletions *cache_utils.PubSubManager
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) Run(ctx context.Context) {
|
||||
wasAlreadyRun := r.hasRun.Load()
|
||||
if r.hasRun.Swap(true) {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", r))
|
||||
}
|
||||
|
||||
r.runOnce.Do(func() {
|
||||
r.hasRun.Store(true)
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
|
||||
}
|
||||
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
|
||||
}
|
||||
ticker := time.NewTicker(cleanupTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
ticker := time.NewTicker(cleanupTickerInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes", "error", err)
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes", "error", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", r))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,8 +4,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -322,7 +320,7 @@ func Test_GetAvailableNodes_SkipsInvalidJsonData(t *testing.T) {
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), registry.timeout)
|
||||
defer cancel()
|
||||
|
||||
invalidKey := nodeInfoKeyPrefix + uuid.New().String() + nodeInfoKeySuffix
|
||||
@@ -331,7 +329,7 @@ func Test_GetAvailableNodes_SkipsInvalidJsonData(t *testing.T) {
|
||||
registry.client.B().Set().Key(invalidKey).Value("invalid json data").Build(),
|
||||
)
|
||||
defer func() {
|
||||
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
cleanupCtx, cleanupCancel := context.WithTimeout(t.Context(), registry.timeout)
|
||||
defer cleanupCancel()
|
||||
registry.client.Do(cleanupCtx, registry.client.B().Del().Key(invalidKey).Build())
|
||||
}()
|
||||
@@ -401,7 +399,7 @@ func Test_GetAvailableNodes_ExcludesStaleNodesFromCache(t *testing.T) {
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), registry.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
|
||||
@@ -419,7 +417,7 @@ func Test_GetAvailableNodes_ExcludesStaleNodesFromCache(t *testing.T) {
|
||||
modifiedData, err := json.Marshal(node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
setCtx, setCancel := context.WithTimeout(t.Context(), registry.timeout)
|
||||
defer setCancel()
|
||||
setResult := registry.client.Do(
|
||||
setCtx,
|
||||
@@ -464,7 +462,7 @@ func Test_GetBackupNodesStats_ExcludesStaleNodesFromCache(t *testing.T) {
|
||||
err = registry.IncrementBackupsInProgress(node3.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), registry.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
|
||||
@@ -482,7 +480,7 @@ func Test_GetBackupNodesStats_ExcludesStaleNodesFromCache(t *testing.T) {
|
||||
modifiedData, err := json.Marshal(node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
setCtx, setCancel := context.WithTimeout(t.Context(), registry.timeout)
|
||||
defer setCancel()
|
||||
setResult := registry.client.Do(
|
||||
setCtx,
|
||||
@@ -524,7 +522,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) {
|
||||
err = registry.IncrementBackupsInProgress(node2.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), registry.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
|
||||
@@ -542,7 +540,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) {
|
||||
modifiedData, err := json.Marshal(node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
setCtx, setCancel := context.WithTimeout(t.Context(), registry.timeout)
|
||||
defer setCancel()
|
||||
setResult := registry.client.Do(
|
||||
setCtx,
|
||||
@@ -553,7 +551,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) {
|
||||
err = registry.cleanupDeadNodes()
|
||||
assert.NoError(t, err)
|
||||
|
||||
checkCtx, checkCancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
checkCtx, checkCancel := context.WithTimeout(t.Context(), registry.timeout)
|
||||
defer checkCancel()
|
||||
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
|
||||
@@ -566,7 +564,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) {
|
||||
node2.ID.String(),
|
||||
nodeActiveBackupsSuffix,
|
||||
)
|
||||
counterCtx, counterCancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
counterCtx, counterCancel := context.WithTimeout(t.Context(), registry.timeout)
|
||||
defer counterCancel()
|
||||
counterResult := registry.client.Do(
|
||||
counterCtx,
|
||||
@@ -575,7 +573,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) {
|
||||
assert.Error(t, counterResult.Error())
|
||||
|
||||
activeInfoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node1.ID.String(), nodeInfoKeySuffix)
|
||||
activeCtx, activeCancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
activeCtx, activeCancel := context.WithTimeout(t.Context(), registry.timeout)
|
||||
defer activeCancel()
|
||||
activeResult := registry.client.Do(
|
||||
activeCtx,
|
||||
@@ -601,8 +599,6 @@ func createTestRegistry() *BackupNodesRegistry {
|
||||
timeout: cache_utils.DefaultCacheTimeout,
|
||||
pubsubBackups: cache_utils.NewPubSubManager(),
|
||||
pubsubCompletions: cache_utils.NewPubSubManager(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -732,7 +728,7 @@ func Test_SubscribeNodeForBackupsAssignment_HandlesInvalidJson(t *testing.T) {
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
err = registry.pubsubBackups.Publish(ctx, "backup:submit", "invalid json")
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -978,7 +974,7 @@ func Test_SubscribeForBackupsCompletions_HandlesInvalidJson(t *testing.T) {
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
err = registry.pubsubCompletions.Publish(ctx, "backup:completion", "invalid json")
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -1093,7 +1089,7 @@ func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) {
|
||||
receivedAll2 := []uuid.UUID{}
|
||||
receivedAll3 := []uuid.UUID{}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
for range 3 {
|
||||
select {
|
||||
case received := <-receivedBackups1:
|
||||
receivedAll1 = append(receivedAll1, received)
|
||||
@@ -1102,7 +1098,7 @@ func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
for range 3 {
|
||||
select {
|
||||
case received := <-receivedBackups2:
|
||||
receivedAll2 = append(receivedAll2, received)
|
||||
@@ -1111,7 +1107,7 @@ func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
for range 3 {
|
||||
select {
|
||||
case received := <-receivedBackups3:
|
||||
receivedAll3 = append(receivedAll3, received)
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -29,6 +28,7 @@ type BackupsScheduler struct {
|
||||
taskCancelManager *task_cancellation.TaskCancelManager
|
||||
backupNodesRegistry *BackupNodesRegistry
|
||||
databaseService *databases.DatabaseService
|
||||
billingService BillingService
|
||||
|
||||
lastBackupTime time.Time
|
||||
logger *slog.Logger
|
||||
@@ -36,68 +36,61 @@ type BackupsScheduler struct {
|
||||
backupToNodeRelations map[uuid.UUID]BackupToNodeRelation
|
||||
backuperNode *BackuperNode
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) Run(ctx context.Context) {
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
|
||||
if config.GetEnv().IsManyNodesMode {
|
||||
// wait other nodes to start
|
||||
time.Sleep(schedulerStartupDelay)
|
||||
}
|
||||
|
||||
if err := s.failBackupsInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail backups in progress", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err := s.backupNodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to subscribe to backup completions", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := s.backupNodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
|
||||
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(schedulerTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.checkDeadNodesAndFailBackups(); err != nil {
|
||||
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.runPendingBackups(); err != nil {
|
||||
s.logger.Error("Failed to run pending backups", "error", err)
|
||||
}
|
||||
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
if s.hasRun.Swap(true) {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
|
||||
if config.GetEnv().IsManyNodesMode {
|
||||
// wait other nodes to start
|
||||
time.Sleep(schedulerStartupDelay)
|
||||
}
|
||||
|
||||
if err := s.failBackupsInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail backups in progress", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err := s.backupNodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to subscribe to backup completions", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := s.backupNodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
|
||||
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(schedulerTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.checkDeadNodesAndFailBackups(); err != nil {
|
||||
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.runPendingBackups(); err != nil {
|
||||
s.logger.Error("Failed to run pending backups", "error", err)
|
||||
}
|
||||
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) IsSchedulerRunning() bool {
|
||||
@@ -127,6 +120,34 @@ func (s *BackupsScheduler) StartBackup(database *databases.Database, isCallNotif
|
||||
return
|
||||
}
|
||||
|
||||
if config.GetEnv().IsCloud {
|
||||
subscription, subErr := s.billingService.GetSubscription(s.logger, database.ID)
|
||||
if subErr != nil || !subscription.CanCreateNewBackups() {
|
||||
failMessage := "subscription has expired, please renew"
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: *backupConfig.StorageID,
|
||||
Status: backups_core.BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
IsSkipRetry: true,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
backup.GenerateFilename(database.Name)
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error(
|
||||
"failed to save failed backup for expired subscription",
|
||||
"database_id", database.ID,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check for existing in-progress backups
|
||||
inProgressBackups, err := s.backupRepository.FindByDatabaseIdAndStatus(
|
||||
database.ID,
|
||||
@@ -342,6 +363,31 @@ func (s *BackupsScheduler) runPendingBackups() error {
|
||||
continue
|
||||
}
|
||||
|
||||
if database.IsAgentManagedBackup() {
|
||||
continue
|
||||
}
|
||||
|
||||
if config.GetEnv().IsCloud {
|
||||
subscription, subErr := s.billingService.GetSubscription(s.logger, backupConfig.DatabaseID)
|
||||
if subErr != nil {
|
||||
s.logger.Warn(
|
||||
"failed to get subscription, skipping backup",
|
||||
"database_id", backupConfig.DatabaseID,
|
||||
"error", subErr,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
if !subscription.CanCreateNewBackups() {
|
||||
s.logger.Debug(
|
||||
"subscription is not active, skipping scheduled backup",
|
||||
"database_id", backupConfig.DatabaseID,
|
||||
"subscription_status", subscription.Status,
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
s.StartBackup(database, remainedBackupTryCount == 1)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
billing_models "databasus-backend/internal/features/billing/models"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
@@ -20,6 +22,128 @@ import (
|
||||
"databasus-backend/internal/util/period"
|
||||
)
|
||||
|
||||
func Test_RunPendingBackups_ByDatabaseType_OnlySchedulesNonAgentManagedBackups(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
createDatabase func(workspaceID uuid.UUID, storage *storages.Storage, notifier *notifiers.Notifier) *databases.Database
|
||||
isBackupExpected bool
|
||||
needsBackuperNode bool
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "PostgreSQL PG_DUMP - backup runs",
|
||||
createDatabase: func(workspaceID uuid.UUID, storage *storages.Storage, notifier *notifiers.Notifier) *databases.Database {
|
||||
return databases.CreateTestDatabase(workspaceID, storage, notifier)
|
||||
},
|
||||
isBackupExpected: true,
|
||||
needsBackuperNode: true,
|
||||
},
|
||||
{
|
||||
name: "PostgreSQL WAL_V1 - backup skipped (agent-managed)",
|
||||
createDatabase: func(workspaceID uuid.UUID, _ *storages.Storage, notifier *notifiers.Notifier) *databases.Database {
|
||||
return databases.CreateTestPostgresWalDatabase(workspaceID, notifier)
|
||||
},
|
||||
isBackupExpected: false,
|
||||
needsBackuperNode: false,
|
||||
},
|
||||
{
|
||||
name: "MariaDB - backup runs",
|
||||
createDatabase: func(workspaceID uuid.UUID, _ *storages.Storage, notifier *notifiers.Notifier) *databases.Database {
|
||||
return databases.CreateTestMariadbDatabase(workspaceID, notifier)
|
||||
},
|
||||
isBackupExpected: true,
|
||||
needsBackuperNode: true,
|
||||
},
|
||||
{
|
||||
name: "MongoDB - backup runs",
|
||||
createDatabase: func(workspaceID uuid.UUID, _ *storages.Storage, notifier *notifiers.Notifier) *databases.Database {
|
||||
return databases.CreateTestMongodbDatabase(workspaceID, notifier)
|
||||
},
|
||||
isBackupExpected: true,
|
||||
needsBackuperNode: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
var backuperNode *BackuperNode
|
||||
var cancel context.CancelFunc
|
||||
|
||||
if tc.needsBackuperNode {
|
||||
backuperNode = CreateTestBackuperNode()
|
||||
cancel = StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
}
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := tc.createDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
|
||||
backupConfig.RetentionTimePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add old backup (24h ago)
|
||||
backupRepository.Save(&backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupsScheduler().runPendingBackups()
|
||||
|
||||
if tc.isBackupExpected {
|
||||
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 2)
|
||||
} else {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_RunPendingBackups_WhenLastBackupWasYesterday_CreatesNewBackup(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
@@ -845,7 +969,7 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
// Start scheduler so it can handle task completions
|
||||
scheduler := CreateTestScheduler()
|
||||
scheduler := CreateTestScheduler(nil)
|
||||
schedulerCancel := StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
@@ -942,7 +1066,7 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
// Start scheduler so it can handle task completions
|
||||
scheduler := CreateTestScheduler()
|
||||
scheduler := CreateTestScheduler(nil)
|
||||
schedulerCancel := StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
@@ -1209,7 +1333,7 @@ func Test_StartBackup_When2BackupsStartedForDifferentDatabases_BothUseCasesAreCa
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
// Create scheduler
|
||||
scheduler := CreateTestScheduler()
|
||||
scheduler := CreateTestScheduler(nil)
|
||||
schedulerCancel := StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
@@ -1335,3 +1459,313 @@ func Test_StartBackup_When2BackupsStartedForDifferentDatabases_BothUseCasesAreCa
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartBackup_WhenCloudAndSubscriptionExpired_CreatesFailedBackup(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{
|
||||
Status: billing_models.StatusExpired,
|
||||
},
|
||||
}
|
||||
scheduler := CreateTestScheduler(mockBilling)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
enableCloud(t)
|
||||
|
||||
scheduler.StartBackup(database, false)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
|
||||
newestBackup := backups[0]
|
||||
assert.Equal(t, backups_core.BackupStatusFailed, newestBackup.Status)
|
||||
assert.NotNil(t, newestBackup.FailMessage)
|
||||
assert.Equal(t, "subscription has expired, please renew", *newestBackup.FailMessage)
|
||||
assert.True(t, newestBackup.IsSkipRetry)
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartBackup_WhenCloudAndSubscriptionActive_ProceedsNormally(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{
|
||||
Status: billing_models.StatusActive,
|
||||
StorageGB: 10,
|
||||
},
|
||||
}
|
||||
scheduler := CreateTestScheduler(mockBilling)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
enableCloud(t)
|
||||
|
||||
scheduler.StartBackup(database, false)
|
||||
|
||||
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
|
||||
newestBackup := backups[0]
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, newestBackup.Status)
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_RunPendingBackups_WhenCloudAndSubscriptionExpired_SilentlySkips(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{
|
||||
Status: billing_models.StatusExpired,
|
||||
},
|
||||
}
|
||||
scheduler := CreateTestScheduler(mockBilling)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
|
||||
backupConfig.RetentionTimePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupRepository.Save(&backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
enableCloud(t)
|
||||
|
||||
scheduler.runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1, "No new backup should be created, scheduler silently skips expired subscriptions")
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartBackup_WhenNotCloudAndSubscriptionExpired_ProceedsNormally(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{
|
||||
Status: billing_models.StatusExpired,
|
||||
},
|
||||
}
|
||||
scheduler := CreateTestScheduler(mockBilling)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
scheduler.StartBackup(database, false)
|
||||
|
||||
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backups[0].Status,
|
||||
"Billing check should not apply in non-cloud mode")
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_RunPendingBackups_WhenNotCloudAndSubscriptionExpired_ProceedsNormally(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
mockBilling := &mockBillingService{
|
||||
subscription: &billing_models.Subscription{
|
||||
Status: billing_models.StatusExpired,
|
||||
},
|
||||
}
|
||||
scheduler := CreateTestScheduler(mockBilling)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
|
||||
backupConfig.RetentionTimePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupRepository.Save(&backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
scheduler.runPendingBackups()
|
||||
|
||||
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 2, "Billing check should not apply in non-cloud mode, new backup should be created")
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package backuping
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -35,58 +34,70 @@ func CreateTestRouter() *gin.Engine {
|
||||
return router
|
||||
}
|
||||
|
||||
func CreateTestBackupCleaner(billingService BillingService) *BackupCleaner {
|
||||
return &BackupCleaner{
|
||||
backupRepository,
|
||||
storages.GetStorageService(),
|
||||
backups_config.GetBackupConfigService(),
|
||||
billingService,
|
||||
encryption.GetFieldEncryptor(),
|
||||
logger.GetLogger(),
|
||||
[]backups_core.BackupRemoveListener{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestBackuperNode() *BackuperNode {
|
||||
return &BackuperNode{
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
workspaceService: workspaces_services.GetWorkspaceService(),
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
notificationSender: notifiers.GetNotifierService(),
|
||||
backupCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
logger: logger.GetLogger(),
|
||||
createBackupUseCase: usecases.GetCreateBackupUsecase(),
|
||||
nodeID: uuid.New(),
|
||||
lastHeartbeat: time.Time{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
databases.GetDatabaseService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
taskCancelManager,
|
||||
backupNodesRegistry,
|
||||
logger.GetLogger(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
uuid.New(),
|
||||
time.Time{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestBackuperNodeWithUseCase(useCase backups_core.CreateBackupUsecase) *BackuperNode {
|
||||
return &BackuperNode{
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
workspaceService: workspaces_services.GetWorkspaceService(),
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
notificationSender: notifiers.GetNotifierService(),
|
||||
backupCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
logger: logger.GetLogger(),
|
||||
createBackupUseCase: useCase,
|
||||
nodeID: uuid.New(),
|
||||
lastHeartbeat: time.Time{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
databases.GetDatabaseService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
taskCancelManager,
|
||||
backupNodesRegistry,
|
||||
logger.GetLogger(),
|
||||
useCase,
|
||||
uuid.New(),
|
||||
time.Time{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestScheduler() *BackupsScheduler {
|
||||
func CreateTestScheduler(billingService BillingService) *BackupsScheduler {
|
||||
return &BackupsScheduler{
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
taskCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
lastBackupTime: time.Now().UTC(),
|
||||
logger: logger.GetLogger(),
|
||||
backupToNodeRelations: make(map[uuid.UUID]BackupToNodeRelation),
|
||||
backuperNode: CreateTestBackuperNode(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
taskCancelManager,
|
||||
backupNodesRegistry,
|
||||
databases.GetDatabaseService(),
|
||||
billingService,
|
||||
time.Now().UTC(),
|
||||
logger.GetLogger(),
|
||||
make(map[uuid.UUID]BackupToNodeRelation),
|
||||
CreateTestBackuperNode(),
|
||||
atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -40,12 +40,15 @@ func (c *BackupController) RegisterPublicRoutes(router *gin.RouterGroup) {
|
||||
|
||||
// GetBackups
|
||||
// @Summary Get backups for a database
|
||||
// @Description Get paginated backups for the specified database
|
||||
// @Description Get paginated backups for the specified database with optional filters
|
||||
// @Tags backups
|
||||
// @Produce json
|
||||
// @Param database_id query string true "Database ID"
|
||||
// @Param limit query int false "Number of items per page" default(10)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param status query []string false "Filter by backup status (can be repeated)" Enums(IN_PROGRESS, COMPLETED, FAILED, CANCELED)
|
||||
// @Param beforeDate query string false "Filter backups created before this date (RFC3339)" format(date-time)
|
||||
// @Param pgWalBackupType query string false "Filter by WAL backup type" Enums(PG_FULL_BACKUP, PG_WAL_SEGMENT)
|
||||
// @Success 200 {object} backups_dto.GetBackupsResponse
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
@@ -70,7 +73,9 @@ func (c *BackupController) GetBackups(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.backupService.GetBackups(user, databaseID, request.Limit, request.Offset)
|
||||
filters := c.buildBackupFilters(&request)
|
||||
|
||||
response, err := c.backupService.GetBackups(user, databaseID, request.Limit, request.Offset, filters)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -359,3 +364,35 @@ func (c *BackupController) startDownloadHeartbeat(ctx context.Context, userID uu
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *BackupController) buildBackupFilters(
|
||||
request *backups_dto.GetBackupsRequest,
|
||||
) *backups_core.BackupFilters {
|
||||
isHasFilters := len(request.Statuses) > 0 ||
|
||||
request.BeforeDate != nil ||
|
||||
request.PgWalBackupType != nil
|
||||
|
||||
if !isHasFilters {
|
||||
return nil
|
||||
}
|
||||
|
||||
filters := &backups_core.BackupFilters{}
|
||||
|
||||
if len(request.Statuses) > 0 {
|
||||
statuses := make([]backups_core.BackupStatus, 0, len(request.Statuses))
|
||||
for _, statusStr := range request.Statuses {
|
||||
statuses = append(statuses, backups_core.BackupStatus(statusStr))
|
||||
}
|
||||
|
||||
filters.Statuses = statuses
|
||||
}
|
||||
|
||||
filters.BeforeDate = request.BeforeDate
|
||||
|
||||
if request.PgWalBackupType != nil {
|
||||
walType := backups_core.PgWalBackupType(*request.PgWalBackupType)
|
||||
filters.PgWalBackupType = &walType
|
||||
}
|
||||
|
||||
return filters
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
@@ -140,6 +141,225 @@ func Test_GetBackups_PermissionsEnforced(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GetBackups_WithStatusFilter_ReturnsFilteredBackups(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
CreatedAt: now.Add(-3 * time.Hour),
|
||||
})
|
||||
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
|
||||
Status: backups_core.BackupStatusFailed,
|
||||
CreatedAt: now.Add(-2 * time.Hour),
|
||||
})
|
||||
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
|
||||
Status: backups_core.BackupStatusCanceled,
|
||||
CreatedAt: now.Add(-1 * time.Hour),
|
||||
})
|
||||
|
||||
// Single status filter
|
||||
var singleResponse backups_dto.GetBackupsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups?database_id=%s&status=COMPLETED", database.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&singleResponse,
|
||||
)
|
||||
|
||||
assert.Equal(t, int64(1), singleResponse.Total)
|
||||
assert.Len(t, singleResponse.Backups, 1)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, singleResponse.Backups[0].Status)
|
||||
|
||||
// Multiple status filter
|
||||
var multiResponse backups_dto.GetBackupsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/backups?database_id=%s&status=COMPLETED&status=FAILED",
|
||||
database.ID.String(),
|
||||
),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&multiResponse,
|
||||
)
|
||||
|
||||
assert.Equal(t, int64(2), multiResponse.Total)
|
||||
assert.Len(t, multiResponse.Backups, 2)
|
||||
|
||||
for _, backup := range multiResponse.Backups {
|
||||
assert.True(
|
||||
t,
|
||||
backup.Status == backups_core.BackupStatusCompleted ||
|
||||
backup.Status == backups_core.BackupStatusFailed,
|
||||
"expected COMPLETED or FAILED, got %s", backup.Status,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GetBackups_WithBeforeDateFilter_ReturnsFilteredBackups(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
now := time.Now().UTC()
|
||||
cutoff := now.Add(-1 * time.Hour)
|
||||
|
||||
olderBackup := CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
CreatedAt: now.Add(-3 * time.Hour),
|
||||
})
|
||||
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
CreatedAt: now,
|
||||
})
|
||||
|
||||
var response backups_dto.GetBackupsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/backups?database_id=%s&beforeDate=%s",
|
||||
database.ID.String(),
|
||||
cutoff.Format(time.RFC3339),
|
||||
),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.Equal(t, int64(1), response.Total)
|
||||
assert.Len(t, response.Backups, 1)
|
||||
assert.Equal(t, olderBackup.ID, response.Backups[0].ID)
|
||||
}
|
||||
|
||||
func Test_GetBackups_WithPgWalBackupTypeFilter_ReturnsFilteredBackups(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
now := time.Now().UTC()
|
||||
fullBackupType := backups_core.PgWalBackupTypeFullBackup
|
||||
walSegmentType := backups_core.PgWalBackupTypeWalSegment
|
||||
|
||||
fullBackup := CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
CreatedAt: now.Add(-2 * time.Hour),
|
||||
PgWalBackupType: &fullBackupType,
|
||||
})
|
||||
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
CreatedAt: now.Add(-1 * time.Hour),
|
||||
PgWalBackupType: &walSegmentType,
|
||||
})
|
||||
|
||||
var response backups_dto.GetBackupsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/backups?database_id=%s&pgWalBackupType=PG_FULL_BACKUP",
|
||||
database.ID.String(),
|
||||
),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.Equal(t, int64(1), response.Total)
|
||||
assert.Len(t, response.Backups, 1)
|
||||
assert.Equal(t, fullBackup.ID, response.Backups[0].ID)
|
||||
}
|
||||
|
||||
func Test_GetBackups_WithCombinedFilters_ReturnsFilteredBackups(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
now := time.Now().UTC()
|
||||
cutoff := now.Add(-1 * time.Hour)
|
||||
|
||||
// Old completed — should match
|
||||
oldCompleted := CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
CreatedAt: now.Add(-3 * time.Hour),
|
||||
})
|
||||
// Old failed — should NOT match (wrong status)
|
||||
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
|
||||
Status: backups_core.BackupStatusFailed,
|
||||
CreatedAt: now.Add(-2 * time.Hour),
|
||||
})
|
||||
// New completed — should NOT match (too recent)
|
||||
CreateTestBackupWithOptions(database.ID, storage.ID, TestBackupOptions{
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
CreatedAt: now,
|
||||
})
|
||||
|
||||
var response backups_dto.GetBackupsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/backups?database_id=%s&status=COMPLETED&beforeDate=%s",
|
||||
database.ID.String(),
|
||||
cutoff.Format(time.RFC3339),
|
||||
),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.Equal(t, int64(1), response.Total)
|
||||
assert.Len(t, response.Backups, 1)
|
||||
assert.Equal(t, oldCompleted.ID, response.Backups[0].ID)
|
||||
}
|
||||
|
||||
func Test_CreateBackup_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -376,7 +596,7 @@ func Test_DeleteBackup_PermissionsEnforced(t *testing.T) {
|
||||
ownerUser, err := userService.GetUserFromToken(owner.Token)
|
||||
assert.NoError(t, err)
|
||||
|
||||
response, err := backups_services.GetBackupService().GetBackups(ownerUser, database.ID, 10, 0)
|
||||
response, err := backups_services.GetBackupService().GetBackups(ownerUser, database.ID, 10, 0, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, len(response.Backups))
|
||||
}
|
||||
@@ -1263,7 +1483,7 @@ func Test_MakeBackup_VerifyBackupAndMetadataFilesExistInStorage(t *testing.T) {
|
||||
backuperCancel := backuping.StartBackuperNodeForTest(t, backuperNode)
|
||||
defer backuping.StopBackuperNodeForTest(t, backuperCancel, backuperNode)
|
||||
|
||||
scheduler := backuping.CreateTestScheduler()
|
||||
scheduler := backuping.CreateTestScheduler(nil)
|
||||
schedulerCancel := backuping.StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
@@ -1297,14 +1517,14 @@ func Test_MakeBackup_VerifyBackupAndMetadataFilesExistInStorage(t *testing.T) {
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
|
||||
backupFile, err := backupStorage.GetFile(encryptor, backup.FileName)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
backupFile.Close()
|
||||
|
||||
metadataFile, err := backupStorage.GetFile(encryptor, backup.FileName+".metadata")
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
metadataContent, err := io.ReadAll(metadataFile)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
metadataFile.Close()
|
||||
|
||||
var storageMetadata backups_common.BackupMetadata
|
||||
@@ -1838,7 +2058,7 @@ func Test_DeleteBackup_RemovesBackupAndMetadataFilesFromDisk(t *testing.T) {
|
||||
backuperCancel := backuping.StartBackuperNodeForTest(t, backuperNode)
|
||||
defer backuping.StopBackuperNodeForTest(t, backuperCancel, backuperNode)
|
||||
|
||||
scheduler := backuping.CreateTestScheduler()
|
||||
scheduler := backuping.CreateTestScheduler(nil)
|
||||
schedulerCancel := backuping.StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
|
||||
@@ -938,6 +938,42 @@ func Test_GetRestorePlan_WithInvalidBackupId_Returns400(t *testing.T) {
|
||||
assert.Equal(t, "no_backups", errResp.Error)
|
||||
}
|
||||
|
||||
func Test_GetRestorePlan_WithWalSegmentId_ResolvesFullBackupAndReturnsWals(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
|
||||
uploadWalSegment(t, router, agentToken, "000000010000000100000011")
|
||||
uploadWalSegment(t, router, agentToken, "000000010000000100000012")
|
||||
uploadWalSegment(t, router, agentToken, "000000010000000100000013")
|
||||
|
||||
WaitForBackupCompletion(t, db.ID, 3, 5*time.Second)
|
||||
|
||||
walSegment, err := backups_core.GetBackupRepository().FindWalSegmentByName(
|
||||
db.ID, "000000010000000100000012",
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, walSegment)
|
||||
|
||||
var response backups_dto.GetRestorePlanResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t, router,
|
||||
"/api/v1/backups/postgres/wal/restore/plan?backupId="+walSegment.ID.String(),
|
||||
agentToken,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.NotEqual(t, uuid.Nil, response.FullBackup.BackupID)
|
||||
assert.Equal(t, "000000010000000100000001", response.FullBackup.FullBackupWalStartSegment)
|
||||
assert.Equal(t, "000000010000000100000010", response.FullBackup.FullBackupWalStopSegment)
|
||||
require.Len(t, response.WalSegments, 3)
|
||||
assert.Equal(t, "000000010000000100000011", response.WalSegments[0].SegmentName)
|
||||
assert.Equal(t, "000000010000000100000012", response.WalSegments[1].SegmentName)
|
||||
assert.Equal(t, "000000010000000100000013", response.WalSegments[2].SegmentName)
|
||||
assert.Greater(t, response.TotalSizeBytes, int64(0))
|
||||
}
|
||||
|
||||
func Test_GetRestorePlan_WithInvalidToken_Returns401(t *testing.T) {
|
||||
router, db, storage, _, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
@@ -995,6 +1031,140 @@ func Test_GetRestorePlan_WithInvalidBackupIdFormat_Returns400(t *testing.T) {
|
||||
assert.Contains(t, string(resp.Body), "invalid backupId format")
|
||||
}
|
||||
|
||||
func Test_WalUpload_WalSegment_CompletedBackup_HasNonZeroDuration(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
|
||||
uploadWalSegment(t, router, agentToken, "000000010000000100000011")
|
||||
|
||||
WaitForBackupCompletion(t, db.ID, 1, 5*time.Second)
|
||||
|
||||
backups, err := backups_core.GetBackupRepository().FindByDatabaseID(db.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var walBackup *backups_core.Backup
|
||||
for _, b := range backups {
|
||||
if b.PgWalBackupType != nil &&
|
||||
*b.PgWalBackupType == backups_core.PgWalBackupTypeWalSegment {
|
||||
walBackup = b
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, walBackup)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, walBackup.Status)
|
||||
assert.Greater(t, walBackup.BackupDurationMs, int64(0),
|
||||
"WAL segment backup should have non-zero duration")
|
||||
}
|
||||
|
||||
func Test_WalUpload_Basebackup_CompletedBackup_HasNonZeroDuration(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
backupID := uploadBasebackupPhase1(t, router, agentToken)
|
||||
completeFullBackupUpload(t, router, agentToken, backupID,
|
||||
"000000010000000100000001", "000000010000000100000010", nil)
|
||||
|
||||
backup, err := backups_core.GetBackupRepository().FindByID(backupID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
assert.Greater(t, backup.BackupDurationMs, int64(0),
|
||||
"base backup should have non-zero duration")
|
||||
}
|
||||
|
||||
func Test_WalUpload_WalSegment_ProgressUpdatedDuringStream(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
|
||||
|
||||
pipeReader, pipeWriter := io.Pipe()
|
||||
req := newWalSegmentUploadRequest(pipeReader, agentToken, "000000010000000100000011")
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
router.ServeHTTP(recorder, req)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Write some data so the countingReader registers bytes.
|
||||
_, err := pipeWriter.Write([]byte("wal-segment-progress-data"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for the progress tracker to tick (1s interval + margin).
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
|
||||
backups, err := backups_core.GetBackupRepository().FindByDatabaseID(db.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var walBackup *backups_core.Backup
|
||||
for _, b := range backups {
|
||||
if b.PgWalBackupType != nil &&
|
||||
*b.PgWalBackupType == backups_core.PgWalBackupTypeWalSegment {
|
||||
walBackup = b
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, walBackup)
|
||||
assert.Equal(t, backups_core.BackupStatusInProgress, walBackup.Status)
|
||||
assert.Greater(t, walBackup.BackupDurationMs, int64(0),
|
||||
"duration should be tracked in real-time during upload")
|
||||
assert.Greater(t, walBackup.BackupSizeMb, float64(0),
|
||||
"size should be tracked in real-time during upload")
|
||||
|
||||
_ = pipeWriter.Close()
|
||||
<-done
|
||||
}
|
||||
|
||||
func Test_WalUpload_Basebackup_ProgressUpdatedDuringStream(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
pipeReader, pipeWriter := io.Pipe()
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/v1/backups/postgres/wal/upload/full-start", pipeReader)
|
||||
req.Header.Set("Authorization", agentToken)
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
router.ServeHTTP(recorder, req)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Write some data so the countingReader registers bytes.
|
||||
_, err := pipeWriter.Write([]byte("basebackup-progress-data"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for the progress tracker to tick (1s interval + margin).
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
|
||||
backups, err := backups_core.GetBackupRepository().FindByDatabaseID(db.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var fullBackup *backups_core.Backup
|
||||
for _, b := range backups {
|
||||
if b.PgWalBackupType != nil &&
|
||||
*b.PgWalBackupType == backups_core.PgWalBackupTypeFullBackup {
|
||||
fullBackup = b
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, fullBackup)
|
||||
assert.Equal(t, backups_core.BackupStatusInProgress, fullBackup.Status)
|
||||
assert.Greater(t, fullBackup.BackupDurationMs, int64(0),
|
||||
"duration should be tracked in real-time during upload")
|
||||
assert.Greater(t, fullBackup.BackupSizeMb, float64(0),
|
||||
"size should be tracked in real-time during upload")
|
||||
|
||||
_ = pipeWriter.Close()
|
||||
<-done
|
||||
}
|
||||
|
||||
func Test_DownloadRestoreFile_UploadThenDownload_ContentMatches(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -95,3 +95,33 @@ func CreateTestBackup(databaseID, storageID uuid.UUID) *backups_core.Backup {
|
||||
|
||||
return backup
|
||||
}
|
||||
|
||||
type TestBackupOptions struct {
|
||||
Status backups_core.BackupStatus
|
||||
CreatedAt time.Time
|
||||
PgWalBackupType *backups_core.PgWalBackupType
|
||||
}
|
||||
|
||||
// CreateTestBackupWithOptions creates a test backup with custom status, time, and WAL type
|
||||
func CreateTestBackupWithOptions(
|
||||
databaseID, storageID uuid.UUID,
|
||||
opts TestBackupOptions,
|
||||
) *backups_core.Backup {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: databaseID,
|
||||
StorageID: storageID,
|
||||
Status: opts.Status,
|
||||
BackupSizeMb: 10.5,
|
||||
BackupDurationMs: 1000,
|
||||
PgWalBackupType: opts.PgWalBackupType,
|
||||
CreatedAt: opts.CreatedAt,
|
||||
}
|
||||
|
||||
repo := &backups_core.BackupRepository{}
|
||||
if err := repo.Save(backup); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return backup
|
||||
}
|
||||
|
||||
9
backend/internal/features/backups/backups/core/dto.go
Normal file
9
backend/internal/features/backups/backups/core/dto.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package backups_core
|
||||
|
||||
import "time"
|
||||
|
||||
type BackupFilters struct {
|
||||
Statuses []BackupStatus
|
||||
BeforeDate *time.Time
|
||||
PgWalBackupType *PgWalBackupType
|
||||
}
|
||||
@@ -349,6 +349,34 @@ func (r *BackupRepository) FindWalSegmentByName(
|
||||
return &backup, nil
|
||||
}
|
||||
|
||||
func (r *BackupRepository) FindLatestCompletedFullWalBackupBefore(
|
||||
databaseID uuid.UUID,
|
||||
before time.Time,
|
||||
) (*Backup, error) {
|
||||
var backup Backup
|
||||
|
||||
err := storage.
|
||||
GetDb().
|
||||
Where(
|
||||
"database_id = ? AND pg_wal_backup_type = ? AND status = ? AND created_at <= ?",
|
||||
databaseID,
|
||||
PgWalBackupTypeFullBackup,
|
||||
BackupStatusCompleted,
|
||||
before,
|
||||
).
|
||||
Order("created_at DESC").
|
||||
First(&backup).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &backup, nil
|
||||
}
|
||||
|
||||
func (r *BackupRepository) FindStaleUploadedBasebackups(olderThan time.Time) ([]*Backup, error) {
|
||||
var backups []*Backup
|
||||
|
||||
@@ -394,3 +422,67 @@ func (r *BackupRepository) FindLastWalSegmentAfter(
|
||||
|
||||
return &backup, nil
|
||||
}
|
||||
|
||||
func (r *BackupRepository) FindByDatabaseIDWithFiltersAndPagination(
|
||||
databaseID uuid.UUID,
|
||||
filters *BackupFilters,
|
||||
limit, offset int,
|
||||
) ([]*Backup, error) {
|
||||
var backups []*Backup
|
||||
|
||||
query := storage.
|
||||
GetDb().
|
||||
Where("database_id = ?", databaseID)
|
||||
|
||||
if filters != nil {
|
||||
query = filters.applyToQuery(query)
|
||||
}
|
||||
|
||||
if err := query.
|
||||
Order("created_at DESC").
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&backups).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return backups, nil
|
||||
}
|
||||
|
||||
func (r *BackupRepository) CountByDatabaseIDWithFilters(
|
||||
databaseID uuid.UUID,
|
||||
filters *BackupFilters,
|
||||
) (int64, error) {
|
||||
var count int64
|
||||
|
||||
query := storage.
|
||||
GetDb().
|
||||
Model(&Backup{}).
|
||||
Where("database_id = ?", databaseID)
|
||||
|
||||
if filters != nil {
|
||||
query = filters.applyToQuery(query)
|
||||
}
|
||||
|
||||
if err := query.Count(&count).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (f *BackupFilters) applyToQuery(query *gorm.DB) *gorm.DB {
|
||||
if len(f.Statuses) > 0 {
|
||||
query = query.Where("status IN ?", f.Statuses)
|
||||
}
|
||||
|
||||
if f.BeforeDate != nil {
|
||||
query = query.Where("created_at < ?", *f.BeforeDate)
|
||||
}
|
||||
|
||||
if f.PgWalBackupType != nil {
|
||||
query = query.Where("pg_wal_backup_type = ?", *f.PgWalBackupType)
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
@@ -13,38 +12,31 @@ type DownloadTokenBackgroundService struct {
|
||||
downloadTokenService *DownloadTokenService
|
||||
logger *slog.Logger
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *DownloadTokenBackgroundService) Run(ctx context.Context) {
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
s.logger.Info("Starting download token cleanup background service")
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
|
||||
s.logger.Error("Failed to clean expired download tokens", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
if s.hasRun.Swap(true) {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
|
||||
s.logger.Info("Starting download token cleanup background service")
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
|
||||
s.logger.Error("Failed to clean expired download tokens", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/logger"
|
||||
@@ -37,8 +34,6 @@ func init() {
|
||||
downloadTokenBackgroundService = &DownloadTokenBackgroundService{
|
||||
downloadTokenService: downloadTokenService,
|
||||
logger: logger.GetLogger(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -11,9 +11,12 @@ import (
|
||||
)
|
||||
|
||||
type GetBackupsRequest struct {
|
||||
DatabaseID string `form:"database_id" binding:"required"`
|
||||
Limit int `form:"limit"`
|
||||
Offset int `form:"offset"`
|
||||
DatabaseID string `form:"database_id" binding:"required"`
|
||||
Limit int `form:"limit"`
|
||||
Offset int `form:"offset"`
|
||||
Statuses []string `form:"status"`
|
||||
BeforeDate *time.Time `form:"beforeDate"`
|
||||
PgWalBackupType *string `form:"pgWalBackupType"`
|
||||
}
|
||||
|
||||
type GetBackupsResponse struct {
|
||||
|
||||
@@ -2,7 +2,6 @@ package backups_services
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
@@ -59,26 +58,11 @@ func GetWalService() *PostgreWalBackupService {
|
||||
return walService
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
var SetupDependencies = sync.OnceFunc(func() {
|
||||
backups_config.
|
||||
GetBackupConfigService().
|
||||
SetDatabaseStorageChangeListener(backupService)
|
||||
|
||||
func SetupDependencies() {
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
backups_config.
|
||||
GetBackupConfigService().
|
||||
SetDatabaseStorageChangeListener(backupService)
|
||||
|
||||
databases.GetDatabaseService().AddDbRemoveListener(backupService)
|
||||
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
databases.GetDatabaseService().AddDbRemoveListener(backupService)
|
||||
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
|
||||
})
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -38,6 +39,8 @@ func (s *PostgreWalBackupService) UploadWalSegment(
|
||||
walSegmentName string,
|
||||
body io.Reader,
|
||||
) error {
|
||||
uploadStart := time.Now().UTC()
|
||||
|
||||
if err := s.validateWalBackupType(database); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -72,14 +75,22 @@ func (s *PostgreWalBackupService) UploadWalSegment(
|
||||
return fmt.Errorf("failed to create backup record: %w", err)
|
||||
}
|
||||
|
||||
sizeBytes, streamErr := s.streamToStorage(ctx, backup, backupConfig, body)
|
||||
inputCounter := &countingReader{r: body}
|
||||
progressDone := make(chan struct{})
|
||||
go s.startProgressTracker(backup, inputCounter, uploadStart, progressDone)
|
||||
|
||||
sizeBytes, streamErr := s.streamToStorage(ctx, backup, backupConfig, inputCounter)
|
||||
close(progressDone)
|
||||
|
||||
if streamErr != nil {
|
||||
errMsg := streamErr.Error()
|
||||
backup.BackupDurationMs = time.Since(uploadStart).Milliseconds()
|
||||
s.markFailed(backup, errMsg)
|
||||
|
||||
return fmt.Errorf("upload failed: %w", streamErr)
|
||||
}
|
||||
|
||||
backup.BackupDurationMs = time.Since(uploadStart).Milliseconds()
|
||||
s.markCompleted(backup, sizeBytes)
|
||||
|
||||
return nil
|
||||
@@ -93,6 +104,8 @@ func (s *PostgreWalBackupService) UploadBasebackup(
|
||||
database *databases.Database,
|
||||
body io.Reader,
|
||||
) (uuid.UUID, error) {
|
||||
uploadStart := time.Now().UTC()
|
||||
|
||||
if err := s.validateWalBackupType(database); err != nil {
|
||||
return uuid.Nil, err
|
||||
}
|
||||
@@ -117,9 +130,16 @@ func (s *PostgreWalBackupService) UploadBasebackup(
|
||||
return uuid.Nil, fmt.Errorf("failed to create backup record: %w", err)
|
||||
}
|
||||
|
||||
sizeBytes, streamErr := s.streamToStorage(ctx, backup, backupConfig, body)
|
||||
inputCounter := &countingReader{r: body}
|
||||
progressDone := make(chan struct{})
|
||||
go s.startProgressTracker(backup, inputCounter, uploadStart, progressDone)
|
||||
|
||||
sizeBytes, streamErr := s.streamToStorage(ctx, backup, backupConfig, inputCounter)
|
||||
close(progressDone)
|
||||
|
||||
if streamErr != nil {
|
||||
errMsg := streamErr.Error()
|
||||
backup.BackupDurationMs = time.Since(uploadStart).Milliseconds()
|
||||
s.markFailed(backup, errMsg)
|
||||
|
||||
return uuid.Nil, fmt.Errorf("upload failed: %w", streamErr)
|
||||
@@ -128,6 +148,7 @@ func (s *PostgreWalBackupService) UploadBasebackup(
|
||||
now := time.Now().UTC()
|
||||
backup.UploadCompletedAt = &now
|
||||
backup.BackupSizeMb = float64(sizeBytes) / (1024 * 1024)
|
||||
backup.BackupDurationMs = time.Since(uploadStart).Milliseconds()
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
return uuid.Nil, fmt.Errorf("failed to update backup after upload: %w", err)
|
||||
@@ -483,7 +504,7 @@ func (s *PostgreWalBackupService) streamDirect(
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return cr.n, nil
|
||||
return cr.n.Load(), nil
|
||||
}
|
||||
|
||||
func (s *PostgreWalBackupService) streamEncrypted(
|
||||
@@ -544,7 +565,7 @@ func (s *PostgreWalBackupService) streamEncrypted(
|
||||
backup.EncryptionSalt = &encryptionSetup.SaltBase64
|
||||
backup.EncryptionIV = &encryptionSetup.NonceBase64
|
||||
|
||||
return cr.n, nil
|
||||
return cr.n.Load(), nil
|
||||
}
|
||||
|
||||
func (s *PostgreWalBackupService) markCompleted(backup *backups_core.Backup, sizeBytes int64) {
|
||||
@@ -562,6 +583,31 @@ func (s *PostgreWalBackupService) markCompleted(backup *backups_core.Backup, siz
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PostgreWalBackupService) startProgressTracker(
|
||||
backup *backups_core.Backup,
|
||||
inputCounter *countingReader,
|
||||
uploadStart time.Time,
|
||||
done <-chan struct{},
|
||||
) {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
backup.BackupDurationMs = time.Since(uploadStart).Milliseconds()
|
||||
backup.BackupSizeMb = float64(inputCounter.n.Load()) / (1024 * 1024)
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error("failed to update backup progress",
|
||||
"backupId", backup.ID, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PostgreWalBackupService) markFailed(backup *backups_core.Backup, errMsg string) {
|
||||
backup.Status = backups_core.BackupStatusFailed
|
||||
backup.FailMessage = &errMsg
|
||||
@@ -575,11 +621,32 @@ func (s *PostgreWalBackupService) resolveFullBackup(
|
||||
databaseID uuid.UUID,
|
||||
backupID *uuid.UUID,
|
||||
) (*backups_core.Backup, error) {
|
||||
if backupID != nil {
|
||||
return s.backupRepository.FindCompletedFullWalBackupByID(databaseID, *backupID)
|
||||
if backupID == nil {
|
||||
return s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(databaseID)
|
||||
}
|
||||
|
||||
return s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(databaseID)
|
||||
fullBackup, err := s.backupRepository.FindCompletedFullWalBackupByID(databaseID, *backupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if fullBackup != nil {
|
||||
return fullBackup, nil
|
||||
}
|
||||
|
||||
backup, err := s.backupRepository.FindByID(*backupID)
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if backup.DatabaseID != databaseID ||
|
||||
backup.Status != backups_core.BackupStatusCompleted ||
|
||||
backup.PgWalBackupType == nil ||
|
||||
*backup.PgWalBackupType != backups_core.PgWalBackupTypeWalSegment {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return s.backupRepository.FindLatestCompletedFullWalBackupBefore(databaseID, backup.CreatedAt)
|
||||
}
|
||||
|
||||
func (s *PostgreWalBackupService) validateRestoreWalChain(
|
||||
@@ -667,12 +734,12 @@ func (s *PostgreWalBackupService) validateWalBackupType(database *databases.Data
|
||||
|
||||
type countingReader struct {
|
||||
r io.Reader
|
||||
n int64
|
||||
n atomic.Int64
|
||||
}
|
||||
|
||||
func (cr *countingReader) Read(p []byte) (n int, err error) {
|
||||
n, err = cr.r.Read(p)
|
||||
cr.n += int64(n)
|
||||
func (cr *countingReader) Read(p []byte) (int, error) {
|
||||
n, err := cr.r.Read(p)
|
||||
cr.n.Add(int64(n))
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
@@ -109,6 +109,7 @@ func (s *BackupService) GetBackups(
|
||||
user *users_models.User,
|
||||
databaseID uuid.UUID,
|
||||
limit, offset int,
|
||||
filters *backups_core.BackupFilters,
|
||||
) (*backups_dto.GetBackupsResponse, error) {
|
||||
database, err := s.databaseService.GetDatabaseByID(databaseID)
|
||||
if err != nil {
|
||||
@@ -134,12 +135,14 @@ func (s *BackupService) GetBackups(
|
||||
offset = 0
|
||||
}
|
||||
|
||||
backups, err := s.backupRepository.FindByDatabaseIDWithPagination(databaseID, limit, offset)
|
||||
backups, err := s.backupRepository.FindByDatabaseIDWithFiltersAndPagination(
|
||||
databaseID, filters, limit, offset,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
total, err := s.backupRepository.CountByDatabaseID(databaseID)
|
||||
total, err := s.backupRepository.CountByDatabaseIDWithFilters(databaseID, filters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -109,6 +109,7 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
|
||||
"--routines",
|
||||
"--quick",
|
||||
"--skip-extended-insert",
|
||||
"--skip-add-locks",
|
||||
"--verbose",
|
||||
}
|
||||
|
||||
@@ -129,6 +130,8 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
|
||||
if mdb.IsHttps {
|
||||
args = append(args, "--ssl")
|
||||
args = append(args, "--skip-ssl-verify-server-cert")
|
||||
} else {
|
||||
args = append(args, "--skip-ssl")
|
||||
}
|
||||
|
||||
if mdb.Database != nil && *mdb.Database != "" {
|
||||
@@ -278,15 +281,9 @@ func (uc *CreateMariadbBackupUsecase) createTempMyCnfFile(
|
||||
mdbConfig *mariadbtypes.MariadbDatabase,
|
||||
password string,
|
||||
) (string, error) {
|
||||
tempFolder := config.GetEnv().TempFolder
|
||||
if err := os.MkdirAll(tempFolder, 0o700); err != nil {
|
||||
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
|
||||
}
|
||||
if err := os.Chmod(tempFolder, 0o700); err != nil {
|
||||
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
|
||||
}
|
||||
|
||||
tempDir, err := os.MkdirTemp(tempFolder, "mycnf_"+uuid.New().String())
|
||||
// Credential files use OS temp dir (/tmp) because some filesystems
|
||||
// (e.g. ZFS on TrueNAS) ignore chmod, causing "group or world access" errors.
|
||||
tempDir, err := os.MkdirTemp(os.TempDir(), "mycnf_"+uuid.New().String())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temp directory: %w", err)
|
||||
}
|
||||
|
||||
@@ -108,6 +108,7 @@ func (uc *CreateMysqlBackupUsecase) buildMysqldumpArgs(my *mysqltypes.MysqlDatab
|
||||
"--set-gtid-purged=OFF",
|
||||
"--quick",
|
||||
"--skip-extended-insert",
|
||||
"--skip-add-locks",
|
||||
"--verbose",
|
||||
}
|
||||
|
||||
@@ -126,6 +127,8 @@ func (uc *CreateMysqlBackupUsecase) buildMysqldumpArgs(my *mysqltypes.MysqlDatab
|
||||
|
||||
if my.IsHttps {
|
||||
args = append(args, "--ssl-mode=REQUIRED")
|
||||
} else {
|
||||
args = append(args, "--ssl-mode=DISABLED")
|
||||
}
|
||||
|
||||
if my.Database != nil && *my.Database != "" {
|
||||
@@ -297,15 +300,9 @@ func (uc *CreateMysqlBackupUsecase) createTempMyCnfFile(
|
||||
myConfig *mysqltypes.MysqlDatabase,
|
||||
password string,
|
||||
) (string, error) {
|
||||
tempFolder := config.GetEnv().TempFolder
|
||||
if err := os.MkdirAll(tempFolder, 0o700); err != nil {
|
||||
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
|
||||
}
|
||||
if err := os.Chmod(tempFolder, 0o700); err != nil {
|
||||
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
|
||||
}
|
||||
|
||||
tempDir, err := os.MkdirTemp(tempFolder, "mycnf_"+uuid.New().String())
|
||||
// Credential files use OS temp dir (/tmp) because some filesystems
|
||||
// (e.g. ZFS on TrueNAS) ignore chmod, causing "group or world access" errors.
|
||||
tempDir, err := os.MkdirTemp(os.TempDir(), "mycnf_"+uuid.New().String())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temp directory: %w", err)
|
||||
}
|
||||
@@ -326,6 +323,8 @@ port=%d
|
||||
|
||||
if myConfig.IsHttps {
|
||||
content += "ssl-mode=REQUIRED\n"
|
||||
} else {
|
||||
content += "ssl-mode=DISABLED\n"
|
||||
}
|
||||
|
||||
err = os.WriteFile(myCnfFile, []byte(content), 0o600)
|
||||
|
||||
@@ -747,15 +747,9 @@ func (uc *CreatePostgresqlBackupUsecase) createTempPgpassFile(
|
||||
escapedPassword,
|
||||
)
|
||||
|
||||
tempFolder := config.GetEnv().TempFolder
|
||||
if err := os.MkdirAll(tempFolder, 0o700); err != nil {
|
||||
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
|
||||
}
|
||||
if err := os.Chmod(tempFolder, 0o700); err != nil {
|
||||
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
|
||||
}
|
||||
|
||||
tempDir, err := os.MkdirTemp(tempFolder, "pgpass_"+uuid.New().String())
|
||||
// Credential files use OS temp dir (/tmp) because some filesystems
|
||||
// (e.g. ZFS on TrueNAS) ignore chmod, causing "group or world access" errors.
|
||||
tempDir, err := os.MkdirTemp(os.TempDir(), "pgpass_"+uuid.New().String())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temporary directory: %w", err)
|
||||
}
|
||||
|
||||
@@ -16,7 +16,6 @@ type BackupConfigController struct {
|
||||
|
||||
func (c *BackupConfigController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
router.POST("/backup-configs/save", c.SaveBackupConfig)
|
||||
router.GET("/backup-configs/database/:id/plan", c.GetDatabasePlan)
|
||||
router.GET("/backup-configs/database/:id", c.GetBackupConfigByDbID)
|
||||
router.GET("/backup-configs/storage/:id/is-using", c.IsStorageUsing)
|
||||
router.GET("/backup-configs/storage/:id/databases-count", c.CountDatabasesForStorage)
|
||||
@@ -93,39 +92,6 @@ func (c *BackupConfigController) GetBackupConfigByDbID(ctx *gin.Context) {
|
||||
ctx.JSON(http.StatusOK, backupConfig)
|
||||
}
|
||||
|
||||
// GetDatabasePlan
|
||||
// @Summary Get database plan by database ID
|
||||
// @Description Get the plan limits for a specific database (max backup size, max total size, max storage period)
|
||||
// @Tags backup-configs
|
||||
// @Produce json
|
||||
// @Param id path string true "Database ID"
|
||||
// @Success 200 {object} plans.DatabasePlan
|
||||
// @Failure 400 {object} map[string]string "Invalid database ID"
|
||||
// @Failure 401 {object} map[string]string "User not authenticated"
|
||||
// @Failure 404 {object} map[string]string "Database not found or access denied"
|
||||
// @Router /backup-configs/database/{id}/plan [get]
|
||||
func (c *BackupConfigController) GetDatabasePlan(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
|
||||
return
|
||||
}
|
||||
|
||||
plan, err := c.backupConfigService.GetDatabasePlan(user, id)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusNotFound, gin.H{"error": "database plan not found"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, plan)
|
||||
}
|
||||
|
||||
// IsStorageUsing
|
||||
// @Summary Check if storage is being used
|
||||
// @Description Check if a storage is currently being used by any backup configuration
|
||||
|
||||
@@ -17,14 +17,12 @@ import (
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/features/storages"
|
||||
local_storage "databasus-backend/internal/features/storages/models/local"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/storage"
|
||||
"databasus-backend/internal/util/period"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
"databasus-backend/internal/util/tools"
|
||||
@@ -326,218 +324,13 @@ func Test_GetBackupConfigByDbID_ReturnsDefaultConfigForNewDatabase(t *testing.T)
|
||||
&response,
|
||||
)
|
||||
|
||||
var plan plans.DatabasePlan
|
||||
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&plan,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, response.DatabaseID)
|
||||
assert.False(t, response.IsBackupsEnabled)
|
||||
assert.Equal(t, plan.MaxStoragePeriod, response.RetentionTimePeriod)
|
||||
assert.Equal(t, plan.MaxBackupSizeMB, response.MaxBackupSizeMB)
|
||||
assert.Equal(t, plan.MaxBackupsTotalSizeMB, response.MaxBackupsTotalSizeMB)
|
||||
assert.True(t, response.IsRetryIfFailed)
|
||||
assert.Equal(t, 3, response.MaxFailedTriesCount)
|
||||
assert.NotNil(t, response.BackupInterval)
|
||||
}
|
||||
|
||||
func Test_GetDatabasePlan_ForNewDatabase_PlanAlwaysReturned(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
var response plans.DatabasePlan
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, response.DatabaseID)
|
||||
assert.NotNil(t, response.MaxBackupSizeMB)
|
||||
assert.NotNil(t, response.MaxBackupsTotalSizeMB)
|
||||
assert.NotEmpty(t, response.MaxStoragePeriod)
|
||||
}
|
||||
|
||||
func Test_SaveBackupConfig_WhenPlanLimitsAreAdjusted_ValidationEnforced(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Get plan via API (triggers auto-creation)
|
||||
var plan plans.DatabasePlan
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&plan,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, plan.DatabaseID)
|
||||
|
||||
// Adjust plan limits directly in database to fixed restrictive values
|
||||
err := storage.GetDb().Model(&plans.DatabasePlan{}).
|
||||
Where("database_id = ?", database.ID).
|
||||
Updates(map[string]any{
|
||||
"max_backup_size_mb": 100,
|
||||
"max_backups_total_size_mb": 1000,
|
||||
"max_storage_period": period.PeriodMonth,
|
||||
}).Error
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test 1: Try to save backup config with exceeded backup size limit
|
||||
timeOfDay := "04:00"
|
||||
backupConfigExceededSize := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
MaxBackupSizeMB: 200, // Exceeds limit of 100
|
||||
MaxBackupsTotalSizeMB: 800,
|
||||
}
|
||||
|
||||
respExceededSize := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner.Token,
|
||||
backupConfigExceededSize,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
assert.Contains(t, string(respExceededSize.Body), "max backup size exceeds plan limit")
|
||||
|
||||
// Test 2: Try to save backup config with exceeded total size limit
|
||||
backupConfigExceededTotal := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
MaxBackupSizeMB: 50,
|
||||
MaxBackupsTotalSizeMB: 2000, // Exceeds limit of 1000
|
||||
}
|
||||
|
||||
respExceededTotal := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner.Token,
|
||||
backupConfigExceededTotal,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
assert.Contains(t, string(respExceededTotal.Body), "max total backups size exceeds plan limit")
|
||||
|
||||
// Test 3: Try to save backup config with exceeded storage period limit
|
||||
backupConfigExceededPeriod := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodYear, // Exceeds limit of Month
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
MaxBackupSizeMB: 80,
|
||||
MaxBackupsTotalSizeMB: 800,
|
||||
}
|
||||
|
||||
respExceededPeriod := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner.Token,
|
||||
backupConfigExceededPeriod,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
assert.Contains(t, string(respExceededPeriod.Body), "storage period exceeds plan limit")
|
||||
|
||||
// Test 4: Save backup config within all limits - should succeed
|
||||
backupConfigValid := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek, // Within Month limit
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
MaxBackupSizeMB: 80, // Within 100 limit
|
||||
MaxBackupsTotalSizeMB: 800, // Within 1000 limit
|
||||
}
|
||||
|
||||
var responseValid BackupConfig
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner.Token,
|
||||
backupConfigValid,
|
||||
http.StatusOK,
|
||||
&responseValid,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, responseValid.DatabaseID)
|
||||
assert.Equal(t, int64(80), responseValid.MaxBackupSizeMB)
|
||||
assert.Equal(t, int64(800), responseValid.MaxBackupsTotalSizeMB)
|
||||
assert.Equal(t, period.PeriodWeek, responseValid.RetentionTimePeriod)
|
||||
}
|
||||
|
||||
func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -2,14 +2,11 @@ package backups_config
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/features/storages"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -20,7 +17,6 @@ var (
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
plans.GetDatabasePlanService(),
|
||||
nil,
|
||||
}
|
||||
)
|
||||
@@ -37,21 +33,6 @@ func GetBackupConfigService() *BackupConfigService {
|
||||
return backupConfigService
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
var SetupDependencies = sync.OnceFunc(func() {
|
||||
storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService)
|
||||
})
|
||||
|
||||
@@ -7,5 +7,5 @@ type TransferDatabaseRequest struct {
|
||||
TargetStorageID *uuid.UUID `json:"targetStorageId,omitempty"`
|
||||
IsTransferWithStorage bool `json:"isTransferWithStorage,omitempty"`
|
||||
IsTransferWithNotifiers bool `json:"isTransferWithNotifiers,omitempty"`
|
||||
TargetNotifierIDs []uuid.UUID `json:"targetNotifierIds,omitempty"`
|
||||
TargetNotifierIDs []uuid.UUID `json:"targetNotifierIds,omitzero"`
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/features/storages"
|
||||
"databasus-backend/internal/util/period"
|
||||
)
|
||||
@@ -29,8 +28,8 @@ type BackupConfig struct {
|
||||
RetentionGfsMonths int `json:"retentionGfsMonths" gorm:"column:retention_gfs_months;type:int;not null;default:0"`
|
||||
RetentionGfsYears int `json:"retentionGfsYears" gorm:"column:retention_gfs_years;type:int;not null;default:0"`
|
||||
|
||||
BackupIntervalID uuid.UUID `json:"backupIntervalId" gorm:"column:backup_interval_id;type:uuid;not null"`
|
||||
BackupInterval *intervals.Interval `json:"backupInterval,omitempty" gorm:"foreignKey:BackupIntervalID"`
|
||||
BackupIntervalID uuid.UUID `json:"backupIntervalId" gorm:"column:backup_interval_id;type:uuid;not null"`
|
||||
BackupInterval *intervals.Interval `json:"backupInterval,omitzero" gorm:"foreignKey:BackupIntervalID"`
|
||||
|
||||
Storage *storages.Storage `json:"storage" gorm:"foreignKey:StorageID"`
|
||||
StorageID *uuid.UUID `json:"storageId" gorm:"column:storage_id;type:uuid;"`
|
||||
@@ -42,11 +41,6 @@ type BackupConfig struct {
|
||||
MaxFailedTriesCount int `json:"maxFailedTriesCount" gorm:"column:max_failed_tries_count;type:int;not null"`
|
||||
|
||||
Encryption BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
|
||||
|
||||
// MaxBackupSizeMB limits individual backup size. 0 = unlimited.
|
||||
MaxBackupSizeMB int64 `json:"maxBackupSizeMb" gorm:"column:max_backup_size_mb;type:int;not null"`
|
||||
// MaxBackupsTotalSizeMB limits total size of all backups. 0 = unlimited.
|
||||
MaxBackupsTotalSizeMB int64 `json:"maxBackupsTotalSizeMb" gorm:"column:max_backups_total_size_mb;type:int;not null"`
|
||||
}
|
||||
|
||||
func (h *BackupConfig) TableName() string {
|
||||
@@ -86,12 +80,12 @@ func (b *BackupConfig) AfterFind(tx *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *BackupConfig) Validate(plan *plans.DatabasePlan) error {
|
||||
func (b *BackupConfig) Validate() error {
|
||||
if b.BackupIntervalID == uuid.Nil && b.BackupInterval == nil {
|
||||
return errors.New("backup interval is required")
|
||||
}
|
||||
|
||||
if err := b.validateRetentionPolicy(plan); err != nil {
|
||||
if err := b.validateRetentionPolicy(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -110,67 +104,38 @@ func (b *BackupConfig) Validate(plan *plans.DatabasePlan) error {
|
||||
}
|
||||
}
|
||||
|
||||
if b.MaxBackupSizeMB < 0 {
|
||||
return errors.New("max backup size must be non-negative")
|
||||
}
|
||||
|
||||
if b.MaxBackupsTotalSizeMB < 0 {
|
||||
return errors.New("max backups total size must be non-negative")
|
||||
}
|
||||
|
||||
if plan.MaxBackupSizeMB > 0 {
|
||||
if b.MaxBackupSizeMB == 0 || b.MaxBackupSizeMB > plan.MaxBackupSizeMB {
|
||||
return errors.New("max backup size exceeds plan limit")
|
||||
}
|
||||
}
|
||||
|
||||
if plan.MaxBackupsTotalSizeMB > 0 {
|
||||
if b.MaxBackupsTotalSizeMB == 0 ||
|
||||
b.MaxBackupsTotalSizeMB > plan.MaxBackupsTotalSizeMB {
|
||||
return errors.New("max total backups size exceeds plan limit")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *BackupConfig) Copy(newDatabaseID uuid.UUID) *BackupConfig {
|
||||
return &BackupConfig{
|
||||
DatabaseID: newDatabaseID,
|
||||
IsBackupsEnabled: b.IsBackupsEnabled,
|
||||
RetentionPolicyType: b.RetentionPolicyType,
|
||||
RetentionTimePeriod: b.RetentionTimePeriod,
|
||||
RetentionCount: b.RetentionCount,
|
||||
RetentionGfsHours: b.RetentionGfsHours,
|
||||
RetentionGfsDays: b.RetentionGfsDays,
|
||||
RetentionGfsWeeks: b.RetentionGfsWeeks,
|
||||
RetentionGfsMonths: b.RetentionGfsMonths,
|
||||
RetentionGfsYears: b.RetentionGfsYears,
|
||||
BackupIntervalID: uuid.Nil,
|
||||
BackupInterval: b.BackupInterval.Copy(),
|
||||
StorageID: b.StorageID,
|
||||
SendNotificationsOn: b.SendNotificationsOn,
|
||||
IsRetryIfFailed: b.IsRetryIfFailed,
|
||||
MaxFailedTriesCount: b.MaxFailedTriesCount,
|
||||
Encryption: b.Encryption,
|
||||
MaxBackupSizeMB: b.MaxBackupSizeMB,
|
||||
MaxBackupsTotalSizeMB: b.MaxBackupsTotalSizeMB,
|
||||
DatabaseID: newDatabaseID,
|
||||
IsBackupsEnabled: b.IsBackupsEnabled,
|
||||
RetentionPolicyType: b.RetentionPolicyType,
|
||||
RetentionTimePeriod: b.RetentionTimePeriod,
|
||||
RetentionCount: b.RetentionCount,
|
||||
RetentionGfsHours: b.RetentionGfsHours,
|
||||
RetentionGfsDays: b.RetentionGfsDays,
|
||||
RetentionGfsWeeks: b.RetentionGfsWeeks,
|
||||
RetentionGfsMonths: b.RetentionGfsMonths,
|
||||
RetentionGfsYears: b.RetentionGfsYears,
|
||||
BackupIntervalID: uuid.Nil,
|
||||
BackupInterval: b.BackupInterval.Copy(),
|
||||
StorageID: b.StorageID,
|
||||
SendNotificationsOn: b.SendNotificationsOn,
|
||||
IsRetryIfFailed: b.IsRetryIfFailed,
|
||||
MaxFailedTriesCount: b.MaxFailedTriesCount,
|
||||
Encryption: b.Encryption,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BackupConfig) validateRetentionPolicy(plan *plans.DatabasePlan) error {
|
||||
func (b *BackupConfig) validateRetentionPolicy() error {
|
||||
switch b.RetentionPolicyType {
|
||||
case RetentionPolicyTypeTimePeriod, "":
|
||||
if b.RetentionTimePeriod == "" {
|
||||
return errors.New("retention time period is required")
|
||||
}
|
||||
|
||||
if plan.MaxStoragePeriod != period.PeriodForever {
|
||||
if b.RetentionTimePeriod.CompareTo(plan.MaxStoragePeriod) > 0 {
|
||||
return errors.New("storage period exceeds plan limit")
|
||||
}
|
||||
}
|
||||
|
||||
case RetentionPolicyTypeCount:
|
||||
if b.RetentionCount <= 0 {
|
||||
return errors.New("retention count must be greater than 0")
|
||||
|
||||
@@ -6,248 +6,34 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/util/period"
|
||||
)
|
||||
|
||||
func Test_Validate_WhenRetentionTimePeriodIsWeekAndPlanAllowsMonth_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodWeek
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodMonth
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenRetentionTimePeriodIsYearAndPlanAllowsMonth_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodYear
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodMonth
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "storage period exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenRetentionTimePeriodIsForeverAndPlanAllowsForever_ValidationPasses(
|
||||
t *testing.T,
|
||||
) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodForever
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodForever
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenRetentionTimePeriodIsForeverAndPlanAllowsYear_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodForever
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodYear
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "storage period exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenRetentionTimePeriodEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodMonth
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodMonth
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenBackupSize100MBAndPlanAllows500MB_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = 100
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupSizeMB = 500
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenBackupSize500MBAndPlanAllows100MB_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = 500
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupSizeMB = 100
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max backup size exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenBackupSizeIsUnlimitedAndPlanAllowsUnlimited_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupSizeMB = 0
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenBackupSizeIsUnlimitedAndPlanHas500MBLimit_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupSizeMB = 500
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max backup size exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenBackupSizeEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = 500
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupSizeMB = 500
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenTotalSize1GBAndPlanAllows5GB_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = 1000
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupsTotalSizeMB = 5000
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenTotalSize5GBAndPlanAllows1GB_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = 5000
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupsTotalSizeMB = 1000
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max total backups size exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenTotalSizeIsUnlimitedAndPlanAllowsUnlimited_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupsTotalSizeMB = 0
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenTotalSizeIsUnlimitedAndPlanHas1GBLimit_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupsTotalSizeMB = 1000
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max total backups size exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenTotalSizeEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = 5000
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupsTotalSizeMB = 5000
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenAllLimitsAreUnlimitedInPlan_AnyConfigurationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodForever
|
||||
config.MaxBackupSizeMB = 0
|
||||
config.MaxBackupsTotalSizeMB = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenMultipleLimitsExceeded_ValidationFailsWithFirstError(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodYear
|
||||
config.MaxBackupSizeMB = 500
|
||||
config.MaxBackupsTotalSizeMB = 5000
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodMonth
|
||||
plan.MaxBackupSizeMB = 100
|
||||
plan.MaxBackupsTotalSizeMB = 1000
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.Error(t, err)
|
||||
assert.EqualError(t, err, "storage period exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenConfigHasInvalidIntervalButPlanIsValid_ValidationFailsOnInterval(
|
||||
t *testing.T,
|
||||
) {
|
||||
func Test_Validate_WhenIntervalIsMissing_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.BackupIntervalID = uuid.Nil
|
||||
config.BackupInterval = nil
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.EqualError(t, err, "backup interval is required")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenIntervalIsMissing_ValidationFailsRegardlessOfPlan(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.BackupIntervalID = uuid.Nil
|
||||
config.BackupInterval = nil
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "backup interval is required")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenRetryEnabledButMaxTriesIsZero_ValidationFailsRegardlessOfPlan(t *testing.T) {
|
||||
func Test_Validate_WhenRetryEnabledButMaxTriesIsZero_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.IsRetryIfFailed = true
|
||||
config.MaxFailedTriesCount = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.EqualError(t, err, "max failed tries count must be greater than 0")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenEncryptionIsInvalid_ValidationFailsRegardlessOfPlan(t *testing.T) {
|
||||
func Test_Validate_WhenEncryptionIsInvalid_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.Encryption = "INVALID"
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.EqualError(t, err, "encryption must be NONE or ENCRYPTED")
|
||||
}
|
||||
|
||||
@@ -255,125 +41,16 @@ func Test_Validate_WhenRetentionTimePeriodIsEmpty_ValidationFails(t *testing.T)
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = ""
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.EqualError(t, err, "retention time period is required")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenMaxBackupSizeIsNegative_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = -100
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max backup size must be non-negative")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenMaxTotalSizeIsNegative_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = -1000
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max backups total size must be non-negative")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenPlanLimitsAreAtBoundary_ValidationWorks(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
configPeriod period.TimePeriod
|
||||
planPeriod period.TimePeriod
|
||||
configSize int64
|
||||
planSize int64
|
||||
configTotal int64
|
||||
planTotal int64
|
||||
shouldSucceed bool
|
||||
}{
|
||||
{
|
||||
name: "all values just under limit",
|
||||
configPeriod: period.PeriodWeek,
|
||||
planPeriod: period.PeriodMonth,
|
||||
configSize: 99,
|
||||
planSize: 100,
|
||||
configTotal: 999,
|
||||
planTotal: 1000,
|
||||
shouldSucceed: true,
|
||||
},
|
||||
{
|
||||
name: "all values equal to limit",
|
||||
configPeriod: period.PeriodMonth,
|
||||
planPeriod: period.PeriodMonth,
|
||||
configSize: 100,
|
||||
planSize: 100,
|
||||
configTotal: 1000,
|
||||
planTotal: 1000,
|
||||
shouldSucceed: true,
|
||||
},
|
||||
{
|
||||
name: "period just over limit",
|
||||
configPeriod: period.Period3Month,
|
||||
planPeriod: period.PeriodMonth,
|
||||
configSize: 100,
|
||||
planSize: 100,
|
||||
configTotal: 1000,
|
||||
planTotal: 1000,
|
||||
shouldSucceed: false,
|
||||
},
|
||||
{
|
||||
name: "size just over limit",
|
||||
configPeriod: period.PeriodMonth,
|
||||
planPeriod: period.PeriodMonth,
|
||||
configSize: 101,
|
||||
planSize: 100,
|
||||
configTotal: 1000,
|
||||
planTotal: 1000,
|
||||
shouldSucceed: false,
|
||||
},
|
||||
{
|
||||
name: "total size just over limit",
|
||||
configPeriod: period.PeriodMonth,
|
||||
planPeriod: period.PeriodMonth,
|
||||
configSize: 100,
|
||||
planSize: 100,
|
||||
configTotal: 1001,
|
||||
planTotal: 1000,
|
||||
shouldSucceed: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = tt.configPeriod
|
||||
config.MaxBackupSizeMB = tt.configSize
|
||||
config.MaxBackupsTotalSizeMB = tt.configTotal
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = tt.planPeriod
|
||||
plan.MaxBackupSizeMB = tt.planSize
|
||||
plan.MaxBackupsTotalSizeMB = tt.planTotal
|
||||
|
||||
err := config.Validate(plan)
|
||||
if tt.shouldSucceed {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Validate_WhenPolicyTypeIsCount_RequiresPositiveCount(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionPolicyType = RetentionPolicyTypeCount
|
||||
config.RetentionCount = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.EqualError(t, err, "retention count must be greater than 0")
|
||||
}
|
||||
|
||||
@@ -382,9 +59,7 @@ func Test_Validate_WhenPolicyTypeIsCount_WithPositiveCount_ValidationPasses(t *t
|
||||
config.RetentionPolicyType = RetentionPolicyTypeCount
|
||||
config.RetentionCount = 10
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -396,9 +71,7 @@ func Test_Validate_WhenPolicyTypeIsGFS_RequiresAtLeastOneField(t *testing.T) {
|
||||
config.RetentionGfsMonths = 0
|
||||
config.RetentionGfsYears = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.EqualError(t, err, "at least one GFS retention field must be greater than 0")
|
||||
}
|
||||
|
||||
@@ -407,9 +80,7 @@ func Test_Validate_WhenPolicyTypeIsGFS_WithOnlyHours_ValidationPasses(t *testing
|
||||
config.RetentionPolicyType = RetentionPolicyTypeGFS
|
||||
config.RetentionGfsHours = 24
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -418,9 +89,7 @@ func Test_Validate_WhenPolicyTypeIsGFS_WithOnlyDays_ValidationPasses(t *testing.
|
||||
config.RetentionPolicyType = RetentionPolicyTypeGFS
|
||||
config.RetentionGfsDays = 7
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -433,9 +102,7 @@ func Test_Validate_WhenPolicyTypeIsGFS_WithAllFields_ValidationPasses(t *testing
|
||||
config.RetentionGfsMonths = 12
|
||||
config.RetentionGfsYears = 3
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -443,35 +110,59 @@ func Test_Validate_WhenPolicyTypeIsInvalid_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionPolicyType = "INVALID"
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
err := config.Validate()
|
||||
assert.EqualError(t, err, "invalid retention policy type")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenCloudAndEncryptionIsNotEncrypted_ValidationFails(t *testing.T) {
|
||||
enableCloud(t)
|
||||
|
||||
backupConfig := createValidBackupConfig()
|
||||
backupConfig.Encryption = BackupEncryptionNone
|
||||
|
||||
err := backupConfig.Validate()
|
||||
assert.EqualError(t, err, "encryption is mandatory for cloud storage")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenCloudAndEncryptionIsEncrypted_ValidationPasses(t *testing.T) {
|
||||
enableCloud(t)
|
||||
|
||||
backupConfig := createValidBackupConfig()
|
||||
backupConfig.Encryption = BackupEncryptionEncrypted
|
||||
|
||||
err := backupConfig.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenNotCloudAndEncryptionIsNotEncrypted_ValidationPasses(t *testing.T) {
|
||||
backupConfig := createValidBackupConfig()
|
||||
backupConfig.Encryption = BackupEncryptionNone
|
||||
|
||||
err := backupConfig.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func enableCloud(t *testing.T) {
|
||||
t.Helper()
|
||||
config.GetEnv().IsCloud = true
|
||||
t.Cleanup(func() {
|
||||
config.GetEnv().IsCloud = false
|
||||
})
|
||||
}
|
||||
|
||||
func createValidBackupConfig() *BackupConfig {
|
||||
intervalID := uuid.New()
|
||||
return &BackupConfig{
|
||||
DatabaseID: uuid.New(),
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodMonth,
|
||||
BackupIntervalID: intervalID,
|
||||
BackupInterval: &intervals.Interval{ID: intervalID},
|
||||
SendNotificationsOn: []BackupNotificationType{},
|
||||
IsRetryIfFailed: false,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
MaxBackupSizeMB: 100,
|
||||
MaxBackupsTotalSizeMB: 1000,
|
||||
}
|
||||
}
|
||||
|
||||
func createUnlimitedPlan() *plans.DatabasePlan {
|
||||
return &plans.DatabasePlan{
|
||||
DatabaseID: uuid.New(),
|
||||
MaxBackupSizeMB: 0,
|
||||
MaxBackupsTotalSizeMB: 0,
|
||||
MaxStoragePeriod: period.PeriodForever,
|
||||
return &BackupConfig{
|
||||
DatabaseID: uuid.New(),
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodMonth,
|
||||
BackupIntervalID: intervalID,
|
||||
BackupInterval: &intervals.Interval{ID: intervalID},
|
||||
SendNotificationsOn: []BackupNotificationType{},
|
||||
IsRetryIfFailed: false,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,6 +62,14 @@ func (r *BackupConfigRepository) FindByDatabaseID(databaseID uuid.UUID) (*Backup
|
||||
GetDb().
|
||||
Preload("BackupInterval").
|
||||
Preload("Storage").
|
||||
Preload("Storage.LocalStorage").
|
||||
Preload("Storage.S3Storage").
|
||||
Preload("Storage.GoogleDriveStorage").
|
||||
Preload("Storage.NASStorage").
|
||||
Preload("Storage.AzureBlobStorage").
|
||||
Preload("Storage.FTPStorage").
|
||||
Preload("Storage.SFTPStorage").
|
||||
Preload("Storage.RcloneStorage").
|
||||
Where("database_id = ?", databaseID).
|
||||
First(&backupConfig).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
@@ -81,6 +89,14 @@ func (r *BackupConfigRepository) GetWithEnabledBackups() ([]*BackupConfig, error
|
||||
GetDb().
|
||||
Preload("BackupInterval").
|
||||
Preload("Storage").
|
||||
Preload("Storage.LocalStorage").
|
||||
Preload("Storage.S3Storage").
|
||||
Preload("Storage.GoogleDriveStorage").
|
||||
Preload("Storage.NASStorage").
|
||||
Preload("Storage.AzureBlobStorage").
|
||||
Preload("Storage.FTPStorage").
|
||||
Preload("Storage.SFTPStorage").
|
||||
Preload("Storage.RcloneStorage").
|
||||
Where("is_backups_enabled = ?", true).
|
||||
Find(&backupConfigs).Error; err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -8,10 +8,10 @@ import (
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_models "databasus-backend/internal/features/users/models"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/period"
|
||||
)
|
||||
|
||||
type BackupConfigService struct {
|
||||
@@ -20,7 +20,6 @@ type BackupConfigService struct {
|
||||
storageService *storages.StorageService
|
||||
notifierService *notifiers.NotifierService
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
databasePlanService *plans.DatabasePlanService
|
||||
|
||||
dbStorageChangeListener BackupConfigStorageChangeListener
|
||||
}
|
||||
@@ -46,12 +45,7 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
|
||||
user *users_models.User,
|
||||
backupConfig *BackupConfig,
|
||||
) (*BackupConfig, error) {
|
||||
plan, err := s.databasePlanService.GetDatabasePlan(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := backupConfig.Validate(plan); err != nil {
|
||||
if err := backupConfig.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -88,12 +82,7 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
|
||||
func (s *BackupConfigService) SaveBackupConfig(
|
||||
backupConfig *BackupConfig,
|
||||
) (*BackupConfig, error) {
|
||||
plan, err := s.databasePlanService.GetDatabasePlan(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := backupConfig.Validate(plan); err != nil {
|
||||
if err := backupConfig.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -131,18 +120,6 @@ func (s *BackupConfigService) GetBackupConfigByDbIdWithAuth(
|
||||
return s.GetBackupConfigByDbId(databaseID)
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) GetDatabasePlan(
|
||||
user *users_models.User,
|
||||
databaseID uuid.UUID,
|
||||
) (*plans.DatabasePlan, error) {
|
||||
_, err := s.databaseService.GetDatabase(user, databaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.databasePlanService.GetDatabasePlan(databaseID)
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) GetBackupConfigByDbId(
|
||||
databaseID uuid.UUID,
|
||||
) (*BackupConfig, error) {
|
||||
@@ -322,20 +299,13 @@ func (s *BackupConfigService) TransferDatabaseToWorkspace(
|
||||
func (s *BackupConfigService) initializeDefaultConfig(
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
plan, err := s.databasePlanService.GetDatabasePlan(databaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
timeOfDay := "04:00"
|
||||
|
||||
_, err = s.backupConfigRepository.Save(&BackupConfig{
|
||||
DatabaseID: databaseID,
|
||||
IsBackupsEnabled: false,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: plan.MaxStoragePeriod,
|
||||
MaxBackupSizeMB: plan.MaxBackupSizeMB,
|
||||
MaxBackupsTotalSizeMB: plan.MaxBackupsTotalSizeMB,
|
||||
_, err := s.backupConfigRepository.Save(&BackupConfig{
|
||||
DatabaseID: databaseID,
|
||||
IsBackupsEnabled: false,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.Period3Month,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
|
||||
305
backend/internal/features/billing/controller.go
Normal file
305
backend/internal/features/billing/controller.go
Normal file
@@ -0,0 +1,305 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
type BillingController struct {
|
||||
billingService *BillingService
|
||||
}
|
||||
|
||||
func (c *BillingController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
billing := router.Group("/billing")
|
||||
|
||||
billing.POST("/subscription", c.CreateSubscription)
|
||||
billing.POST("/subscription/change-storage", c.ChangeSubscriptionStorage)
|
||||
billing.POST("/subscription/portal/:subscription_id", c.GetPortalSession)
|
||||
billing.GET("/subscription/events/:subscription_id", c.GetSubscriptionEvents)
|
||||
billing.GET("/subscription/invoices/:subscription_id", c.GetInvoices)
|
||||
billing.GET("/subscription/:database_id", c.GetSubscription)
|
||||
}
|
||||
|
||||
// CreateSubscription
|
||||
// @Summary Create a new subscription
|
||||
// @Description Create a billing subscription for the specified database with the given storage
|
||||
// @Tags billing
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body CreateSubscriptionRequest true "Subscription creation data"
|
||||
// @Success 200 {object} CreateSubscriptionResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /billing/subscription [post]
|
||||
func (c *BillingController) CreateSubscription(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(401, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var request CreateSubscriptionRequest
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(400, gin.H{"error": "Invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
log := logger.GetLogger().With(
|
||||
"request_id", uuid.New(),
|
||||
"database_id", request.DatabaseID,
|
||||
"user_id", user.ID,
|
||||
)
|
||||
|
||||
transactionID, err := c.billingService.CreateSubscription(
|
||||
log,
|
||||
user,
|
||||
request.DatabaseID,
|
||||
request.StorageGB,
|
||||
)
|
||||
if err != nil {
|
||||
log.Error("Failed to create subscription", "error", err)
|
||||
ctx.JSON(500, gin.H{"error": "Failed to create subscription"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(200, CreateSubscriptionResponse{PaddleTransactionID: transactionID})
|
||||
}
|
||||
|
||||
// ChangeSubscriptionStorage
|
||||
// @Summary Change subscription storage
|
||||
// @Description Update the storage allocation for an existing subscription
|
||||
// @Tags billing
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body ChangeStorageRequest true "New storage configuration"
|
||||
// @Success 200 {object} ChangeStorageResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /billing/subscription/change-storage [post]
|
||||
func (c *BillingController) ChangeSubscriptionStorage(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(401, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var request ChangeStorageRequest
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(400, gin.H{"error": "Invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
log := logger.GetLogger().With(
|
||||
"request_id", uuid.New(),
|
||||
"database_id", request.DatabaseID,
|
||||
"user_id", user.ID,
|
||||
)
|
||||
|
||||
result, err := c.billingService.ChangeSubscriptionStorage(log, user, request.DatabaseID, request.StorageGB)
|
||||
if err != nil {
|
||||
log.Error("Failed to change subscription storage", "error", err)
|
||||
ctx.JSON(500, gin.H{"error": "Failed to change subscription storage"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(200, ChangeStorageResponse{
|
||||
ApplyMode: result.ApplyMode,
|
||||
CurrentGB: result.CurrentGB,
|
||||
PendingGB: result.PendingGB,
|
||||
})
|
||||
}
|
||||
|
||||
// GetPortalSession
|
||||
// @Summary Get billing portal session
|
||||
// @Description Generate a portal session URL for managing the subscription
|
||||
// @Tags billing
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param subscription_id path string true "Subscription ID"
|
||||
// @Success 200 {object} GetPortalSessionResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /billing/subscription/portal/{subscription_id} [post]
|
||||
func (c *BillingController) GetPortalSession(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(401, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
subscriptionID := ctx.Param("subscription_id")
|
||||
if subscriptionID == "" {
|
||||
ctx.JSON(400, gin.H{"error": "Subscription ID is required"})
|
||||
return
|
||||
}
|
||||
|
||||
log := logger.GetLogger().With(
|
||||
"request_id", uuid.New(),
|
||||
"subscription_id", subscriptionID,
|
||||
"user_id", user.ID,
|
||||
)
|
||||
|
||||
url, err := c.billingService.GetPortalURL(log, user, uuid.MustParse(subscriptionID))
|
||||
if err != nil {
|
||||
log.Error("Failed to get portal session", "error", err)
|
||||
ctx.JSON(500, gin.H{"error": "Failed to get portal session"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(200, GetPortalSessionResponse{PortalURL: url})
|
||||
}
|
||||
|
||||
// GetSubscriptionEvents
|
||||
// @Summary Get subscription events
|
||||
// @Description Retrieve the event history for a subscription
|
||||
// @Tags billing
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param subscription_id path string true "Subscription ID"
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Success 200 {object} GetSubscriptionEventsResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /billing/subscription/events/{subscription_id} [get]
|
||||
func (c *BillingController) GetSubscriptionEvents(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(401, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
subscriptionID, err := uuid.Parse(ctx.Param("subscription_id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid subscription ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var request PaginatedRequest
|
||||
if err := ctx.ShouldBindQuery(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
log := logger.GetLogger().With(
|
||||
"request_id", uuid.New(),
|
||||
"subscription_id", subscriptionID,
|
||||
"user_id", user.ID,
|
||||
)
|
||||
|
||||
response, err := c.billingService.GetSubscriptionEvents(log, user, subscriptionID, request.Limit, request.Offset)
|
||||
if err != nil {
|
||||
log.Error("Failed to get subscription events", "error", err)
|
||||
ctx.JSON(500, gin.H{"error": "Failed to get subscription events"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(200, response)
|
||||
}
|
||||
|
||||
// GetInvoices
|
||||
// @Summary Get subscription invoices
|
||||
// @Description Retrieve all invoices for a subscription
|
||||
// @Tags billing
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param subscription_id path string true "Subscription ID"
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Success 200 {object} GetInvoicesResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /billing/subscription/invoices/{subscription_id} [get]
|
||||
func (c *BillingController) GetInvoices(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(401, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
subscriptionID, err := uuid.Parse(ctx.Param("subscription_id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid subscription ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var request PaginatedRequest
|
||||
if err := ctx.ShouldBindQuery(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
log := logger.GetLogger().With(
|
||||
"request_id", uuid.New(),
|
||||
"subscription_id", subscriptionID,
|
||||
"user_id", user.ID,
|
||||
)
|
||||
|
||||
response, err := c.billingService.GetSubscriptionInvoices(log, user, subscriptionID, request.Limit, request.Offset)
|
||||
if err != nil {
|
||||
log.Error("Failed to get invoices", "error", err)
|
||||
ctx.JSON(500, gin.H{"error": "Failed to get invoices"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(200, response)
|
||||
}
|
||||
|
||||
// GetSubscription
|
||||
// @Summary Get subscription by database
|
||||
// @Description Retrieve the subscription associated with a specific database
|
||||
// @Tags billing
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param database_id path string true "Database ID"
|
||||
// @Success 200 {object} billing_models.Subscription
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /billing/subscription/{database_id} [get]
|
||||
func (c *BillingController) GetSubscription(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(401, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
databaseID, err := uuid.Parse(ctx.Param("database_id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid database ID"})
|
||||
return
|
||||
}
|
||||
|
||||
log := logger.GetLogger().With(
|
||||
"request_id", uuid.New(),
|
||||
"database_id", databaseID,
|
||||
"user_id", user.ID,
|
||||
)
|
||||
|
||||
subscription, err := c.billingService.GetSubscriptionByDatabaseID(log, user, databaseID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSubscriptionNotFound) {
|
||||
ctx.JSON(http.StatusNotFound, gin.H{"error": "Subscription not found"})
|
||||
return
|
||||
}
|
||||
|
||||
log.Error("failed to get subscription", "error", err)
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get subscription"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(200, subscription)
|
||||
}
|
||||
1450
backend/internal/features/billing/controller_test.go
Normal file
1450
backend/internal/features/billing/controller_test.go
Normal file
File diff suppressed because it is too large
Load Diff
35
backend/internal/features/billing/di.go
Normal file
35
backend/internal/features/billing/di.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
billing_repositories "databasus-backend/internal/features/billing/repositories"
|
||||
"databasus-backend/internal/features/databases"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
)
|
||||
|
||||
var (
|
||||
billingService = &BillingService{
|
||||
&billing_repositories.SubscriptionRepository{},
|
||||
&billing_repositories.SubscriptionEventRepository{},
|
||||
&billing_repositories.InvoiceRepository{},
|
||||
nil, // billing provider will be set later to avoid circular dependency
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
*databases.GetDatabaseService(),
|
||||
atomic.Bool{},
|
||||
}
|
||||
billingController = &BillingController{billingService}
|
||||
)
|
||||
|
||||
func GetBillingService() *BillingService {
|
||||
return billingService
|
||||
}
|
||||
|
||||
func GetBillingController() *BillingController {
|
||||
return billingController
|
||||
}
|
||||
|
||||
var SetupDependencies = sync.OnceFunc(func() {
|
||||
databases.GetDatabaseService().AddDbCreationListener(billingService)
|
||||
})
|
||||
67
backend/internal/features/billing/dto.go
Normal file
67
backend/internal/features/billing/dto.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
billing_models "databasus-backend/internal/features/billing/models"
|
||||
)
|
||||
|
||||
type CreateSubscriptionRequest struct {
|
||||
DatabaseID uuid.UUID `json:"databaseId" validate:"required"`
|
||||
StorageGB int `json:"storageGb" validate:"required,min=1"`
|
||||
}
|
||||
|
||||
type CreateSubscriptionResponse struct {
|
||||
PaddleTransactionID string `json:"paddleTransactionId"`
|
||||
}
|
||||
|
||||
type ChangeStorageApplyMode string
|
||||
|
||||
const (
|
||||
ChangeStorageApplyImmediate ChangeStorageApplyMode = "immediate"
|
||||
ChangeStorageApplyNextCycle ChangeStorageApplyMode = "next_cycle"
|
||||
)
|
||||
|
||||
type ChangeStorageRequest struct {
|
||||
DatabaseID uuid.UUID `json:"databaseId" validate:"required"`
|
||||
StorageGB int `json:"storageGb" validate:"required,min=1"`
|
||||
}
|
||||
|
||||
type ChangeStorageResponse struct {
|
||||
ApplyMode ChangeStorageApplyMode `json:"applyMode"`
|
||||
CurrentGB int `json:"currentGb"`
|
||||
PendingGB *int `json:"pendingGb,omitempty"`
|
||||
}
|
||||
|
||||
type PortalResponse struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type ChangeStorageResult struct {
|
||||
ApplyMode ChangeStorageApplyMode
|
||||
CurrentGB int
|
||||
PendingGB *int
|
||||
}
|
||||
|
||||
type GetPortalSessionResponse struct {
|
||||
PortalURL string `json:"url"`
|
||||
}
|
||||
|
||||
type PaginatedRequest struct {
|
||||
Limit int `form:"limit" json:"limit"`
|
||||
Offset int `form:"offset" json:"offset"`
|
||||
}
|
||||
|
||||
type GetSubscriptionEventsResponse struct {
|
||||
Events []*billing_models.SubscriptionEvent `json:"events"`
|
||||
Total int64 `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
}
|
||||
|
||||
type GetInvoicesResponse struct {
|
||||
Invoices []*billing_models.Invoice `json:"invoices"`
|
||||
Total int64 `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
}
|
||||
15
backend/internal/features/billing/errors.go
Normal file
15
backend/internal/features/billing/errors.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package billing
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrInvalidStorage = errors.New("storage must be between 20 and 10000 GB")
|
||||
ErrAlreadySubscribed = errors.New("database already has an active subscription")
|
||||
ErrExceedsUsage = errors.New("cannot downgrade below current storage usage")
|
||||
ErrNoChange = errors.New("requested storage is the same as current")
|
||||
ErrDuplicate = errors.New("duplicate event already processed")
|
||||
ErrProviderUnavailable = errors.New("payment provider unavailable")
|
||||
ErrNoActiveSubscription = errors.New("no active subscription for this database")
|
||||
ErrAccessDenied = errors.New("user does not have access to this database")
|
||||
ErrSubscriptionNotFound = errors.New("subscription not found")
|
||||
)
|
||||
24
backend/internal/features/billing/models/invoice.go
Normal file
24
backend/internal/features/billing/models/invoice.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package billing_models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Invoice struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
|
||||
SubscriptionID uuid.UUID `json:"subscriptionId" gorm:"column:subscription_id;type:uuid;not null"`
|
||||
ProviderInvoiceID string `json:"providerInvoiceId" gorm:"column:provider_invoice_id;type:text;not null"`
|
||||
AmountCents int64 `json:"amountCents" gorm:"column:amount_cents;type:bigint;not null"`
|
||||
StorageGB int `json:"storageGb" gorm:"column:storage_gb;type:int;not null"`
|
||||
PeriodStart time.Time `json:"periodStart" gorm:"column:period_start;type:timestamptz;not null"`
|
||||
PeriodEnd time.Time `json:"periodEnd" gorm:"column:period_end;type:timestamptz;not null"`
|
||||
Status InvoiceStatus `json:"status" gorm:"column:status;type:text;not null"`
|
||||
PaidAt *time.Time `json:"paidAt,omitzero" gorm:"column:paid_at;type:timestamptz"`
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;type:timestamptz;not null"`
|
||||
}
|
||||
|
||||
func (Invoice) TableName() string {
|
||||
return "invoices"
|
||||
}
|
||||
11
backend/internal/features/billing/models/invoice_status.go
Normal file
11
backend/internal/features/billing/models/invoice_status.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package billing_models
|
||||
|
||||
type InvoiceStatus string
|
||||
|
||||
const (
|
||||
InvoiceStatusPending InvoiceStatus = "pending"
|
||||
InvoiceStatusPaid InvoiceStatus = "paid"
|
||||
InvoiceStatusFailed InvoiceStatus = "failed"
|
||||
InvoiceStatusRefunded InvoiceStatus = "refunded"
|
||||
InvoiceStatusDisputed InvoiceStatus = "disputed"
|
||||
)
|
||||
72
backend/internal/features/billing/models/subscription.go
Normal file
72
backend/internal/features/billing/models/subscription.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package billing_models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
)
|
||||
|
||||
type Subscription struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
|
||||
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;not null"`
|
||||
Status SubscriptionStatus `json:"status" gorm:"column:status;type:text;not null"`
|
||||
|
||||
StorageGB int `json:"storageGb" gorm:"column:storage_gb;type:int;not null"`
|
||||
PendingStorageGB *int `json:"pendingStorageGb,omitempty" gorm:"column:pending_storage_gb;type:int"`
|
||||
|
||||
CurrentPeriodStart time.Time `json:"currentPeriodStart" gorm:"column:current_period_start;type:timestamptz;not null"`
|
||||
CurrentPeriodEnd time.Time `json:"currentPeriodEnd" gorm:"column:current_period_end;type:timestamptz;not null"`
|
||||
CanceledAt *time.Time `json:"canceledAt,omitzero" gorm:"column:canceled_at;type:timestamptz"`
|
||||
|
||||
DataRetentionGracePeriodUntil *time.Time `json:"dataRetentionGracePeriodUntil,omitzero" gorm:"column:data_retention_grace_period_until;type:timestamptz"`
|
||||
|
||||
ProviderName *string `json:"providerName,omitempty" gorm:"column:provider_name;type:text"`
|
||||
ProviderSubID *string `json:"providerSubId,omitempty" gorm:"column:provider_sub_id;type:text"`
|
||||
ProviderCustomerID *string `json:"providerCustomerId,omitempty" gorm:"column:provider_customer_id;type:text"`
|
||||
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;type:timestamptz;not null"`
|
||||
UpdatedAt time.Time `json:"updatedAt" gorm:"column:updated_at;type:timestamptz;not null"`
|
||||
}
|
||||
|
||||
func (Subscription) TableName() string {
|
||||
return "subscriptions"
|
||||
}
|
||||
|
||||
func (s *Subscription) PriceCents() int64 {
|
||||
return int64(s.StorageGB) * config.GetEnv().PricePerGBCents
|
||||
}
|
||||
|
||||
// CanCreateNewBackups - whether it is allowed to create new backups
|
||||
// by scheduler or for user manually. Clarification: in grace period
|
||||
// user can download, delete and restore backups, but cannot create new ones
|
||||
func (s *Subscription) CanCreateNewBackups() bool {
|
||||
switch s.Status {
|
||||
case StatusActive, StatusPastDue:
|
||||
return true
|
||||
case StatusTrial, StatusCanceled:
|
||||
return time.Now().Before(s.CurrentPeriodEnd)
|
||||
case StatusExpired:
|
||||
return false
|
||||
default:
|
||||
panic("unknown subscription status")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Subscription) GetBackupsStorageGB() int {
|
||||
switch s.Status {
|
||||
case StatusActive, StatusPastDue, StatusCanceled:
|
||||
return s.StorageGB
|
||||
case StatusTrial:
|
||||
if time.Now().Before(s.CurrentPeriodEnd) {
|
||||
return s.StorageGB
|
||||
}
|
||||
|
||||
return 0
|
||||
case StatusExpired:
|
||||
return 0
|
||||
default:
|
||||
panic("unknown subscription status")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package billing_models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type SubscriptionEvent struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
|
||||
SubscriptionID uuid.UUID `json:"subscriptionId" gorm:"column:subscription_id;type:uuid;not null"`
|
||||
ProviderEventID *string `json:"providerEventId,omitempty" gorm:"column:provider_event_id;type:text"`
|
||||
Type SubscriptionEventType `json:"type" gorm:"column:type;type:text;not null"`
|
||||
|
||||
OldStorageGB *int `json:"oldStorageGb,omitempty" gorm:"column:old_storage_gb;type:int"`
|
||||
NewStorageGB *int `json:"newStorageGb,omitempty" gorm:"column:new_storage_gb;type:int"`
|
||||
OldStatus *SubscriptionStatus `json:"oldStatus,omitempty" gorm:"column:old_status;type:text"`
|
||||
NewStatus *SubscriptionStatus `json:"newStatus,omitempty" gorm:"column:new_status;type:text"`
|
||||
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;type:timestamptz;not null"`
|
||||
}
|
||||
|
||||
func (SubscriptionEvent) TableName() string {
|
||||
return "subscription_events"
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package billing_models
|
||||
|
||||
type SubscriptionEventType string
|
||||
|
||||
const (
|
||||
EventCreated SubscriptionEventType = "subscription.created"
|
||||
EventUpgraded SubscriptionEventType = "subscription.upgraded"
|
||||
EventDowngraded SubscriptionEventType = "subscription.downgraded"
|
||||
EventNewBillingCycleStarted SubscriptionEventType = "subscription.new_billing_cycle_started"
|
||||
EventCanceled SubscriptionEventType = "subscription.canceled"
|
||||
EventReactivated SubscriptionEventType = "subscription.reactivated"
|
||||
EventExpired SubscriptionEventType = "subscription.expired"
|
||||
EventPastDue SubscriptionEventType = "subscription.past_due"
|
||||
EventRecoveredFromPastDue SubscriptionEventType = "subscription.recovered_from_past_due"
|
||||
EventRefund SubscriptionEventType = "payment.refund"
|
||||
EventDispute SubscriptionEventType = "payment.dispute"
|
||||
)
|
||||
@@ -0,0 +1,11 @@
|
||||
package billing_models
|
||||
|
||||
type SubscriptionStatus string
|
||||
|
||||
const (
|
||||
StatusTrial SubscriptionStatus = "trial" // trial period (~24h after DB creation)
|
||||
StatusActive SubscriptionStatus = "active" // paid, everything works
|
||||
StatusPastDue SubscriptionStatus = "past_due" // payment failed, trying to charge again, but everything still works
|
||||
StatusCanceled SubscriptionStatus = "canceled" // subscription canceled by user or after past_due (grace period is active)
|
||||
StatusExpired SubscriptionStatus = "expired" // grace period ended, data marked for deletion, can come from canceled and trial
|
||||
)
|
||||
22
backend/internal/features/billing/models/webhook_event.go
Normal file
22
backend/internal/features/billing/models/webhook_event.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package billing_models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type WebhookEvent struct {
|
||||
RequestID uuid.UUID
|
||||
ProviderEventID string
|
||||
DatabaseID *uuid.UUID
|
||||
Type WebhookEventType
|
||||
ProviderSubscriptionID string
|
||||
ProviderCustomerID string
|
||||
ProviderInvoiceID string
|
||||
QuantityGB int
|
||||
Status SubscriptionStatus
|
||||
PeriodStart *time.Time
|
||||
PeriodEnd *time.Time
|
||||
AmountCents int64
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package billing_models
|
||||
|
||||
type WebhookEventType string
|
||||
|
||||
const (
|
||||
WHEventSubscriptionCreated WebhookEventType = "subscription.created"
|
||||
WHEventSubscriptionUpdated WebhookEventType = "subscription.updated"
|
||||
WHEventSubscriptionCanceled WebhookEventType = "subscription.canceled"
|
||||
WHEventSubscriptionPastDue WebhookEventType = "subscription.past_due"
|
||||
WHEventSubscriptionReactivated WebhookEventType = "subscription.reactivated"
|
||||
WHEventPaymentSucceeded WebhookEventType = "payment.succeeded"
|
||||
WHEventSubscriptionDisputeCreated WebhookEventType = "dispute.created"
|
||||
)
|
||||
5
backend/internal/features/billing/paddle/README.md
Normal file
5
backend/internal/features/billing/paddle/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
**Paddle hints:**
|
||||
|
||||
- **max_quantity on price:** Paddle limits `quantity` on a price to 100 by default. You need to explicitly set the range (`quantity: {minimum: 20, maximum: 10000}`) when creating a price via API or dashboard. Otherwise requests with quantity > 100 will return an error.
|
||||
- **Full items list on update:** Unlike Stripe, Paddle requires sending **all** subscription items in `PATCH /subscriptions/{id}`, not just the changed ones. `proration_billing_mode` is also required. Without this you can accidentally remove a line item or get a 400.
|
||||
- **Webhook events mapping:** Paddle uses `transaction.completed` instead of `payment.succeeded`, `transaction.payment_failed` instead of `payment.failed`, `adjustment.created` instead of `dispute.created`.
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user