diff --git a/AGENTS.md b/AGENTS.md index f931b0d..7979259 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -73,6 +73,10 @@ This is NOT a strict set of rules, but a set of recommendations to help you writ - Patch the answer accordingly - Verify edge cases are handled +7. **Fix the reason, not the symptom:** + - If you find a bug or issue, ask "Why did this happen?" and fix the root cause + - Avoid quick fixes that don't address underlying problems + ### Application guidelines: **Scale your response to the task:** @@ -88,6 +92,36 @@ This is NOT a strict set of rules, but a set of recommendations to help you writ ## Backend guidelines +### Naming + +Variables and functions naming are the most important part of code readability. Always choose descriptive and meaningful names that clearly indicate the purpose and intent of the code. + +Avoid abbreviations, unless they are widely accepted and unambiguous (e.g., `ID`, `URL`, `HTTP`). Use consistent naming conventions across the codebase. + +Do not use one-two letters. For example: + +Bad: + +``` + u := users.getUser() + + pr, pw := io.Pipe() + + r := bufio.NewReader(pr) +``` + +Good: + +``` + user := users.GetUser() + + pipeReader, pipeWriter := io.Pipe() + + bufferedReader := bufio.NewReader(pipeReader) +``` + +Exclusion: widely used variables like "db", "ctx", "req", "res", etc. + ### Code style **Always place private methods to the bottom of file** diff --git a/agent/.gitignore b/agent/.gitignore index da0f211..554a966 100644 --- a/agent/.gitignore +++ b/agent/.gitignore @@ -21,4 +21,6 @@ cmd.exe temp/ valkey-data/ victoria-logs-data/ -databasus.json \ No newline at end of file +databasus.json +.test-tmp/ +databasus.log \ No newline at end of file diff --git a/agent/cmd/main.go b/agent/cmd/main.go index 38b90c5..8665002 100644 --- a/agent/cmd/main.go +++ b/agent/cmd/main.go @@ -1,14 +1,17 @@ package main import ( + "errors" "flag" "fmt" "log/slog" "os" "path/filepath" "strings" + "syscall" "databasus-agent/internal/config" + "databasus-agent/internal/features/api" "databasus-agent/internal/features/start" "databasus-agent/internal/features/upgrade" "databasus-agent/internal/logger" @@ -25,6 +28,8 @@ func main() { switch os.Args[1] { case "start": runStart(os.Args[2:]) + case "_run": + runDaemon(os.Args[2:]) case "stop": runStop() case "status": @@ -43,7 +48,6 @@ func main() { func runStart(args []string) { fs := flag.NewFlagSet("start", flag.ExitOnError) - isDebug := fs.Bool("debug", false, "Enable debug logging") isSkipUpdate := fs.Bool("skip-update", false, "Skip auto-update check") cfg := &config.Config{} @@ -53,26 +57,59 @@ func runStart(args []string) { fmt.Fprintf(os.Stderr, "Failed to save config: %v\n", err) } - logger.Init(*isDebug) log := logger.GetLogger() isDev := checkIsDevelopment() runUpdateCheck(cfg.DatabasusHost, *isSkipUpdate, isDev, log) - if err := start.Run(cfg, log); err != nil { + if err := start.Start(cfg, Version, isDev, log); err != nil { + if errors.Is(err, upgrade.ErrUpgradeRestart) { + reexecAfterUpgrade(log) + } + fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } } +func runDaemon(args []string) { + fs := flag.NewFlagSet("_run", flag.ExitOnError) + + if err := fs.Parse(args); err != nil { + os.Exit(1) + } + + log := logger.GetLogger() + + cfg := &config.Config{} + cfg.LoadFromJSON() + + if err := start.RunDaemon(cfg, Version, checkIsDevelopment(), log); err != nil { + if errors.Is(err, upgrade.ErrUpgradeRestart) { + reexecAfterUpgrade(log) + } + + log.Error("Agent exited with error", "error", err) + os.Exit(1) + } +} + func runStop() { - logger.Init(false) - logger.GetLogger().Info("stop: stub — not yet implemented") + log := logger.GetLogger() + + if err := start.Stop(log); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } } func runStatus() { - logger.Init(false) - logger.GetLogger().Info("status: stub — not yet implemented") + log := logger.GetLogger() + + if err := start.Status(log); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } } func runRestore(args []string) { @@ -82,7 +119,6 @@ func runRestore(args []string) { backupID := fs.String("backup-id", "", "Full backup UUID (optional)") targetTime := fs.String("target-time", "", "PITR target time in RFC3339 (optional)") isYes := fs.Bool("yes", false, "Skip confirmation prompt") - isDebug := fs.Bool("debug", false, "Enable debug logging") isSkipUpdate := fs.Bool("skip-update", false, "Skip auto-update check") cfg := &config.Config{} @@ -92,7 +128,6 @@ func runRestore(args []string) { fmt.Fprintf(os.Stderr, "Failed to save config: %v\n", err) } - logger.Init(*isDebug) log := logger.GetLogger() isDev := checkIsDevelopment() @@ -126,10 +161,17 @@ func runUpdateCheck(host string, isSkipUpdate, isDev bool, log *slog.Logger) { return } - if err := upgrade.CheckAndUpdate(host, Version, isDev, log); err != nil { + apiClient := api.NewClient(host, "", log) + + isUpgraded, err := upgrade.CheckAndUpdate(apiClient, Version, isDev, log) + if err != nil { log.Error("Auto-update failed", "error", err) os.Exit(1) } + + if isUpgraded { + reexecAfterUpgrade(log) + } } func checkIsDevelopment() bool { @@ -168,3 +210,18 @@ func parseEnvMode(data []byte) bool { return false } + +func reexecAfterUpgrade(log *slog.Logger) { + selfPath, err := os.Executable() + if err != nil { + log.Error("Failed to resolve executable for re-exec", "error", err) + os.Exit(1) + } + + log.Info("Re-executing after upgrade...") + + if err := syscall.Exec(selfPath, os.Args, os.Environ()); err != nil { + log.Error("Failed to re-exec after upgrade", "error", err) + os.Exit(1) + } +} diff --git a/agent/e2e/mock-server/main.go b/agent/e2e/mock-server/main.go index 60d0432..d55d3f6 100644 --- a/agent/e2e/mock-server/main.go +++ b/agent/e2e/mock-server/main.go @@ -24,6 +24,7 @@ func main() { http.HandleFunc("/api/v1/system/version", s.handleVersion) http.HandleFunc("/api/v1/system/agent", s.handleAgentDownload) http.HandleFunc("/mock/set-version", s.handleSetVersion) + http.HandleFunc("/mock/set-binary-path", s.handleSetBinaryPath) http.HandleFunc("/health", s.handleHealth) addr := ":" + port @@ -78,6 +79,29 @@ func (s *server) handleSetVersion(w http.ResponseWriter, r *http.Request) { _, _ = fmt.Fprintf(w, "version set to %s", body.Version) } +func (s *server) handleSetBinaryPath(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "POST only", http.StatusMethodNotAllowed) + return + } + + var body struct { + BinaryPath string `json:"binaryPath"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + s.mu.Lock() + s.binaryPath = body.BinaryPath + s.mu.Unlock() + + log.Printf("POST /mock/set-binary-path -> %s", body.BinaryPath) + + _, _ = fmt.Fprintf(w, "binary path set to %s", body.BinaryPath) +} + func (s *server) handleHealth(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("ok")) diff --git a/agent/e2e/scripts/run-all.sh b/agent/e2e/scripts/run-all.sh index 0cc9256..a026e4d 100644 --- a/agent/e2e/scripts/run-all.sh +++ b/agent/e2e/scripts/run-all.sh @@ -27,11 +27,12 @@ run_test() { if [ "$MODE" = "host" ]; then run_test "Test 1: Upgrade success (v1 -> v2)" "$SCRIPT_DIR/test-upgrade-success.sh" run_test "Test 2: Upgrade skip (version matches)" "$SCRIPT_DIR/test-upgrade-skip.sh" - run_test "Test 3: pg_basebackup in PATH" "$SCRIPT_DIR/test-pg-host-path.sh" - run_test "Test 4: pg_basebackup via bindir" "$SCRIPT_DIR/test-pg-host-bindir.sh" + run_test "Test 3: Background upgrade (v1 -> v2 while running)" "$SCRIPT_DIR/test-upgrade-background.sh" + run_test "Test 4: pg_basebackup in PATH" "$SCRIPT_DIR/test-pg-host-path.sh" + run_test "Test 5: pg_basebackup via bindir" "$SCRIPT_DIR/test-pg-host-bindir.sh" elif [ "$MODE" = "docker" ]; then - run_test "Test 5: pg_basebackup via docker exec" "$SCRIPT_DIR/test-pg-docker-exec.sh" + run_test "Test 6: pg_basebackup via docker exec" "$SCRIPT_DIR/test-pg-docker-exec.sh" else echo "Unknown mode: $MODE (expected 'host' or 'docker')" diff --git a/agent/e2e/scripts/test-pg-docker-exec.sh b/agent/e2e/scripts/test-pg-docker-exec.sh index 56a13d4..bc166bb 100644 --- a/agent/e2e/scripts/test-pg-docker-exec.sh +++ b/agent/e2e/scripts/test-pg-docker-exec.sh @@ -5,6 +5,16 @@ ARTIFACTS="/opt/agent/artifacts" AGENT="/tmp/test-agent" PG_CONTAINER="e2e-agent-postgres" +# Cleanup from previous runs +pkill -f "test-agent" 2>/dev/null || true +for i in $(seq 1 20); do + pgrep -f "test-agent" > /dev/null 2>&1 || break + sleep 0.5 +done +pkill -9 -f "test-agent" 2>/dev/null || true +sleep 0.5 +rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true + # Copy agent binary cp "$ARTIFACTS/agent-v1" "$AGENT" chmod +x "$AGENT" @@ -26,7 +36,7 @@ OUTPUT=$("$AGENT" start \ --pg-port 5432 \ --pg-user testuser \ --pg-password testpassword \ - --wal-dir /tmp/wal \ + --pg-wal-dir /tmp/wal \ --pg-type docker \ --pg-docker-container-name "$PG_CONTAINER" 2>&1) diff --git a/agent/e2e/scripts/test-pg-host-bindir.sh b/agent/e2e/scripts/test-pg-host-bindir.sh index 38f911e..3be4033 100644 --- a/agent/e2e/scripts/test-pg-host-bindir.sh +++ b/agent/e2e/scripts/test-pg-host-bindir.sh @@ -5,6 +5,16 @@ ARTIFACTS="/opt/agent/artifacts" AGENT="/tmp/test-agent" CUSTOM_BIN_DIR="/opt/pg/bin" +# Cleanup from previous runs +pkill -f "test-agent" 2>/dev/null || true +for i in $(seq 1 20); do + pgrep -f "test-agent" > /dev/null 2>&1 || break + sleep 0.5 +done +pkill -9 -f "test-agent" 2>/dev/null || true +sleep 0.5 +rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true + # Copy agent binary cp "$ARTIFACTS/agent-v1" "$AGENT" chmod +x "$AGENT" @@ -32,7 +42,7 @@ OUTPUT=$("$AGENT" start \ --pg-port 5432 \ --pg-user testuser \ --pg-password testpassword \ - --wal-dir /tmp/wal \ + --pg-wal-dir /tmp/wal \ --pg-type host \ --pg-host-bin-dir "$CUSTOM_BIN_DIR" 2>&1) diff --git a/agent/e2e/scripts/test-pg-host-path.sh b/agent/e2e/scripts/test-pg-host-path.sh index b3c73bb..bf6f5d8 100644 --- a/agent/e2e/scripts/test-pg-host-path.sh +++ b/agent/e2e/scripts/test-pg-host-path.sh @@ -4,6 +4,16 @@ set -euo pipefail ARTIFACTS="/opt/agent/artifacts" AGENT="/tmp/test-agent" +# Cleanup from previous runs +pkill -f "test-agent" 2>/dev/null || true +for i in $(seq 1 20); do + pgrep -f "test-agent" > /dev/null 2>&1 || break + sleep 0.5 +done +pkill -9 -f "test-agent" 2>/dev/null || true +sleep 0.5 +rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true + # Copy agent binary cp "$ARTIFACTS/agent-v1" "$AGENT" chmod +x "$AGENT" @@ -25,7 +35,7 @@ OUTPUT=$("$AGENT" start \ --pg-port 5432 \ --pg-user testuser \ --pg-password testpassword \ - --wal-dir /tmp/wal \ + --pg-wal-dir /tmp/wal \ --pg-type host 2>&1) EXIT_CODE=$? diff --git a/agent/e2e/scripts/test-upgrade-background.sh b/agent/e2e/scripts/test-upgrade-background.sh new file mode 100644 index 0000000..b85e6bf --- /dev/null +++ b/agent/e2e/scripts/test-upgrade-background.sh @@ -0,0 +1,90 @@ +#!/bin/bash +set -euo pipefail + +ARTIFACTS="/opt/agent/artifacts" +AGENT="/tmp/test-agent" + +# Cleanup from previous runs +pkill -f "test-agent" 2>/dev/null || true +for i in $(seq 1 20); do + pgrep -f "test-agent" > /dev/null 2>&1 || break + sleep 0.5 +done +pkill -9 -f "test-agent" 2>/dev/null || true +sleep 0.5 +rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true + +# Set mock server to v1.0.0 (same as agent — no sync upgrade on start) +curl -sf -X POST http://e2e-mock-server:4050/mock/set-version \ + -H "Content-Type: application/json" \ + -d '{"version":"v1.0.0"}' + +curl -sf -X POST http://e2e-mock-server:4050/mock/set-binary-path \ + -H "Content-Type: application/json" \ + -d '{"binaryPath":"/artifacts/agent-v1"}' + +# Copy v1 binary to writable location +cp "$ARTIFACTS/agent-v1" "$AGENT" +chmod +x "$AGENT" + +# Verify initial version +VERSION=$("$AGENT" version) +if [ "$VERSION" != "v1.0.0" ]; then + echo "FAIL: Expected initial version v1.0.0, got $VERSION" + exit 1 +fi +echo "Initial version: $VERSION" + +# Start agent as daemon (versions match → no sync upgrade) +mkdir -p /tmp/wal +"$AGENT" start \ + --databasus-host http://e2e-mock-server:4050 \ + --db-id test-db-id \ + --token test-token \ + --pg-host e2e-postgres \ + --pg-port 5432 \ + --pg-user testuser \ + --pg-password testpassword \ + --pg-wal-dir /tmp/wal \ + --pg-type host + +echo "Agent started as daemon, waiting for stabilization..." +sleep 2 + +# Change mock server to v2.0.0 and point to v2 binary +curl -sf -X POST http://e2e-mock-server:4050/mock/set-version \ + -H "Content-Type: application/json" \ + -d '{"version":"v2.0.0"}' + +curl -sf -X POST http://e2e-mock-server:4050/mock/set-binary-path \ + -H "Content-Type: application/json" \ + -d '{"binaryPath":"/artifacts/agent-v2"}' + +echo "Mock server updated to v2.0.0, waiting for background upgrade..." + +# Poll for upgrade (timeout 60s, poll every 3s) +DEADLINE=$((SECONDS + 60)) +while [ $SECONDS -lt $DEADLINE ]; do + VERSION=$("$AGENT" version) + if [ "$VERSION" = "v2.0.0" ]; then + echo "Binary upgraded to $VERSION" + break + fi + sleep 3 +done + +VERSION=$("$AGENT" version) +if [ "$VERSION" != "v2.0.0" ]; then + echo "FAIL: Expected v2.0.0 after background upgrade, got $VERSION" + cat databasus.log 2>/dev/null || true + exit 1 +fi + +# Verify agent is still running after restart +sleep 2 +"$AGENT" status || true + +# Cleanup +"$AGENT" stop || true + +echo "Background upgrade test passed" diff --git a/agent/e2e/scripts/test-upgrade-skip.sh b/agent/e2e/scripts/test-upgrade-skip.sh index 3bf7d0b..06feab2 100644 --- a/agent/e2e/scripts/test-upgrade-skip.sh +++ b/agent/e2e/scripts/test-upgrade-skip.sh @@ -4,6 +4,16 @@ set -euo pipefail ARTIFACTS="/opt/agent/artifacts" AGENT="/tmp/test-agent" +# Cleanup from previous runs +pkill -f "test-agent" 2>/dev/null || true +for i in $(seq 1 20); do + pgrep -f "test-agent" > /dev/null 2>&1 || break + sleep 0.5 +done +pkill -9 -f "test-agent" 2>/dev/null || true +sleep 0.5 +rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true + # Set mock server to return v1.0.0 (same as agent) curl -sf -X POST http://e2e-mock-server:4050/mock/set-version \ -H "Content-Type: application/json" \ @@ -30,7 +40,7 @@ OUTPUT=$("$AGENT" start \ --pg-port 5432 \ --pg-user testuser \ --pg-password testpassword \ - --wal-dir /tmp/wal \ + --pg-wal-dir /tmp/wal \ --pg-type host 2>&1) || true echo "$OUTPUT" diff --git a/agent/e2e/scripts/test-upgrade-success.sh b/agent/e2e/scripts/test-upgrade-success.sh index 338ebd6..722a4e1 100644 --- a/agent/e2e/scripts/test-upgrade-success.sh +++ b/agent/e2e/scripts/test-upgrade-success.sh @@ -4,11 +4,25 @@ set -euo pipefail ARTIFACTS="/opt/agent/artifacts" AGENT="/tmp/test-agent" -# Ensure mock server returns v2.0.0 +# Cleanup from previous runs +pkill -f "test-agent" 2>/dev/null || true +for i in $(seq 1 20); do + pgrep -f "test-agent" > /dev/null 2>&1 || break + sleep 0.5 +done +pkill -9 -f "test-agent" 2>/dev/null || true +sleep 0.5 +rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true + +# Ensure mock server returns v2.0.0 and serves v2 binary curl -sf -X POST http://e2e-mock-server:4050/mock/set-version \ -H "Content-Type: application/json" \ -d '{"version":"v2.0.0"}' +curl -sf -X POST http://e2e-mock-server:4050/mock/set-binary-path \ + -H "Content-Type: application/json" \ + -d '{"binaryPath":"/artifacts/agent-v2"}' + # Copy v1 binary to writable location cp "$ARTIFACTS/agent-v1" "$AGENT" chmod +x "$AGENT" @@ -37,7 +51,7 @@ OUTPUT=$("$AGENT" start \ --pg-port 5432 \ --pg-user testuser \ --pg-password testpassword \ - --wal-dir /tmp/wal \ + --pg-wal-dir /tmp/wal \ --pg-type host 2>&1) || true echo "$OUTPUT" diff --git a/agent/go.mod b/agent/go.mod index ccec2af..ec3248c 100644 --- a/agent/go.mod +++ b/agent/go.mod @@ -3,7 +3,9 @@ module databasus-agent go 1.26.1 require ( + github.com/go-resty/resty/v2 v2.17.2 github.com/jackc/pgx/v5 v5.8.0 + github.com/klauspost/compress v1.18.4 github.com/stretchr/testify v1.11.1 ) @@ -14,6 +16,7 @@ require ( github.com/kr/text v0.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.14.1 // indirect + golang.org/x/net v0.43.0 // indirect golang.org/x/text v0.29.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/agent/go.sum b/agent/go.sum index 6fe46bd..76740e0 100644 --- a/agent/go.sum +++ b/agent/go.sum @@ -2,6 +2,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-resty/resty/v2 v2.17.2 h1:FQW5oHYcIlkCNrMD2lloGScxcHJ0gkjshV3qcQAyHQk= +github.com/go-resty/resty/v2 v2.17.2/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -10,6 +12,8 @@ github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo= github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c= +github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -23,10 +27,14 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= +golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= +golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/agent/internal/config/config.go b/agent/internal/config/config.go index 17709de..a70e5a2 100644 --- a/agent/internal/config/config.go +++ b/agent/internal/config/config.go @@ -24,7 +24,7 @@ type Config struct { PgType string `json:"pgType"` PgHostBinDir string `json:"pgHostBinDir"` PgDockerContainerName string `json:"pgDockerContainerName"` - WalDir string `json:"walDir"` + PgWalDir string `json:"pgWalDir"` IsDeleteWalAfterUpload *bool `json:"deleteWalAfterUpload"` flags parsedFlags @@ -51,7 +51,7 @@ func (c *Config) LoadFromJSONAndArgs(fs *flag.FlagSet, args []string) { c.flags.pgType = fs.String("pg-type", "", "PostgreSQL type: host or docker") c.flags.pgHostBinDir = fs.String("pg-host-bin-dir", "", "Path to PG bin directory (host mode)") c.flags.pgDockerContainerName = fs.String("pg-docker-container-name", "", "Docker container name (docker mode)") - c.flags.walDir = fs.String("wal-dir", "", "Path to WAL queue directory") + c.flags.pgWalDir = fs.String("pg-wal-dir", "", "Path to WAL queue directory") if err := fs.Parse(args); err != nil { os.Exit(1) @@ -73,6 +73,11 @@ func (c *Config) SaveToJSON() error { return os.WriteFile(configFileName, data, 0o644) } +func (c *Config) LoadFromJSON() { + c.loadFromJSON() + c.applyDefaults() +} + func (c *Config) loadFromJSON() { data, err := os.ReadFile(configFileName) if err != nil { @@ -122,7 +127,7 @@ func (c *Config) initSources() { "pg-type": "not configured", "pg-host-bin-dir": "not configured", "pg-docker-container-name": "not configured", - "wal-dir": "not configured", + "pg-wal-dir": "not configured", "delete-wal-after-upload": "not configured", } @@ -164,8 +169,8 @@ func (c *Config) initSources() { c.flags.sources["pg-docker-container-name"] = configFileName } - if c.WalDir != "" { - c.flags.sources["wal-dir"] = configFileName + if c.PgWalDir != "" { + c.flags.sources["pg-wal-dir"] = configFileName } // IsDeleteWalAfterUpload always has a value after applyDefaults @@ -223,9 +228,9 @@ func (c *Config) applyFlags() { c.flags.sources["pg-docker-container-name"] = "command line args" } - if c.flags.walDir != nil && *c.flags.walDir != "" { - c.WalDir = *c.flags.walDir - c.flags.sources["wal-dir"] = "command line args" + if c.flags.pgWalDir != nil && *c.flags.pgWalDir != "" { + c.PgWalDir = *c.flags.pgWalDir + c.flags.sources["pg-wal-dir"] = "command line args" } } @@ -246,7 +251,7 @@ func (c *Config) logConfigSources() { "source", c.flags.sources["pg-docker-container-name"], ) - log.Info("wal-dir", "value", c.WalDir, "source", c.flags.sources["wal-dir"]) + log.Info("pg-wal-dir", "value", c.PgWalDir, "source", c.flags.sources["pg-wal-dir"]) log.Info( "delete-wal-after-upload", "value", diff --git a/agent/internal/config/config_test.go b/agent/internal/config/config_test.go index 8a5b9bd..79726b5 100644 --- a/agent/internal/config/config_test.go +++ b/agent/internal/config/config_test.go @@ -142,7 +142,7 @@ func Test_LoadFromJSONAndArgs_PgFieldsLoadedFromJSON(t *testing.T) { PgType: "docker", PgHostBinDir: "/usr/bin", PgDockerContainerName: "pg-container", - WalDir: "/opt/wal", + PgWalDir: "/opt/wal", IsDeleteWalAfterUpload: &deleteWal, }) @@ -157,7 +157,7 @@ func Test_LoadFromJSONAndArgs_PgFieldsLoadedFromJSON(t *testing.T) { assert.Equal(t, "docker", cfg.PgType) assert.Equal(t, "/usr/bin", cfg.PgHostBinDir) assert.Equal(t, "pg-container", cfg.PgDockerContainerName) - assert.Equal(t, "/opt/wal", cfg.WalDir) + assert.Equal(t, "/opt/wal", cfg.PgWalDir) assert.Equal(t, false, *cfg.IsDeleteWalAfterUpload) } @@ -174,7 +174,7 @@ func Test_LoadFromJSONAndArgs_PgFieldsLoadedFromArgs(t *testing.T) { "--pg-type", "docker", "--pg-host-bin-dir", "/custom/bin", "--pg-docker-container-name", "my-pg", - "--wal-dir", "/var/wal", + "--pg-wal-dir", "/var/wal", }) assert.Equal(t, "arg-pg-host", cfg.PgHost) @@ -184,17 +184,17 @@ func Test_LoadFromJSONAndArgs_PgFieldsLoadedFromArgs(t *testing.T) { assert.Equal(t, "docker", cfg.PgType) assert.Equal(t, "/custom/bin", cfg.PgHostBinDir) assert.Equal(t, "my-pg", cfg.PgDockerContainerName) - assert.Equal(t, "/var/wal", cfg.WalDir) + assert.Equal(t, "/var/wal", cfg.PgWalDir) } func Test_LoadFromJSONAndArgs_PgArgsOverrideJSON(t *testing.T) { dir := setupTempDir(t) writeConfigJSON(t, dir, Config{ - PgHost: "json-host", - PgPort: 5432, - PgUser: "json-user", - PgType: "host", - WalDir: "/json/wal", + PgHost: "json-host", + PgPort: 5432, + PgUser: "json-user", + PgType: "host", + PgWalDir: "/json/wal", }) cfg := &Config{} @@ -205,7 +205,7 @@ func Test_LoadFromJSONAndArgs_PgArgsOverrideJSON(t *testing.T) { "--pg-user", "arg-user", "--pg-type", "docker", "--pg-docker-container-name", "my-container", - "--wal-dir", "/arg/wal", + "--pg-wal-dir", "/arg/wal", }) assert.Equal(t, "arg-host", cfg.PgHost) @@ -213,7 +213,7 @@ func Test_LoadFromJSONAndArgs_PgArgsOverrideJSON(t *testing.T) { assert.Equal(t, "arg-user", cfg.PgUser) assert.Equal(t, "docker", cfg.PgType) assert.Equal(t, "my-container", cfg.PgDockerContainerName) - assert.Equal(t, "/arg/wal", cfg.WalDir) + assert.Equal(t, "/arg/wal", cfg.PgWalDir) } func Test_LoadFromJSONAndArgs_DefaultsApplied_WhenNoJSONAndNoArgs(t *testing.T) { @@ -244,7 +244,7 @@ func Test_SaveToJSON_PgFieldsSavedCorrectly(t *testing.T) { PgType: "docker", PgHostBinDir: "/usr/bin", PgDockerContainerName: "pg-container", - WalDir: "/opt/wal", + PgWalDir: "/opt/wal", IsDeleteWalAfterUpload: &deleteWal, } @@ -260,7 +260,7 @@ func Test_SaveToJSON_PgFieldsSavedCorrectly(t *testing.T) { assert.Equal(t, "docker", saved.PgType) assert.Equal(t, "/usr/bin", saved.PgHostBinDir) assert.Equal(t, "pg-container", saved.PgDockerContainerName) - assert.Equal(t, "/opt/wal", saved.WalDir) + assert.Equal(t, "/opt/wal", saved.PgWalDir) require.NotNil(t, saved.IsDeleteWalAfterUpload) assert.Equal(t, false, *saved.IsDeleteWalAfterUpload) } diff --git a/agent/internal/config/dto.go b/agent/internal/config/dto.go index a03545c..5298298 100644 --- a/agent/internal/config/dto.go +++ b/agent/internal/config/dto.go @@ -11,7 +11,7 @@ type parsedFlags struct { pgType *string pgHostBinDir *string pgDockerContainerName *string - walDir *string + pgWalDir *string sources map[string]string } diff --git a/agent/internal/features/api/api.go b/agent/internal/features/api/api.go new file mode 100644 index 0000000..54da8f3 --- /dev/null +++ b/agent/internal/features/api/api.go @@ -0,0 +1,288 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "os" + "time" + + "github.com/go-resty/resty/v2" +) + +const ( + chainValidPath = "/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup" + nextBackupTimePath = "/api/v1/backups/postgres/wal/next-full-backup-time" + walUploadPath = "/api/v1/backups/postgres/wal/upload/wal" + fullStartPath = "/api/v1/backups/postgres/wal/upload/full-start" + fullCompletePath = "/api/v1/backups/postgres/wal/upload/full-complete" + reportErrorPath = "/api/v1/backups/postgres/wal/error" + versionPath = "/api/v1/system/version" + agentBinaryPath = "/api/v1/system/agent" + + apiCallTimeout = 30 * time.Second + maxRetryAttempts = 3 + retryBaseDelay = 1 * time.Second +) + +type Client struct { + json *resty.Client + stream *resty.Client + host string + log *slog.Logger +} + +func NewClient(host, token string, log *slog.Logger) *Client { + setAuth := func(_ *resty.Client, req *resty.Request) error { + if token != "" { + req.SetHeader("Authorization", token) + } + + return nil + } + + jsonClient := resty.New(). + SetTimeout(apiCallTimeout). + SetRetryCount(maxRetryAttempts - 1). + SetRetryWaitTime(retryBaseDelay). + SetRetryMaxWaitTime(4 * retryBaseDelay). + AddRetryCondition(func(resp *resty.Response, err error) bool { + return err != nil || resp.StatusCode() >= 500 + }). + OnBeforeRequest(setAuth) + + streamClient := resty.New(). + OnBeforeRequest(setAuth) + + return &Client{ + json: jsonClient, + stream: streamClient, + host: host, + log: log, + } +} + +func (c *Client) CheckWalChainValidity(ctx context.Context) (*WalChainValidityResponse, error) { + var resp WalChainValidityResponse + + httpResp, err := c.json.R(). + SetContext(ctx). + SetResult(&resp). + Get(c.buildURL(chainValidPath)) + if err != nil { + return nil, err + } + + if err := c.checkResponse(httpResp, "check WAL chain validity"); err != nil { + return nil, err + } + + return &resp, nil +} + +func (c *Client) GetNextFullBackupTime(ctx context.Context) (*NextFullBackupTimeResponse, error) { + var resp NextFullBackupTimeResponse + + httpResp, err := c.json.R(). + SetContext(ctx). + SetResult(&resp). + Get(c.buildURL(nextBackupTimePath)) + if err != nil { + return nil, err + } + + if err := c.checkResponse(httpResp, "get next full backup time"); err != nil { + return nil, err + } + + return &resp, nil +} + +func (c *Client) ReportBackupError(ctx context.Context, errMsg string) error { + httpResp, err := c.json.R(). + SetContext(ctx). + SetBody(reportErrorRequest{Error: errMsg}). + Post(c.buildURL(reportErrorPath)) + if err != nil { + return err + } + + return c.checkResponse(httpResp, "report backup error") +} + +func (c *Client) UploadBasebackup( + ctx context.Context, + body io.Reader, +) (*UploadBasebackupResponse, error) { + resp, err := c.stream.R(). + SetContext(ctx). + SetBody(body). + SetHeader("Content-Type", "application/octet-stream"). + SetDoNotParseResponse(true). + Post(c.buildURL(fullStartPath)) + if err != nil { + return nil, fmt.Errorf("upload request: %w", err) + } + defer func() { _ = resp.RawBody().Close() }() + + if resp.StatusCode() != http.StatusOK { + respBody, _ := io.ReadAll(resp.RawBody()) + + return nil, fmt.Errorf("upload failed with status %d: %s", resp.StatusCode(), string(respBody)) + } + + var result UploadBasebackupResponse + if err := json.NewDecoder(resp.RawBody()).Decode(&result); err != nil { + return nil, fmt.Errorf("decode upload response: %w", err) + } + + return &result, nil +} + +func (c *Client) FinalizeBasebackup( + ctx context.Context, + backupID string, + startSegment string, + stopSegment string, +) error { + resp, err := c.json.R(). + SetContext(ctx). + SetBody(finalizeBasebackupRequest{ + BackupID: backupID, + StartSegment: startSegment, + StopSegment: stopSegment, + }). + Post(c.buildURL(fullCompletePath)) + if err != nil { + return fmt.Errorf("finalize request: %w", err) + } + + if resp.StatusCode() != http.StatusOK { + return fmt.Errorf("finalize failed with status %d: %s", resp.StatusCode(), resp.String()) + } + + return nil +} + +func (c *Client) FinalizeBasebackupWithError( + ctx context.Context, + backupID string, + errMsg string, +) error { + resp, err := c.json.R(). + SetContext(ctx). + SetBody(finalizeBasebackupRequest{ + BackupID: backupID, + Error: &errMsg, + }). + Post(c.buildURL(fullCompletePath)) + if err != nil { + return fmt.Errorf("finalize-with-error request: %w", err) + } + + if resp.StatusCode() != http.StatusOK { + return fmt.Errorf("finalize-with-error failed with status %d: %s", resp.StatusCode(), resp.String()) + } + + return nil +} + +func (c *Client) UploadWalSegment( + ctx context.Context, + segmentName string, + body io.Reader, +) (*UploadWalSegmentResult, error) { + resp, err := c.stream.R(). + SetContext(ctx). + SetBody(body). + SetHeader("Content-Type", "application/octet-stream"). + SetHeader("X-Wal-Segment-Name", segmentName). + SetDoNotParseResponse(true). + Post(c.buildURL(walUploadPath)) + if err != nil { + return nil, fmt.Errorf("upload request: %w", err) + } + defer func() { _ = resp.RawBody().Close() }() + + switch resp.StatusCode() { + case http.StatusNoContent: + return &UploadWalSegmentResult{IsGapDetected: false}, nil + + case http.StatusConflict: + var errResp uploadErrorResponse + + if err := json.NewDecoder(resp.RawBody()).Decode(&errResp); err != nil { + return &UploadWalSegmentResult{IsGapDetected: true}, nil + } + + return &UploadWalSegmentResult{ + IsGapDetected: true, + ExpectedSegmentName: errResp.ExpectedSegmentName, + ReceivedSegmentName: errResp.ReceivedSegmentName, + }, nil + + default: + respBody, _ := io.ReadAll(resp.RawBody()) + + return nil, fmt.Errorf("upload failed with status %d: %s", resp.StatusCode(), string(respBody)) + } +} + +func (c *Client) FetchServerVersion(ctx context.Context) (string, error) { + var ver versionResponse + + httpResp, err := c.json.R(). + SetContext(ctx). + SetResult(&ver). + Get(c.buildURL(versionPath)) + if err != nil { + return "", err + } + + if err := c.checkResponse(httpResp, "fetch server version"); err != nil { + return "", err + } + + return ver.Version, nil +} + +func (c *Client) DownloadAgentBinary(ctx context.Context, arch, destPath string) error { + resp, err := c.stream.R(). + SetContext(ctx). + SetQueryParam("arch", arch). + SetDoNotParseResponse(true). + Get(c.buildURL(agentBinaryPath)) + if err != nil { + return err + } + defer func() { _ = resp.RawBody().Close() }() + + if resp.StatusCode() != http.StatusOK { + return fmt.Errorf("server returned %d for agent download", resp.StatusCode()) + } + + f, err := os.Create(destPath) + if err != nil { + return err + } + defer func() { _ = f.Close() }() + + _, err = io.Copy(f, resp.RawBody()) + + return err +} + +func (c *Client) buildURL(path string) string { + return c.host + path +} + +func (c *Client) checkResponse(resp *resty.Response, method string) error { + if resp.StatusCode() >= 400 { + return fmt.Errorf("%s: server returned status %d: %s", method, resp.StatusCode(), resp.String()) + } + + return nil +} diff --git a/agent/internal/features/api/dto.go b/agent/internal/features/api/dto.go new file mode 100644 index 0000000..3495b96 --- /dev/null +++ b/agent/internal/features/api/dto.go @@ -0,0 +1,44 @@ +package api + +import "time" + +type WalChainValidityResponse struct { + IsValid bool `json:"isValid"` + Error string `json:"error,omitempty"` + LastContiguousSegment string `json:"lastContiguousSegment,omitempty"` +} + +type NextFullBackupTimeResponse struct { + NextFullBackupTime *time.Time `json:"nextFullBackupTime"` +} + +type UploadWalSegmentResult struct { + IsGapDetected bool + ExpectedSegmentName string + ReceivedSegmentName string +} + +type reportErrorRequest struct { + Error string `json:"error"` +} + +type versionResponse struct { + Version string `json:"version"` +} + +type UploadBasebackupResponse struct { + BackupID string `json:"backupId"` +} + +type finalizeBasebackupRequest struct { + BackupID string `json:"backupId"` + StartSegment string `json:"startSegment"` + StopSegment string `json:"stopSegment"` + Error *string `json:"error,omitempty"` +} + +type uploadErrorResponse struct { + Error string `json:"error"` + ExpectedSegmentName string `json:"expectedSegmentName"` + ReceivedSegmentName string `json:"receivedSegmentName"` +} diff --git a/agent/internal/features/full_backup/backuper.go b/agent/internal/features/full_backup/backuper.go new file mode 100644 index 0000000..478809e --- /dev/null +++ b/agent/internal/features/full_backup/backuper.go @@ -0,0 +1,292 @@ +package full_backup + +import ( + "bytes" + "context" + "fmt" + "io" + "log/slog" + "os" + "os/exec" + "path/filepath" + "sync/atomic" + "time" + + "github.com/klauspost/compress/zstd" + + "databasus-agent/internal/config" + "databasus-agent/internal/features/api" +) + +const ( + checkInterval = 30 * time.Second + retryDelay = 1 * time.Minute + uploadTimeout = 30 * time.Minute +) + +var retryDelayOverride *time.Duration + +type CmdBuilder func(ctx context.Context) *exec.Cmd + +// FullBackuper runs pg_basebackup when the WAL chain is broken or a scheduled backup is due. +// +// Every 30 seconds it checks two conditions via the Databasus API: +// 1. WAL chain validity — if broken or no full backup exists, triggers an immediate basebackup. +// 2. Scheduled backup time — if the next full backup time has passed, triggers a basebackup. +// +// Only one basebackup runs at a time (guarded by atomic bool). +// On failure the error is reported to the server and the backup retries after 1 minute, indefinitely. +// WAL segment uploads (handled by wal.Streamer) continue independently and are not paused. +// +// pg_basebackup runs as "pg_basebackup -Ft -D - -X none --verbose". Stdout (tar) is zstd-compressed +// and uploaded to the server. Stderr is parsed for WAL start/stop segment names (LSN → segment arithmetic). +type FullBackuper struct { + cfg *config.Config + apiClient *api.Client + log *slog.Logger + isRunning atomic.Bool + cmdBuilder CmdBuilder +} + +func NewFullBackuper(cfg *config.Config, apiClient *api.Client, log *slog.Logger) *FullBackuper { + backuper := &FullBackuper{ + cfg: cfg, + apiClient: apiClient, + log: log, + } + + backuper.cmdBuilder = backuper.defaultCmdBuilder + + return backuper +} + +func (backuper *FullBackuper) Run(ctx context.Context) { + backuper.log.Info("Full backuper started") + + backuper.checkAndRunIfNeeded(ctx) + + ticker := time.NewTicker(checkInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + backuper.log.Info("Full backuper stopping") + return + case <-ticker.C: + backuper.checkAndRunIfNeeded(ctx) + } + } +} + +func (backuper *FullBackuper) checkAndRunIfNeeded(ctx context.Context) { + if backuper.isRunning.Load() { + backuper.log.Debug("Skipping check: basebackup already in progress") + return + } + + chainResp, err := backuper.apiClient.CheckWalChainValidity(ctx) + if err != nil { + backuper.log.Error("Failed to check WAL chain validity", "error", err) + return + } + + if !chainResp.IsValid { + backuper.log.Info("WAL chain is invalid, triggering basebackup", + "error", chainResp.Error, + "lastContiguousSegment", chainResp.LastContiguousSegment, + ) + + backuper.runBasebackupWithRetry(ctx) + + return + } + + nextTimeResp, err := backuper.apiClient.GetNextFullBackupTime(ctx) + if err != nil { + backuper.log.Error("Failed to check next full backup time", "error", err) + return + } + + if nextTimeResp.NextFullBackupTime == nil || !nextTimeResp.NextFullBackupTime.After(time.Now().UTC()) { + backuper.log.Info("Scheduled full backup is due, triggering basebackup") + backuper.runBasebackupWithRetry(ctx) + + return + } + + backuper.log.Debug("No basebackup needed", + "nextFullBackupTime", nextTimeResp.NextFullBackupTime, + ) +} + +func (backuper *FullBackuper) runBasebackupWithRetry(ctx context.Context) { + if !backuper.isRunning.CompareAndSwap(false, true) { + backuper.log.Debug("Skipping basebackup: already running") + return + } + defer backuper.isRunning.Store(false) + + for { + if ctx.Err() != nil { + return + } + + backuper.log.Info("Starting pg_basebackup") + + err := backuper.executeAndUploadBasebackup(ctx) + if err == nil { + backuper.log.Info("Basebackup completed successfully") + return + } + + backuper.log.Error("Basebackup failed", "error", err) + backuper.reportError(ctx, err.Error()) + + delay := retryDelay + if retryDelayOverride != nil { + delay = *retryDelayOverride + } + + backuper.log.Info("Retrying basebackup after delay", "delay", delay) + + select { + case <-ctx.Done(): + return + case <-time.After(delay): + } + } +} + +func (backuper *FullBackuper) executeAndUploadBasebackup(ctx context.Context) error { + cmd := backuper.cmdBuilder(ctx) + + var stderrBuf bytes.Buffer + cmd.Stderr = &stderrBuf + + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + return fmt.Errorf("create stdout pipe: %w", err) + } + + if err := cmd.Start(); err != nil { + return fmt.Errorf("start pg_basebackup: %w", err) + } + + // Phase 1: Stream compressed data via io.Pipe directly to the API. + pipeReader, pipeWriter := io.Pipe() + go backuper.compressAndStream(pipeWriter, stdoutPipe) + + uploadCtx, cancel := context.WithTimeout(ctx, uploadTimeout) + defer cancel() + + uploadResp, uploadErr := backuper.apiClient.UploadBasebackup(uploadCtx, pipeReader) + + cmdErr := cmd.Wait() + + if uploadErr != nil { + return fmt.Errorf("upload basebackup: %w", uploadErr) + } + + if cmdErr != nil { + errMsg := fmt.Sprintf("pg_basebackup exited with error: %v (stderr: %s)", cmdErr, stderrBuf.String()) + _ = backuper.apiClient.FinalizeBasebackupWithError(ctx, uploadResp.BackupID, errMsg) + + return fmt.Errorf("pg_basebackup: %w", cmdErr) + } + + // Phase 2: Parse stderr for WAL segments and finalize the backup. + stderrStr := stderrBuf.String() + backuper.log.Debug("pg_basebackup stderr", "stderr", stderrStr) + + startSegment, stopSegment, err := ParseBasebackupStderr(stderrStr) + if err != nil { + errMsg := fmt.Sprintf("parse pg_basebackup stderr: %v", err) + _ = backuper.apiClient.FinalizeBasebackupWithError(ctx, uploadResp.BackupID, errMsg) + + return fmt.Errorf("parse pg_basebackup stderr: %w", err) + } + + backuper.log.Info("Basebackup WAL segments parsed", + "startSegment", startSegment, + "stopSegment", stopSegment, + "backupId", uploadResp.BackupID, + ) + + if err := backuper.apiClient.FinalizeBasebackup(ctx, uploadResp.BackupID, startSegment, stopSegment); err != nil { + return fmt.Errorf("finalize basebackup: %w", err) + } + + return nil +} + +func (backuper *FullBackuper) compressAndStream(pipeWriter *io.PipeWriter, reader io.Reader) { + encoder, err := zstd.NewWriter(pipeWriter, + zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(5)), + zstd.WithEncoderCRC(true), + ) + if err != nil { + _ = pipeWriter.CloseWithError(fmt.Errorf("create zstd encoder: %w", err)) + return + } + + if _, err := io.Copy(encoder, reader); err != nil { + _ = encoder.Close() + _ = pipeWriter.CloseWithError(fmt.Errorf("compress: %w", err)) + return + } + + if err := encoder.Close(); err != nil { + _ = pipeWriter.CloseWithError(fmt.Errorf("close encoder: %w", err)) + return + } + + _ = pipeWriter.Close() +} + +func (backuper *FullBackuper) reportError(ctx context.Context, errMsg string) { + if err := backuper.apiClient.ReportBackupError(ctx, errMsg); err != nil { + backuper.log.Error("Failed to report error to server", "error", err) + } +} + +func (backuper *FullBackuper) defaultCmdBuilder(ctx context.Context) *exec.Cmd { + switch backuper.cfg.PgType { + case "docker": + return backuper.buildDockerCmd(ctx) + default: + return backuper.buildHostCmd(ctx) + } +} + +func (backuper *FullBackuper) buildHostCmd(ctx context.Context) *exec.Cmd { + binary := "pg_basebackup" + if backuper.cfg.PgHostBinDir != "" { + binary = filepath.Join(backuper.cfg.PgHostBinDir, "pg_basebackup") + } + + cmd := exec.CommandContext(ctx, binary, + "-Ft", "-D", "-", "-X", "none", "--verbose", + "-h", backuper.cfg.PgHost, + "-p", fmt.Sprintf("%d", backuper.cfg.PgPort), + "-U", backuper.cfg.PgUser, + ) + + cmd.Env = append(os.Environ(), "PGPASSWORD="+backuper.cfg.PgPassword) + + return cmd +} + +func (backuper *FullBackuper) buildDockerCmd(ctx context.Context) *exec.Cmd { + cmd := exec.CommandContext(ctx, "docker", "exec", + "-e", "PGPASSWORD="+backuper.cfg.PgPassword, + "-i", backuper.cfg.PgDockerContainerName, + "pg_basebackup", + "-Ft", "-D", "-", "-X", "none", "--verbose", + "-h", backuper.cfg.PgHost, + "-p", fmt.Sprintf("%d", backuper.cfg.PgPort), + "-U", backuper.cfg.PgUser, + ) + + return cmd +} diff --git a/agent/internal/features/full_backup/backuper_test.go b/agent/internal/features/full_backup/backuper_test.go new file mode 100644 index 0000000..72ecc09 --- /dev/null +++ b/agent/internal/features/full_backup/backuper_test.go @@ -0,0 +1,671 @@ +package full_backup + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "os/exec" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/klauspost/compress/zstd" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "databasus-agent/internal/config" + "databasus-agent/internal/features/api" + "databasus-agent/internal/logger" +) + +const ( + testChainValidPath = "/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup" + testNextBackupTimePath = "/api/v1/backups/postgres/wal/next-full-backup-time" + testFullStartPath = "/api/v1/backups/postgres/wal/upload/full-start" + testFullCompletePath = "/api/v1/backups/postgres/wal/upload/full-complete" + testReportErrorPath = "/api/v1/backups/postgres/wal/error" + + testBackupID = "test-backup-id-1234" +) + +func Test_RunFullBackup_WhenChainBroken_BasebackupTriggered(t *testing.T) { + var mu sync.Mutex + var uploadReceived bool + var uploadHeaders http.Header + var finalizeReceived bool + var finalizeBody map[string]any + + server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case testChainValidPath: + writeJSON(w, api.WalChainValidityResponse{ + IsValid: false, + Error: "wal_chain_broken", + LastContiguousSegment: "000000010000000100000011", + }) + case testFullStartPath: + mu.Lock() + uploadReceived = true + uploadHeaders = r.Header.Clone() + mu.Unlock() + + _, _ = io.ReadAll(r.Body) + writeJSON(w, map[string]string{"backupId": testBackupID}) + case testFullCompletePath: + mu.Lock() + finalizeReceived = true + _ = json.NewDecoder(r.Body).Decode(&finalizeBody) + mu.Unlock() + + w.WriteHeader(http.StatusOK) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + fb := newTestFullBackuper(server.URL) + fb.cmdBuilder = mockCmdBuilder(t, "test-backup-data", validStderr()) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go fb.Run(ctx) + waitForCondition(t, func() bool { + mu.Lock() + defer mu.Unlock() + return finalizeReceived + }, 5*time.Second) + cancel() + + mu.Lock() + defer mu.Unlock() + + assert.True(t, uploadReceived) + assert.Equal(t, "application/octet-stream", uploadHeaders.Get("Content-Type")) + assert.Equal(t, "test-token", uploadHeaders.Get("Authorization")) + + assert.True(t, finalizeReceived) + assert.Equal(t, testBackupID, finalizeBody["backupId"]) + assert.Equal(t, "000000010000000000000002", finalizeBody["startSegment"]) + assert.Equal(t, "000000010000000000000002", finalizeBody["stopSegment"]) +} + +func Test_RunFullBackup_WhenScheduledBackupDue_BasebackupTriggered(t *testing.T) { + var mu sync.Mutex + var finalizeReceived bool + + pastTime := time.Now().UTC().Add(-1 * time.Hour) + + server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case testChainValidPath: + writeJSON(w, api.WalChainValidityResponse{IsValid: true}) + case testNextBackupTimePath: + writeJSON(w, api.NextFullBackupTimeResponse{NextFullBackupTime: &pastTime}) + case testFullStartPath: + _, _ = io.ReadAll(r.Body) + writeJSON(w, map[string]string{"backupId": testBackupID}) + case testFullCompletePath: + mu.Lock() + finalizeReceived = true + mu.Unlock() + + w.WriteHeader(http.StatusOK) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + fb := newTestFullBackuper(server.URL) + fb.cmdBuilder = mockCmdBuilder(t, "scheduled-backup-data", validStderr()) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go fb.Run(ctx) + waitForCondition(t, func() bool { + mu.Lock() + defer mu.Unlock() + return finalizeReceived + }, 5*time.Second) + cancel() + + mu.Lock() + defer mu.Unlock() + + assert.True(t, finalizeReceived) +} + +func Test_RunFullBackup_WhenNoFullBackupExists_ImmediateBasebackupTriggered(t *testing.T) { + var mu sync.Mutex + var finalizeReceived bool + + server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case testChainValidPath: + writeJSON(w, api.WalChainValidityResponse{ + IsValid: false, + Error: "no_full_backup", + }) + case testFullStartPath: + _, _ = io.ReadAll(r.Body) + writeJSON(w, map[string]string{"backupId": testBackupID}) + case testFullCompletePath: + mu.Lock() + finalizeReceived = true + mu.Unlock() + + w.WriteHeader(http.StatusOK) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + fb := newTestFullBackuper(server.URL) + fb.cmdBuilder = mockCmdBuilder(t, "first-backup-data", validStderr()) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go fb.Run(ctx) + waitForCondition(t, func() bool { + mu.Lock() + defer mu.Unlock() + return finalizeReceived + }, 5*time.Second) + cancel() + + mu.Lock() + defer mu.Unlock() + + assert.True(t, finalizeReceived) +} + +func Test_RunFullBackup_WhenUploadFails_RetriesAfterDelay(t *testing.T) { + var mu sync.Mutex + var uploadAttempts int + var errorReported bool + + server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case testChainValidPath: + writeJSON(w, api.WalChainValidityResponse{ + IsValid: false, + Error: "no_full_backup", + }) + case testFullStartPath: + _, _ = io.ReadAll(r.Body) + + mu.Lock() + uploadAttempts++ + attempt := uploadAttempts + mu.Unlock() + + if attempt == 1 { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"error":"storage unavailable"}`)) + return + } + + writeJSON(w, map[string]string{"backupId": testBackupID}) + case testFullCompletePath: + w.WriteHeader(http.StatusOK) + case testReportErrorPath: + mu.Lock() + errorReported = true + mu.Unlock() + + w.WriteHeader(http.StatusOK) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + fb := newTestFullBackuper(server.URL) + fb.cmdBuilder = mockCmdBuilder(t, "retry-backup-data", validStderr()) + + origRetryDelay := retryDelay + setRetryDelay(100 * time.Millisecond) + defer setRetryDelay(origRetryDelay) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + go fb.Run(ctx) + waitForCondition(t, func() bool { + mu.Lock() + defer mu.Unlock() + return uploadAttempts >= 2 + }, 10*time.Second) + cancel() + + mu.Lock() + defer mu.Unlock() + + assert.GreaterOrEqual(t, uploadAttempts, 2) + assert.True(t, errorReported) +} + +func Test_RunFullBackup_WhenAlreadyRunning_SkipsExecution(t *testing.T) { + var mu sync.Mutex + var uploadCount int + + server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case testChainValidPath: + writeJSON(w, api.WalChainValidityResponse{ + IsValid: false, + Error: "no_full_backup", + }) + case testFullStartPath: + _, _ = io.ReadAll(r.Body) + + mu.Lock() + uploadCount++ + mu.Unlock() + + writeJSON(w, map[string]string{"backupId": testBackupID}) + case testFullCompletePath: + w.WriteHeader(http.StatusOK) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + fb := newTestFullBackuper(server.URL) + fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr()) + + fb.isRunning.Store(true) + + fb.checkAndRunIfNeeded(context.Background()) + + mu.Lock() + count := uploadCount + mu.Unlock() + + assert.Equal(t, 0, count, "should not trigger backup when already running") +} + +func Test_RunFullBackup_WhenContextCancelled_StopsCleanly(t *testing.T) { + server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case testChainValidPath: + writeJSON(w, api.WalChainValidityResponse{ + IsValid: false, + Error: "no_full_backup", + }) + case testFullStartPath: + _, _ = io.ReadAll(r.Body) + w.WriteHeader(http.StatusInternalServerError) + case testFullCompletePath: + w.WriteHeader(http.StatusOK) + case testReportErrorPath: + w.WriteHeader(http.StatusOK) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + fb := newTestFullBackuper(server.URL) + fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr()) + + origRetryDelay := retryDelay + setRetryDelay(5 * time.Second) + defer setRetryDelay(origRetryDelay) + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + done := make(chan struct{}) + go func() { + fb.Run(ctx) + close(done) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Run should have stopped after context cancellation") + } +} + +func Test_RunFullBackup_WhenChainValidAndNotScheduled_NoBasebackupTriggered(t *testing.T) { + var uploadReceived atomic.Bool + + futureTime := time.Now().UTC().Add(24 * time.Hour) + + server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case testChainValidPath: + writeJSON(w, api.WalChainValidityResponse{IsValid: true}) + case testNextBackupTimePath: + writeJSON(w, api.NextFullBackupTimeResponse{NextFullBackupTime: &futureTime}) + case testFullStartPath: + uploadReceived.Store(true) + + _, _ = io.ReadAll(r.Body) + writeJSON(w, map[string]string{"backupId": testBackupID}) + case testFullCompletePath: + w.WriteHeader(http.StatusOK) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + fb := newTestFullBackuper(server.URL) + fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr()) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + go fb.Run(ctx) + time.Sleep(500 * time.Millisecond) + cancel() + + assert.False(t, uploadReceived.Load(), "should not trigger backup when chain valid and not scheduled") +} + +func Test_RunFullBackup_WhenStderrParsingFails_FinalizesWithErrorAndRetries(t *testing.T) { + var mu sync.Mutex + var errorReported bool + var finalizeWithErrorReceived bool + var finalizeBody map[string]any + + server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case testChainValidPath: + writeJSON(w, api.WalChainValidityResponse{ + IsValid: false, + Error: "no_full_backup", + }) + case testFullStartPath: + _, _ = io.ReadAll(r.Body) + writeJSON(w, map[string]string{"backupId": testBackupID}) + case testFullCompletePath: + mu.Lock() + finalizeWithErrorReceived = true + _ = json.NewDecoder(r.Body).Decode(&finalizeBody) + mu.Unlock() + + w.WriteHeader(http.StatusOK) + case testReportErrorPath: + mu.Lock() + errorReported = true + mu.Unlock() + + w.WriteHeader(http.StatusOK) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + fb := newTestFullBackuper(server.URL) + fb.cmdBuilder = mockCmdBuilder(t, "data", "pg_basebackup: unexpected output with no LSN info") + + origRetryDelay := retryDelay + setRetryDelay(100 * time.Millisecond) + defer setRetryDelay(origRetryDelay) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + go fb.Run(ctx) + waitForCondition(t, func() bool { + mu.Lock() + defer mu.Unlock() + return errorReported + }, 2*time.Second) + cancel() + + mu.Lock() + defer mu.Unlock() + + assert.True(t, errorReported) + assert.True(t, finalizeWithErrorReceived, "should finalize with error when stderr parsing fails") + assert.Equal(t, testBackupID, finalizeBody["backupId"]) + assert.NotNil(t, finalizeBody["error"], "finalize should include error message") +} + +func Test_RunFullBackup_WhenNextBackupTimeNull_BasebackupTriggered(t *testing.T) { + var mu sync.Mutex + var finalizeReceived bool + + server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case testChainValidPath: + writeJSON(w, api.WalChainValidityResponse{IsValid: true}) + case testNextBackupTimePath: + writeJSON(w, api.NextFullBackupTimeResponse{NextFullBackupTime: nil}) + case testFullStartPath: + _, _ = io.ReadAll(r.Body) + writeJSON(w, map[string]string{"backupId": testBackupID}) + case testFullCompletePath: + mu.Lock() + finalizeReceived = true + mu.Unlock() + + w.WriteHeader(http.StatusOK) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + fb := newTestFullBackuper(server.URL) + fb.cmdBuilder = mockCmdBuilder(t, "first-run-data", validStderr()) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go fb.Run(ctx) + waitForCondition(t, func() bool { + mu.Lock() + defer mu.Unlock() + return finalizeReceived + }, 5*time.Second) + cancel() + + mu.Lock() + defer mu.Unlock() + + assert.True(t, finalizeReceived) +} + +func Test_RunFullBackup_WhenChainValidityReturns401_NoBasebackupTriggered(t *testing.T) { + var uploadReceived atomic.Bool + + server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case testChainValidPath: + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"invalid token"}`)) + case testFullStartPath: + uploadReceived.Store(true) + + _, _ = io.ReadAll(r.Body) + writeJSON(w, map[string]string{"backupId": testBackupID}) + case testFullCompletePath: + w.WriteHeader(http.StatusOK) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + fb := newTestFullBackuper(server.URL) + fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr()) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + go fb.Run(ctx) + time.Sleep(500 * time.Millisecond) + cancel() + + assert.False(t, uploadReceived.Load(), "should not trigger backup when API returns 401") +} + +func Test_RunFullBackup_WhenUploadSucceeds_BodyIsZstdCompressed(t *testing.T) { + var mu sync.Mutex + var receivedBody []byte + + server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case testChainValidPath: + writeJSON(w, api.WalChainValidityResponse{ + IsValid: false, + Error: "no_full_backup", + }) + case testFullStartPath: + body, _ := io.ReadAll(r.Body) + + mu.Lock() + receivedBody = body + mu.Unlock() + + writeJSON(w, map[string]string{"backupId": testBackupID}) + case testFullCompletePath: + w.WriteHeader(http.StatusOK) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + originalContent := "test-backup-content-for-compression-check" + fb := newTestFullBackuper(server.URL) + fb.cmdBuilder = mockCmdBuilder(t, originalContent, validStderr()) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go fb.Run(ctx) + waitForCondition(t, func() bool { + mu.Lock() + defer mu.Unlock() + return len(receivedBody) > 0 + }, 5*time.Second) + cancel() + + mu.Lock() + body := receivedBody + mu.Unlock() + + decoder, err := zstd.NewReader(nil) + require.NoError(t, err) + defer decoder.Close() + + decompressed, err := decoder.DecodeAll(body, nil) + require.NoError(t, err) + assert.Equal(t, originalContent, string(decompressed)) +} + +func newTestServer(t *testing.T, handler http.HandlerFunc) *httptest.Server { + t.Helper() + + server := httptest.NewServer(handler) + t.Cleanup(server.Close) + + return server +} + +func newTestFullBackuper(serverURL string) *FullBackuper { + cfg := &config.Config{ + DatabasusHost: serverURL, + DbID: "test-db-id", + Token: "test-token", + PgHost: "localhost", + PgPort: 5432, + PgUser: "postgres", + PgPassword: "password", + PgType: "host", + } + + apiClient := api.NewClient(serverURL, cfg.Token, logger.GetLogger()) + + return NewFullBackuper(cfg, apiClient, logger.GetLogger()) +} + +func mockCmdBuilder(t *testing.T, stdoutContent, stderrContent string) CmdBuilder { + t.Helper() + + return func(ctx context.Context) *exec.Cmd { + cmd := exec.CommandContext(ctx, os.Args[0], + "-test.run=TestHelperProcess", + "--", + stdoutContent, + stderrContent, + ) + + cmd.Env = append(os.Environ(), "GO_TEST_HELPER_PROCESS=1") + + return cmd + } +} + +func TestHelperProcess(t *testing.T) { + if os.Getenv("GO_TEST_HELPER_PROCESS") != "1" { + return + } + + args := os.Args + for i, arg := range args { + if arg == "--" { + args = args[i+1:] + break + } + } + + if len(args) >= 1 { + _, _ = fmt.Fprint(os.Stdout, args[0]) + } + + if len(args) >= 2 { + _, _ = fmt.Fprint(os.Stderr, args[1]) + } + + os.Exit(0) +} + +func validStderr() string { + return `pg_basebackup: initiating base backup, waiting for checkpoint to complete +pg_basebackup: checkpoint completed +pg_basebackup: write-ahead log start point: 0/2000028, on timeline 1 +pg_basebackup: checkpoint redo point at 0/2000028 +pg_basebackup: write-ahead log end point: 0/2000100 +pg_basebackup: base backup completed` +} + +func writeJSON(w http.ResponseWriter, v any) { + w.Header().Set("Content-Type", "application/json") + + if err := json.NewEncoder(w).Encode(v); err != nil { + w.WriteHeader(http.StatusInternalServerError) + } +} + +func waitForCondition(t *testing.T, condition func() bool, timeout time.Duration) { + t.Helper() + + deadline := time.Now().Add(timeout) + + for time.Now().Before(deadline) { + if condition() { + return + } + + time.Sleep(50 * time.Millisecond) + } + + t.Fatalf("condition not met within %v", timeout) +} + +func setRetryDelay(d time.Duration) { + retryDelayOverride = &d +} + +func init() { + retryDelayOverride = nil +} diff --git a/agent/internal/features/full_backup/stderr_parser.go b/agent/internal/features/full_backup/stderr_parser.go new file mode 100644 index 0000000..7c6069c --- /dev/null +++ b/agent/internal/features/full_backup/stderr_parser.go @@ -0,0 +1,75 @@ +package full_backup + +import ( + "fmt" + "regexp" + "strconv" + "strings" +) + +const defaultWalSegmentSize uint32 = 16 * 1024 * 1024 // 16 MB + +var ( + startLSNRegex = regexp.MustCompile(`checkpoint redo point at ([0-9A-Fa-f]+/[0-9A-Fa-f]+)`) + stopLSNRegex = regexp.MustCompile(`write-ahead log end point: ([0-9A-Fa-f]+/[0-9A-Fa-f]+)`) +) + +func ParseBasebackupStderr(stderr string) (startSegment, stopSegment string, err error) { + startMatch := startLSNRegex.FindStringSubmatch(stderr) + if len(startMatch) < 2 { + return "", "", fmt.Errorf("failed to parse start WAL location from pg_basebackup stderr") + } + + stopMatch := stopLSNRegex.FindStringSubmatch(stderr) + if len(stopMatch) < 2 { + return "", "", fmt.Errorf("failed to parse stop WAL location from pg_basebackup stderr") + } + + startSegment, err = LSNToSegmentName(startMatch[1], 1, defaultWalSegmentSize) + if err != nil { + return "", "", fmt.Errorf("failed to convert start LSN to segment name: %w", err) + } + + stopSegment, err = LSNToSegmentName(stopMatch[1], 1, defaultWalSegmentSize) + if err != nil { + return "", "", fmt.Errorf("failed to convert stop LSN to segment name: %w", err) + } + + return startSegment, stopSegment, nil +} + +func LSNToSegmentName(lsn string, timelineID, walSegmentSize uint32) (string, error) { + high, low, err := parseLSN(lsn) + if err != nil { + return "", err + } + + segmentsPerXLogID := uint32(0x100000000 / uint64(walSegmentSize)) + logID := high + segmentOffset := low / walSegmentSize + + if segmentOffset >= segmentsPerXLogID { + return "", fmt.Errorf("segment offset %d exceeds segments per XLogId %d", segmentOffset, segmentsPerXLogID) + } + + return fmt.Sprintf("%08X%08X%08X", timelineID, logID, segmentOffset), nil +} + +func parseLSN(lsn string) (high, low uint32, err error) { + parts := strings.SplitN(lsn, "/", 2) + if len(parts) != 2 { + return 0, 0, fmt.Errorf("invalid LSN format: %q (expected X/Y)", lsn) + } + + highVal, err := strconv.ParseUint(parts[0], 16, 32) + if err != nil { + return 0, 0, fmt.Errorf("invalid LSN high part %q: %w", parts[0], err) + } + + lowVal, err := strconv.ParseUint(parts[1], 16, 32) + if err != nil { + return 0, 0, fmt.Errorf("invalid LSN low part %q: %w", parts[1], err) + } + + return uint32(highVal), uint32(lowVal), nil +} diff --git a/agent/internal/features/full_backup/stderr_parser_test.go b/agent/internal/features/full_backup/stderr_parser_test.go new file mode 100644 index 0000000..cd833fc --- /dev/null +++ b/agent/internal/features/full_backup/stderr_parser_test.go @@ -0,0 +1,162 @@ +package full_backup + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ParseBasebackupStderr_WithPG17Output_ExtractsCorrectSegments(t *testing.T) { + stderr := `pg_basebackup: initiating base backup, waiting for checkpoint to complete +pg_basebackup: checkpoint completed +pg_basebackup: write-ahead log start point: 0/2000028, on timeline 1 +pg_basebackup: starting background WAL receiver +pg_basebackup: checkpoint redo point at 0/2000028 +pg_basebackup: write-ahead log end point: 0/2000100 +pg_basebackup: waiting for background process to finish streaming ... +pg_basebackup: syncing data to disk ... +pg_basebackup: renaming backup_manifest.tmp to backup_manifest +pg_basebackup: base backup completed` + + startSeg, stopSeg, err := ParseBasebackupStderr(stderr) + + require.NoError(t, err) + assert.Equal(t, "000000010000000000000002", startSeg) + assert.Equal(t, "000000010000000000000002", stopSeg) +} + +func Test_ParseBasebackupStderr_WithPG15Output_ExtractsCorrectSegments(t *testing.T) { + stderr := `pg_basebackup: initiating base backup, waiting for checkpoint to complete +pg_basebackup: checkpoint completed +pg_basebackup: write-ahead log start point: 1/AB000028, on timeline 1 +pg_basebackup: checkpoint redo point at 1/AB000028 +pg_basebackup: write-ahead log end point: 1/AC000000 +pg_basebackup: base backup completed` + + startSeg, stopSeg, err := ParseBasebackupStderr(stderr) + + require.NoError(t, err) + assert.Equal(t, "0000000100000001000000AB", startSeg) + assert.Equal(t, "0000000100000001000000AC", stopSeg) +} + +func Test_ParseBasebackupStderr_WithHighLogID_ExtractsCorrectSegments(t *testing.T) { + stderr := `pg_basebackup: checkpoint redo point at A/FF000028 +pg_basebackup: write-ahead log end point: B/1000000` + + startSeg, stopSeg, err := ParseBasebackupStderr(stderr) + + require.NoError(t, err) + assert.Equal(t, "000000010000000A000000FF", startSeg) + assert.Equal(t, "000000010000000B00000001", stopSeg) +} + +func Test_ParseBasebackupStderr_WhenStartLSNMissing_ReturnsError(t *testing.T) { + stderr := `pg_basebackup: write-ahead log end point: 0/2000100 +pg_basebackup: base backup completed` + + _, _, err := ParseBasebackupStderr(stderr) + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse start WAL location") +} + +func Test_ParseBasebackupStderr_WhenStopLSNMissing_ReturnsError(t *testing.T) { + stderr := `pg_basebackup: checkpoint redo point at 0/2000028 +pg_basebackup: base backup completed` + + _, _, err := ParseBasebackupStderr(stderr) + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse stop WAL location") +} + +func Test_ParseBasebackupStderr_WhenEmptyStderr_ReturnsError(t *testing.T) { + _, _, err := ParseBasebackupStderr("") + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse start WAL location") +} + +func Test_LSNToSegmentName_WithBoundaryValues_ConvertsCorrectly(t *testing.T) { + tests := []struct { + name string + lsn string + timeline uint32 + segSize uint32 + expected string + }{ + { + name: "first segment", + lsn: "0/1000000", + timeline: 1, + segSize: 16 * 1024 * 1024, + expected: "000000010000000000000001", + }, + { + name: "segment at boundary FF", + lsn: "0/FF000000", + timeline: 1, + segSize: 16 * 1024 * 1024, + expected: "0000000100000000000000FF", + }, + { + name: "segment in second log file", + lsn: "1/0", + timeline: 1, + segSize: 16 * 1024 * 1024, + expected: "000000010000000100000000", + }, + { + name: "segment with offset within 16MB", + lsn: "0/200ABCD", + timeline: 1, + segSize: 16 * 1024 * 1024, + expected: "000000010000000000000002", + }, + { + name: "zero LSN", + lsn: "0/0", + timeline: 1, + segSize: 16 * 1024 * 1024, + expected: "000000010000000000000000", + }, + { + name: "high timeline ID", + lsn: "0/1000000", + timeline: 2, + segSize: 16 * 1024 * 1024, + expected: "000000020000000000000001", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := LSNToSegmentName(tt.lsn, tt.timeline, tt.segSize) + + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + }) + } +} + +func Test_LSNToSegmentName_WithInvalidLSN_ReturnsError(t *testing.T) { + tests := []struct { + name string + lsn string + }{ + {name: "no slash", lsn: "012345"}, + {name: "empty string", lsn: ""}, + {name: "invalid hex high", lsn: "GG/0"}, + {name: "invalid hex low", lsn: "0/ZZ"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := LSNToSegmentName(tt.lsn, 1, 16*1024*1024) + + require.Error(t, err) + }) + } +} diff --git a/agent/internal/features/start/daemon.go b/agent/internal/features/start/daemon.go new file mode 100644 index 0000000..15e6a5d --- /dev/null +++ b/agent/internal/features/start/daemon.go @@ -0,0 +1,120 @@ +//go:build !windows + +package start + +import ( + "errors" + "fmt" + "log/slog" + "os" + "os/exec" + "syscall" + "time" +) + +const ( + logFileName = "databasus.log" + stopTimeout = 30 * time.Second + stopPollInterval = 500 * time.Millisecond + daemonStartupDelay = 500 * time.Millisecond +) + +func Stop(log *slog.Logger) error { + pid, err := ReadLockFilePID() + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return errors.New("agent is not running (no lock file found)") + } + + return fmt.Errorf("failed to read lock file: %w", err) + } + + if !isProcessAlive(pid) { + _ = os.Remove(lockFileName) + return fmt.Errorf("agent is not running (stale lock file removed, PID %d)", pid) + } + + log.Info("Sending SIGTERM to agent", "pid", pid) + + if err := syscall.Kill(pid, syscall.SIGTERM); err != nil { + return fmt.Errorf("failed to send SIGTERM to PID %d: %w", pid, err) + } + + deadline := time.Now().Add(stopTimeout) + for time.Now().Before(deadline) { + if !isProcessAlive(pid) { + log.Info("Agent stopped", "pid", pid) + return nil + } + + time.Sleep(stopPollInterval) + } + + return fmt.Errorf("agent (PID %d) did not stop within %s — process may be stuck", pid, stopTimeout) +} + +func Status(log *slog.Logger) error { + pid, err := ReadLockFilePID() + if err != nil { + if errors.Is(err, os.ErrNotExist) { + fmt.Println("Agent is not running") + return nil + } + + return fmt.Errorf("failed to read lock file: %w", err) + } + + if isProcessAlive(pid) { + fmt.Printf("Agent is running (PID %d)\n", pid) + } else { + fmt.Println("Agent is not running (stale lock file)") + _ = os.Remove(lockFileName) + } + + return nil +} + +func spawnDaemon(log *slog.Logger) (int, error) { + execPath, err := os.Executable() + if err != nil { + return 0, fmt.Errorf("failed to resolve executable path: %w", err) + } + + args := []string{"_run"} + + logFile, err := os.OpenFile(logFileName, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return 0, fmt.Errorf("failed to open log file %s: %w", logFileName, err) + } + + cwd, err := os.Getwd() + if err != nil { + _ = logFile.Close() + return 0, fmt.Errorf("failed to get working directory: %w", err) + } + + cmd := exec.Command(execPath, args...) + cmd.Dir = cwd + cmd.Stderr = logFile + cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true} + + if err := cmd.Start(); err != nil { + _ = logFile.Close() + return 0, fmt.Errorf("failed to start daemon process: %w", err) + } + + pid := cmd.Process.Pid + + // Detach — we don't wait for the child + _ = logFile.Close() + + time.Sleep(daemonStartupDelay) + + if !isProcessAlive(pid) { + return 0, fmt.Errorf("daemon process (PID %d) exited immediately — check %s for details", pid, logFileName) + } + + log.Info("Daemon spawned", "pid", pid, "log", logFileName) + + return pid, nil +} diff --git a/agent/internal/features/start/daemon_windows.go b/agent/internal/features/start/daemon_windows.go new file mode 100644 index 0000000..8743554 --- /dev/null +++ b/agent/internal/features/start/daemon_windows.go @@ -0,0 +1,20 @@ +//go:build windows + +package start + +import ( + "errors" + "log/slog" +) + +func Stop(log *slog.Logger) error { + return errors.New("stop is not supported on Windows — use Ctrl+C in the terminal where the agent is running") +} + +func Status(log *slog.Logger) error { + return errors.New("status is not supported on Windows — check the terminal where the agent is running") +} + +func spawnDaemon(_ *slog.Logger) (int, error) { + return 0, errors.New("daemon mode is not supported on Windows") +} diff --git a/agent/internal/features/start/lock.go b/agent/internal/features/start/lock.go new file mode 100644 index 0000000..58a0d10 --- /dev/null +++ b/agent/internal/features/start/lock.go @@ -0,0 +1,117 @@ +//go:build !windows + +package start + +import ( + "errors" + "fmt" + "io" + "log/slog" + "os" + "strconv" + "strings" + "syscall" +) + +const lockFileName = "databasus.lock" + +func AcquireLock(log *slog.Logger) (*os.File, error) { + f, err := os.OpenFile(lockFileName, os.O_CREATE|os.O_RDWR, 0o644) + if err != nil { + return nil, fmt.Errorf("failed to open lock file: %w", err) + } + + err = syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB) + if err == nil { + if err := writePID(f); err != nil { + _ = f.Close() + return nil, err + } + + log.Info("Process lock acquired", "pid", os.Getpid(), "lockFile", lockFileName) + + return f, nil + } + + if !errors.Is(err, syscall.EWOULDBLOCK) { + _ = f.Close() + return nil, fmt.Errorf("failed to acquire lock: %w", err) + } + + pid, pidErr := readLockPID(f) + _ = f.Close() + + if pidErr != nil { + return nil, fmt.Errorf("Another instance is already running.") + } + + return nil, fmt.Errorf("Another instance is already running (PID %d).", pid) +} + +func ReleaseLock(f *os.File) { + _ = syscall.Flock(int(f.Fd()), syscall.LOCK_UN) + _ = f.Close() + _ = os.Remove(lockFileName) +} + +func ReadLockFilePID() (int, error) { + f, err := os.Open(lockFileName) + if err != nil { + return 0, err + } + defer f.Close() + + return readLockPID(f) +} + +func writePID(f *os.File) error { + if err := f.Truncate(0); err != nil { + return fmt.Errorf("failed to truncate lock file: %w", err) + } + + if _, err := f.Seek(0, io.SeekStart); err != nil { + return fmt.Errorf("failed to seek lock file: %w", err) + } + + if _, err := fmt.Fprintf(f, "%d\n", os.Getpid()); err != nil { + return fmt.Errorf("failed to write PID to lock file: %w", err) + } + + return f.Sync() +} + +func readLockPID(f *os.File) (int, error) { + if _, err := f.Seek(0, io.SeekStart); err != nil { + return 0, err + } + + data, err := io.ReadAll(f) + if err != nil { + return 0, err + } + + s := strings.TrimSpace(string(data)) + if s == "" { + return 0, errors.New("lock file is empty") + } + + pid, err := strconv.Atoi(s) + if err != nil { + return 0, fmt.Errorf("invalid PID in lock file: %w", err) + } + + return pid, nil +} + +func isProcessAlive(pid int) bool { + err := syscall.Kill(pid, 0) + if err == nil { + return true + } + + if errors.Is(err, syscall.EPERM) { + return true + } + + return false +} diff --git a/agent/internal/features/start/lock_test.go b/agent/internal/features/start/lock_test.go new file mode 100644 index 0000000..f9c7293 --- /dev/null +++ b/agent/internal/features/start/lock_test.go @@ -0,0 +1,148 @@ +//go:build !windows + +package start + +import ( + "fmt" + "os" + "strconv" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "databasus-agent/internal/logger" +) + +func Test_AcquireLock_LockFileCreatedWithPID(t *testing.T) { + setupTempDir(t) + log := logger.GetLogger() + + lockFile, err := AcquireLock(log) + require.NoError(t, err) + defer ReleaseLock(lockFile) + + data, err := os.ReadFile(lockFileName) + require.NoError(t, err) + + pid, err := strconv.Atoi(strings.TrimSpace(string(data))) + require.NoError(t, err) + assert.Equal(t, os.Getpid(), pid) +} + +func Test_AcquireLock_SecondAcquireFails_WhenFirstHeld(t *testing.T) { + setupTempDir(t) + log := logger.GetLogger() + + first, err := AcquireLock(log) + require.NoError(t, err) + defer ReleaseLock(first) + + second, err := AcquireLock(log) + assert.Nil(t, second) + require.Error(t, err) + assert.Contains(t, err.Error(), "Another instance is already running") + assert.Contains(t, err.Error(), fmt.Sprintf("PID %d", os.Getpid())) +} + +func Test_AcquireLock_StaleLockReacquired_WhenProcessDead(t *testing.T) { + setupTempDir(t) + log := logger.GetLogger() + + err := os.WriteFile(lockFileName, []byte("999999999\n"), 0o644) + require.NoError(t, err) + + lockFile, err := AcquireLock(log) + require.NoError(t, err) + defer ReleaseLock(lockFile) + + data, err := os.ReadFile(lockFileName) + require.NoError(t, err) + + pid, err := strconv.Atoi(strings.TrimSpace(string(data))) + require.NoError(t, err) + assert.Equal(t, os.Getpid(), pid) +} + +func Test_ReleaseLock_LockFileRemoved(t *testing.T) { + setupTempDir(t) + log := logger.GetLogger() + + lockFile, err := AcquireLock(log) + require.NoError(t, err) + + ReleaseLock(lockFile) + + _, err = os.Stat(lockFileName) + assert.True(t, os.IsNotExist(err)) +} + +func Test_AcquireLock_ReacquiredAfterRelease(t *testing.T) { + setupTempDir(t) + log := logger.GetLogger() + + first, err := AcquireLock(log) + require.NoError(t, err) + ReleaseLock(first) + + second, err := AcquireLock(log) + require.NoError(t, err) + defer ReleaseLock(second) + + data, err := os.ReadFile(lockFileName) + require.NoError(t, err) + + pid, err := strconv.Atoi(strings.TrimSpace(string(data))) + require.NoError(t, err) + assert.Equal(t, os.Getpid(), pid) +} + +func Test_isProcessAlive_ReturnsTrueForSelf(t *testing.T) { + assert.True(t, isProcessAlive(os.Getpid())) +} + +func Test_isProcessAlive_ReturnsFalseForNonExistentPID(t *testing.T) { + assert.False(t, isProcessAlive(999999999)) +} + +func Test_readLockPID_ParsesValidPID(t *testing.T) { + setupTempDir(t) + + f, err := os.CreateTemp("", "lock-test-*") + require.NoError(t, err) + defer os.Remove(f.Name()) + + _, err = f.WriteString("12345\n") + require.NoError(t, err) + + pid, err := readLockPID(f) + require.NoError(t, err) + assert.Equal(t, 12345, pid) +} + +func Test_readLockPID_ReturnsErrorForEmptyFile(t *testing.T) { + setupTempDir(t) + + f, err := os.CreateTemp("", "lock-test-*") + require.NoError(t, err) + defer os.Remove(f.Name()) + + _, err = readLockPID(f) + require.Error(t, err) + assert.Contains(t, err.Error(), "lock file is empty") +} + +func setupTempDir(t *testing.T) string { + t.Helper() + + origDir, err := os.Getwd() + require.NoError(t, err) + + dir := t.TempDir() + require.NoError(t, os.Chdir(dir)) + + t.Cleanup(func() { _ = os.Chdir(origDir) }) + + return dir +} diff --git a/agent/internal/features/start/lock_watcher.go b/agent/internal/features/start/lock_watcher.go new file mode 100644 index 0000000..27e8ca4 --- /dev/null +++ b/agent/internal/features/start/lock_watcher.go @@ -0,0 +1,90 @@ +//go:build !windows + +package start + +import ( + "context" + "log/slog" + "os" + "syscall" + "time" +) + +const lockWatchInterval = 5 * time.Second + +type LockWatcher struct { + originalInode uint64 + cancel context.CancelFunc + log *slog.Logger +} + +func NewLockWatcher(lockFile *os.File, cancel context.CancelFunc, log *slog.Logger) (*LockWatcher, error) { + inode, err := getFileInode(lockFile) + if err != nil { + return nil, err + } + + return &LockWatcher{ + originalInode: inode, + cancel: cancel, + log: log, + }, nil +} + +func (w *LockWatcher) Run(ctx context.Context) { + ticker := time.NewTicker(lockWatchInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + w.check() + } + } +} + +func (w *LockWatcher) check() { + info, err := os.Stat(lockFileName) + if err != nil { + w.log.Error("Lock file disappeared, shutting down", "file", lockFileName, "error", err) + w.cancel() + + return + } + + currentInode, err := getStatInode(info) + if err != nil { + w.log.Error("Failed to read lock file inode, shutting down", "error", err) + w.cancel() + + return + } + + if currentInode != w.originalInode { + w.log.Error("Lock file was replaced (inode changed), shutting down", + "originalInode", w.originalInode, + "currentInode", currentInode, + ) + w.cancel() + } +} + +func getFileInode(f *os.File) (uint64, error) { + info, err := f.Stat() + if err != nil { + return 0, err + } + + return getStatInode(info) +} + +func getStatInode(info os.FileInfo) (uint64, error) { + stat, ok := info.Sys().(*syscall.Stat_t) + if !ok { + return 0, os.ErrInvalid + } + + return stat.Ino, nil +} diff --git a/agent/internal/features/start/lock_watcher_test.go b/agent/internal/features/start/lock_watcher_test.go new file mode 100644 index 0000000..d1e3ab2 --- /dev/null +++ b/agent/internal/features/start/lock_watcher_test.go @@ -0,0 +1,110 @@ +//go:build !windows + +package start + +import ( + "context" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "databasus-agent/internal/logger" +) + +func Test_NewLockWatcher_CapturesInode(t *testing.T) { + setupTempDir(t) + log := logger.GetLogger() + + lockFile, err := AcquireLock(log) + require.NoError(t, err) + defer ReleaseLock(lockFile) + + _, cancel := context.WithCancel(context.Background()) + defer cancel() + + watcher, err := NewLockWatcher(lockFile, cancel, log) + require.NoError(t, err) + assert.NotZero(t, watcher.originalInode) +} + +func Test_LockWatcher_FileUnchanged_ContextNotCancelled(t *testing.T) { + setupTempDir(t) + log := logger.GetLogger() + + lockFile, err := AcquireLock(log) + require.NoError(t, err) + defer ReleaseLock(lockFile) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + watcher, err := NewLockWatcher(lockFile, cancel, log) + require.NoError(t, err) + + watcher.check() + watcher.check() + watcher.check() + + select { + case <-ctx.Done(): + t.Fatal("context should not be cancelled when lock file is unchanged") + default: + } +} + +func Test_LockWatcher_FileDeleted_CancelsContext(t *testing.T) { + setupTempDir(t) + log := logger.GetLogger() + + lockFile, err := AcquireLock(log) + require.NoError(t, err) + defer ReleaseLock(lockFile) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + watcher, err := NewLockWatcher(lockFile, cancel, log) + require.NoError(t, err) + + err = os.Remove(lockFileName) + require.NoError(t, err) + + watcher.check() + + select { + case <-ctx.Done(): + default: + t.Fatal("context should be cancelled when lock file is deleted") + } +} + +func Test_LockWatcher_FileReplacedWithDifferentInode_CancelsContext(t *testing.T) { + setupTempDir(t) + log := logger.GetLogger() + + lockFile, err := AcquireLock(log) + require.NoError(t, err) + defer ReleaseLock(lockFile) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + watcher, err := NewLockWatcher(lockFile, cancel, log) + require.NoError(t, err) + + err = os.Remove(lockFileName) + require.NoError(t, err) + + err = os.WriteFile(lockFileName, []byte("99999\n"), 0o644) + require.NoError(t, err) + + watcher.check() + + select { + case <-ctx.Done(): + default: + t.Fatal("context should be cancelled when lock file inode changes") + } +} diff --git a/agent/internal/features/start/lock_watcher_windows.go b/agent/internal/features/start/lock_watcher_windows.go new file mode 100644 index 0000000..3b41ba8 --- /dev/null +++ b/agent/internal/features/start/lock_watcher_windows.go @@ -0,0 +1,17 @@ +//go:build windows + +package start + +import ( + "context" + "log/slog" + "os" +) + +type LockWatcher struct{} + +func NewLockWatcher(_ *os.File, _ context.CancelFunc, _ *slog.Logger) (*LockWatcher, error) { + return &LockWatcher{}, nil +} + +func (w *LockWatcher) Run(_ context.Context) {} diff --git a/agent/internal/features/start/lock_windows.go b/agent/internal/features/start/lock_windows.go new file mode 100644 index 0000000..c8fa44e --- /dev/null +++ b/agent/internal/features/start/lock_windows.go @@ -0,0 +1,18 @@ +package start + +import ( + "log/slog" + "os" +) + +func AcquireLock(log *slog.Logger) (*os.File, error) { + log.Warn("Process locking is not supported on Windows, skipping") + + return nil, nil +} + +func ReleaseLock(f *os.File) { + if f != nil { + _ = f.Close() + } +} diff --git a/agent/internal/features/start/start.go b/agent/internal/features/start/start.go index 93470b8..5d4c578 100644 --- a/agent/internal/features/start/start.go +++ b/agent/internal/features/start/start.go @@ -5,22 +5,32 @@ import ( "errors" "fmt" "log/slog" + "os" "os/exec" + "os/signal" "path/filepath" + "runtime" + "strconv" "strings" + "syscall" "time" "github.com/jackc/pgx/v5" "databasus-agent/internal/config" + "databasus-agent/internal/features/api" + full_backup "databasus-agent/internal/features/full_backup" + "databasus-agent/internal/features/upgrade" + "databasus-agent/internal/features/wal" ) const ( pgBasebackupVerifyTimeout = 10 * time.Second dbVerifyTimeout = 10 * time.Second + minPgMajorVersion = 15 ) -func Run(cfg *config.Config, log *slog.Logger) error { +func Start(cfg *config.Config, agentVersion string, isDev bool, log *slog.Logger) error { if err := validateConfig(cfg); err != nil { return err } @@ -33,10 +43,59 @@ func Run(cfg *config.Config, log *slog.Logger) error { return err } - log.Info("start: stub — not yet implemented", - "dbId", cfg.DbID, - "hasToken", cfg.Token != "", - ) + if runtime.GOOS == "windows" { + return RunDaemon(cfg, agentVersion, isDev, log) + } + + pid, err := spawnDaemon(log) + if err != nil { + return err + } + + fmt.Printf("Agent started in background (PID %d)\n", pid) + + return nil +} + +func RunDaemon(cfg *config.Config, agentVersion string, isDev bool, log *slog.Logger) error { + lockFile, err := AcquireLock(log) + if err != nil { + return err + } + defer ReleaseLock(lockFile) + + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer cancel() + + watcher, err := NewLockWatcher(lockFile, cancel, log) + if err != nil { + return fmt.Errorf("failed to initialize lock watcher: %w", err) + } + go watcher.Run(ctx) + + apiClient := api.NewClient(cfg.DatabasusHost, cfg.Token, log) + + var backgroundUpgrader *upgrade.BackgroundUpgrader + if agentVersion != "dev" && runtime.GOOS != "windows" { + backgroundUpgrader = upgrade.NewBackgroundUpgrader(apiClient, agentVersion, isDev, cancel, log) + go backgroundUpgrader.Run(ctx) + } + + fullBackuper := full_backup.NewFullBackuper(cfg, apiClient, log) + go fullBackuper.Run(ctx) + + streamer := wal.NewStreamer(cfg, apiClient, log) + streamer.Run(ctx) + + if backgroundUpgrader != nil { + backgroundUpgrader.WaitForCompletion(30 * time.Second) + + if backgroundUpgrader.IsUpgraded() { + return upgrade.ErrUpgradeRestart + } + } + + log.Info("Agent stopped") return nil } @@ -70,8 +129,8 @@ func validateConfig(cfg *config.Config) error { return fmt.Errorf("argument pg-type must be 'host' or 'docker', got '%s'", cfg.PgType) } - if cfg.WalDir == "" { - return errors.New("argument wal-dir is required") + if cfg.PgWalDir == "" { + return errors.New("argument pg-wal-dir is required") } if cfg.PgType == "docker" && cfg.PgDockerContainerName == "" { @@ -169,11 +228,44 @@ func verifyDatabase(cfg *config.Config, log *slog.Logger) error { ) } + var versionNumStr string + if err := conn.QueryRow(ctx, "SHOW server_version_num").Scan(&versionNumStr); err != nil { + return fmt.Errorf("failed to query PostgreSQL version: %w", err) + } + + majorVersion, err := parsePgVersionNum(versionNumStr) + if err != nil { + return fmt.Errorf("failed to parse PostgreSQL version '%s': %w", versionNumStr, err) + } + + if majorVersion < minPgMajorVersion { + return fmt.Errorf( + "PostgreSQL %d is not supported, minimum required version is %d", + majorVersion, minPgMajorVersion, + ) + } + log.Info("PostgreSQL connection verified", "host", cfg.PgHost, "port", cfg.PgPort, "user", cfg.PgUser, + "version", majorVersion, ) return nil } + +func parsePgVersionNum(versionNumStr string) (int, error) { + versionNum, err := strconv.Atoi(strings.TrimSpace(versionNumStr)) + if err != nil { + return 0, fmt.Errorf("invalid version number: %w", err) + } + + if versionNum <= 0 { + return 0, fmt.Errorf("invalid version number: %d", versionNum) + } + + majorVersion := versionNum / 10000 + + return majorVersion, nil +} diff --git a/agent/internal/features/start/start_test.go b/agent/internal/features/start/start_test.go new file mode 100644 index 0000000..b7ea8a5 --- /dev/null +++ b/agent/internal/features/start/start_test.go @@ -0,0 +1,84 @@ +package start + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ParsePgVersionNum_SupportedVersions_ReturnsMajorVersion(t *testing.T) { + tests := []struct { + name string + versionNumStr string + expectedMajor int + }{ + {name: "PG 15.0", versionNumStr: "150000", expectedMajor: 15}, + {name: "PG 15.4", versionNumStr: "150004", expectedMajor: 15}, + {name: "PG 16.0", versionNumStr: "160000", expectedMajor: 16}, + {name: "PG 16.3", versionNumStr: "160003", expectedMajor: 16}, + {name: "PG 17.2", versionNumStr: "170002", expectedMajor: 17}, + {name: "PG 18.0", versionNumStr: "180000", expectedMajor: 18}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + major, err := parsePgVersionNum(tt.versionNumStr) + + require.NoError(t, err) + assert.Equal(t, tt.expectedMajor, major) + assert.GreaterOrEqual(t, major, minPgMajorVersion) + }) + } +} + +func Test_ParsePgVersionNum_UnsupportedVersions_ReturnsMajorVersionBelow15(t *testing.T) { + tests := []struct { + name string + versionNumStr string + expectedMajor int + }{ + {name: "PG 12.5", versionNumStr: "120005", expectedMajor: 12}, + {name: "PG 13.0", versionNumStr: "130000", expectedMajor: 13}, + {name: "PG 14.12", versionNumStr: "140012", expectedMajor: 14}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + major, err := parsePgVersionNum(tt.versionNumStr) + + require.NoError(t, err) + assert.Equal(t, tt.expectedMajor, major) + assert.Less(t, major, minPgMajorVersion) + }) + } +} + +func Test_ParsePgVersionNum_InvalidInput_ReturnsError(t *testing.T) { + tests := []struct { + name string + versionNumStr string + }{ + {name: "empty string", versionNumStr: ""}, + {name: "non-numeric", versionNumStr: "abc"}, + {name: "negative number", versionNumStr: "-1"}, + {name: "zero", versionNumStr: "0"}, + {name: "float", versionNumStr: "15.4"}, + {name: "whitespace only", versionNumStr: " "}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := parsePgVersionNum(tt.versionNumStr) + + require.Error(t, err) + }) + } +} + +func Test_ParsePgVersionNum_WithWhitespace_ParsesCorrectly(t *testing.T) { + major, err := parsePgVersionNum(" 150004 ") + + require.NoError(t, err) + assert.Equal(t, 15, major) +} diff --git a/agent/internal/features/upgrade/background_upgrader.go b/agent/internal/features/upgrade/background_upgrader.go new file mode 100644 index 0000000..4703195 --- /dev/null +++ b/agent/internal/features/upgrade/background_upgrader.go @@ -0,0 +1,88 @@ +package upgrade + +import ( + "context" + "log/slog" + "sync/atomic" + "time" + + "databasus-agent/internal/features/api" +) + +const backgroundCheckInterval = 5 * time.Second + +type BackgroundUpgrader struct { + apiClient *api.Client + currentVersion string + isDev bool + cancel context.CancelFunc + isUpgraded atomic.Bool + log *slog.Logger + done chan struct{} +} + +func NewBackgroundUpgrader( + apiClient *api.Client, + currentVersion string, + isDev bool, + cancel context.CancelFunc, + log *slog.Logger, +) *BackgroundUpgrader { + return &BackgroundUpgrader{ + apiClient, + currentVersion, + isDev, + cancel, + atomic.Bool{}, + log, + make(chan struct{}), + } +} + +func (u *BackgroundUpgrader) Run(ctx context.Context) { + defer close(u.done) + + ticker := time.NewTicker(backgroundCheckInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if u.checkAndUpgrade() { + return + } + } + } +} + +func (u *BackgroundUpgrader) IsUpgraded() bool { + return u.isUpgraded.Load() +} + +func (u *BackgroundUpgrader) WaitForCompletion(timeout time.Duration) { + select { + case <-u.done: + case <-time.After(timeout): + } +} + +func (u *BackgroundUpgrader) checkAndUpgrade() bool { + isUpgraded, err := CheckAndUpdate(u.apiClient, u.currentVersion, u.isDev, u.log) + if err != nil { + u.log.Warn("Background update check failed", "error", err) + + return false + } + + if !isUpgraded { + return false + } + + u.log.Info("Background upgrade complete, restarting...") + u.isUpgraded.Store(true) + u.cancel() + + return true +} diff --git a/agent/internal/features/upgrade/dto.go b/agent/internal/features/upgrade/dto.go deleted file mode 100644 index 11f251b..0000000 --- a/agent/internal/features/upgrade/dto.go +++ /dev/null @@ -1,5 +0,0 @@ -package upgrade - -type versionResponse struct { - Version string `json:"version"` -} diff --git a/agent/internal/features/upgrade/errors.go b/agent/internal/features/upgrade/errors.go new file mode 100644 index 0000000..ddecc64 --- /dev/null +++ b/agent/internal/features/upgrade/errors.go @@ -0,0 +1,5 @@ +package upgrade + +import "errors" + +var ErrUpgradeRestart = errors.New("agent upgraded, restart required") diff --git a/agent/internal/features/upgrade/upgrader.go b/agent/internal/features/upgrade/upgrader.go index a75ae2d..d50c534 100644 --- a/agent/internal/features/upgrade/upgrader.go +++ b/agent/internal/features/upgrade/upgrader.go @@ -2,44 +2,47 @@ package upgrade import ( "context" - "encoding/json" "fmt" - "io" "log/slog" - "net/http" "os" "os/exec" "runtime" "strings" - "syscall" - "time" + + "databasus-agent/internal/features/api" ) -func CheckAndUpdate(databasusHost, currentVersion string, isDev bool, log *slog.Logger) error { +// CheckAndUpdate checks if a new version is available and upgrades the binary on disk. +// Returns (true, nil) if the binary was upgraded, (false, nil) if already up to date, +// or (false, err) on failure. Callers are responsible for re-exec or restart signaling. +func CheckAndUpdate(apiClient *api.Client, currentVersion string, isDev bool, log *slog.Logger) (bool, error) { if isDev { log.Info("Skipping update check (development mode)") - return nil + + return false, nil } - serverVersion, err := fetchServerVersion(databasusHost, log) + serverVersion, err := apiClient.FetchServerVersion(context.Background()) if err != nil { - return fmt.Errorf( - "unable to check version, please verify Databasus server is available at %s: %w", - databasusHost, + log.Warn("Could not reach server for update check", "error", err) + + return false, fmt.Errorf( + "unable to check version, please verify Databasus server is available: %w", err, ) } if serverVersion == currentVersion { log.Info("Agent version is up to date", "version", currentVersion) - return nil + + return false, nil } log.Info("Updating agent...", "current", currentVersion, "target", serverVersion) selfPath, err := os.Executable() if err != nil { - return fmt.Errorf("failed to determine executable path: %w", err) + return false, fmt.Errorf("failed to determine executable path: %w", err) } tempPath := selfPath + ".update" @@ -48,93 +51,25 @@ func CheckAndUpdate(databasusHost, currentVersion string, isDev bool, log *slog. _ = os.Remove(tempPath) }() - if err := downloadBinary(databasusHost, tempPath); err != nil { - return fmt.Errorf("failed to download update: %w", err) + if err := apiClient.DownloadAgentBinary(context.Background(), runtime.GOARCH, tempPath); err != nil { + return false, fmt.Errorf("failed to download update: %w", err) } if err := os.Chmod(tempPath, 0o755); err != nil { - return fmt.Errorf("failed to set permissions on update: %w", err) + return false, fmt.Errorf("failed to set permissions on update: %w", err) } if err := verifyBinary(tempPath, serverVersion); err != nil { - return fmt.Errorf("update verification failed: %w", err) + return false, fmt.Errorf("update verification failed: %w", err) } if err := os.Rename(tempPath, selfPath); err != nil { - return fmt.Errorf("failed to replace binary (try --skip-update if this persists): %w", err) + return false, fmt.Errorf("failed to replace binary (try --skip-update if this persists): %w", err) } - log.Info("Update complete, re-executing...") + log.Info("Agent binary updated", "version", serverVersion) - return syscall.Exec(selfPath, os.Args, os.Environ()) -} - -func fetchServerVersion(host string, log *slog.Logger) (string, error) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - client := &http.Client{Timeout: 10 * time.Second} - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, host+"/api/v1/system/version", nil) - if err != nil { - return "", err - } - - resp, err := client.Do(req) - if err != nil { - log.Warn("Could not reach server for update check, continuing", "error", err) - return "", err - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - log.Warn( - "Server returned non-OK status for version check, continuing", - "status", - resp.StatusCode, - ) - return "", fmt.Errorf("status %d", resp.StatusCode) - } - - var ver versionResponse - if err := json.NewDecoder(resp.Body).Decode(&ver); err != nil { - log.Warn("Failed to parse server version response, continuing", "error", err) - return "", err - } - - return ver.Version, nil -} - -func downloadBinary(host, destPath string) error { - url := fmt.Sprintf("%s/api/v1/system/agent?arch=%s", host, runtime.GOARCH) - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) - defer cancel() - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) - if err != nil { - return err - } - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return err - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("server returned %d for agent download", resp.StatusCode) - } - - f, err := os.Create(destPath) - if err != nil { - return err - } - defer func() { _ = f.Close() }() - - _, err = io.Copy(f, resp.Body) - - return err + return true, nil } func verifyBinary(binaryPath, expectedVersion string) error { diff --git a/agent/internal/features/wal/streamer.go b/agent/internal/features/wal/streamer.go new file mode 100644 index 0000000..c1513fd --- /dev/null +++ b/agent/internal/features/wal/streamer.go @@ -0,0 +1,182 @@ +package wal + +import ( + "context" + "fmt" + "io" + "log/slog" + "os" + "path/filepath" + "regexp" + "sort" + "strings" + "time" + + "github.com/klauspost/compress/zstd" + + "databasus-agent/internal/config" + "databasus-agent/internal/features/api" +) + +const ( + pollInterval = 2 * time.Second + uploadTimeout = 5 * time.Minute +) + +var segmentNameRegex = regexp.MustCompile(`^[0-9A-Fa-f]{24}$`) + +type Streamer struct { + cfg *config.Config + apiClient *api.Client + log *slog.Logger +} + +func NewStreamer(cfg *config.Config, apiClient *api.Client, log *slog.Logger) *Streamer { + return &Streamer{ + cfg: cfg, + apiClient: apiClient, + log: log, + } +} + +func (s *Streamer) Run(ctx context.Context) { + s.log.Info("WAL streamer started", "pgWalDir", s.cfg.PgWalDir) + + s.processQueue(ctx) + + ticker := time.NewTicker(pollInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + s.log.Info("WAL streamer stopping") + return + case <-ticker.C: + s.processQueue(ctx) + } + } +} + +func (s *Streamer) processQueue(ctx context.Context) { + segments, err := s.listSegments() + if err != nil { + s.log.Error("Failed to list WAL segments", "error", err) + return + } + + for _, segmentName := range segments { + if ctx.Err() != nil { + return + } + + if err := s.uploadSegment(ctx, segmentName); err != nil { + s.log.Error("Failed to upload WAL segment", + "segment", segmentName, + "error", err, + ) + return + } + } +} + +func (s *Streamer) listSegments() ([]string, error) { + entries, err := os.ReadDir(s.cfg.PgWalDir) + if err != nil { + return nil, fmt.Errorf("read wal dir: %w", err) + } + + var segments []string + + for _, entry := range entries { + if entry.IsDir() { + continue + } + + name := entry.Name() + + if strings.HasSuffix(name, ".tmp") { + continue + } + + if !segmentNameRegex.MatchString(name) { + continue + } + + segments = append(segments, name) + } + + sort.Strings(segments) + + return segments, nil +} + +func (s *Streamer) uploadSegment(ctx context.Context, segmentName string) error { + filePath := filepath.Join(s.cfg.PgWalDir, segmentName) + + pr, pw := io.Pipe() + + go s.compressAndStream(pw, filePath) + + uploadCtx, cancel := context.WithTimeout(ctx, uploadTimeout) + defer cancel() + + result, err := s.apiClient.UploadWalSegment(uploadCtx, segmentName, pr) + if err != nil { + return err + } + + if result.IsGapDetected { + s.log.Warn("WAL chain gap detected", + "segment", segmentName, + "expected", result.ExpectedSegmentName, + "received", result.ReceivedSegmentName, + ) + + return fmt.Errorf("gap detected for segment %s", segmentName) + } + + s.log.Debug("WAL segment uploaded", "segment", segmentName) + + if *s.cfg.IsDeleteWalAfterUpload { + if err := os.Remove(filePath); err != nil { + s.log.Warn("Failed to delete uploaded WAL segment", + "segment", segmentName, + "error", err, + ) + } + } + + return nil +} + +func (s *Streamer) compressAndStream(pw *io.PipeWriter, filePath string) { + f, err := os.Open(filePath) + if err != nil { + _ = pw.CloseWithError(fmt.Errorf("open file: %w", err)) + return + } + defer func() { _ = f.Close() }() + + encoder, err := zstd.NewWriter(pw, + zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(5)), + zstd.WithEncoderCRC(true), + ) + if err != nil { + _ = pw.CloseWithError(fmt.Errorf("create zstd encoder: %w", err)) + return + } + + if _, err := io.Copy(encoder, f); err != nil { + _ = encoder.Close() + _ = pw.CloseWithError(fmt.Errorf("compress: %w", err)) + return + } + + if err := encoder.Close(); err != nil { + _ = pw.CloseWithError(fmt.Errorf("close encoder: %w", err)) + return + } + + _ = pw.Close() +} diff --git a/agent/internal/features/wal/streamer_test.go b/agent/internal/features/wal/streamer_test.go new file mode 100644 index 0000000..b9deceb --- /dev/null +++ b/agent/internal/features/wal/streamer_test.go @@ -0,0 +1,348 @@ +package wal + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/klauspost/compress/zstd" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "databasus-agent/internal/config" + "databasus-agent/internal/features/api" + "databasus-agent/internal/logger" +) + +func Test_UploadSegment_SingleSegment_ServerReceivesCorrectHeadersAndBody(t *testing.T) { + walDir := createTestWalDir(t) + segmentContent := []byte("test-wal-segment-data-for-upload") + writeTestSegment(t, walDir, "000000010000000100000001", segmentContent) + + var receivedHeaders http.Header + var receivedBody []byte + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + + body, err := io.ReadAll(r.Body) + require.NoError(t, err) + receivedBody = body + + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + streamer := newTestStreamer(walDir, server.URL) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + go streamer.Run(ctx) + time.Sleep(500 * time.Millisecond) + cancel() + + require.NotNil(t, receivedHeaders) + assert.Equal(t, "test-token", receivedHeaders.Get("Authorization")) + assert.Equal(t, "application/octet-stream", receivedHeaders.Get("Content-Type")) + assert.Equal(t, "000000010000000100000001", receivedHeaders.Get("X-Wal-Segment-Name")) + + decompressed := decompressZstd(t, receivedBody) + assert.Equal(t, segmentContent, decompressed) +} + +func Test_UploadSegments_MultipleSegmentsOutOfOrder_UploadedInAscendingOrder(t *testing.T) { + walDir := createTestWalDir(t) + writeTestSegment(t, walDir, "000000010000000100000003", []byte("third")) + writeTestSegment(t, walDir, "000000010000000100000001", []byte("first")) + writeTestSegment(t, walDir, "000000010000000100000002", []byte("second")) + + var mu sync.Mutex + var uploadOrder []string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + uploadOrder = append(uploadOrder, r.Header.Get("X-Wal-Segment-Name")) + mu.Unlock() + + _, _ = io.ReadAll(r.Body) + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + streamer := newTestStreamer(walDir, server.URL) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + go streamer.Run(ctx) + time.Sleep(500 * time.Millisecond) + cancel() + + mu.Lock() + defer mu.Unlock() + + require.Len(t, uploadOrder, 3) + assert.Equal(t, "000000010000000100000001", uploadOrder[0]) + assert.Equal(t, "000000010000000100000002", uploadOrder[1]) + assert.Equal(t, "000000010000000100000003", uploadOrder[2]) +} + +func Test_UploadSegments_DirectoryHasTmpFiles_TmpFilesIgnored(t *testing.T) { + walDir := createTestWalDir(t) + writeTestSegment(t, walDir, "000000010000000100000001", []byte("real segment")) + writeTestSegment(t, walDir, "000000010000000100000002.tmp", []byte("partial copy")) + + var mu sync.Mutex + var uploadedSegments []string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + uploadedSegments = append(uploadedSegments, r.Header.Get("X-Wal-Segment-Name")) + mu.Unlock() + + _, _ = io.ReadAll(r.Body) + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + streamer := newTestStreamer(walDir, server.URL) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + go streamer.Run(ctx) + time.Sleep(500 * time.Millisecond) + cancel() + + mu.Lock() + defer mu.Unlock() + + require.Len(t, uploadedSegments, 1) + assert.Equal(t, "000000010000000100000001", uploadedSegments[0]) +} + +func Test_UploadSegment_DeleteEnabled_FileRemovedAfterUpload(t *testing.T) { + walDir := createTestWalDir(t) + segmentName := "000000010000000100000001" + writeTestSegment(t, walDir, segmentName, []byte("segment data")) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.ReadAll(r.Body) + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + isDeleteEnabled := true + cfg := createTestConfig(walDir, server.URL) + cfg.IsDeleteWalAfterUpload = &isDeleteEnabled + apiClient := api.NewClient(server.URL, cfg.Token, logger.GetLogger()) + streamer := NewStreamer(cfg, apiClient, logger.GetLogger()) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + go streamer.Run(ctx) + time.Sleep(500 * time.Millisecond) + cancel() + + _, err := os.Stat(filepath.Join(walDir, segmentName)) + assert.True(t, os.IsNotExist(err), "segment file should be deleted after successful upload") +} + +func Test_UploadSegment_DeleteDisabled_FileKeptAfterUpload(t *testing.T) { + walDir := createTestWalDir(t) + segmentName := "000000010000000100000001" + writeTestSegment(t, walDir, segmentName, []byte("segment data")) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.ReadAll(r.Body) + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + isDeleteDisabled := false + cfg := createTestConfig(walDir, server.URL) + cfg.IsDeleteWalAfterUpload = &isDeleteDisabled + apiClient := api.NewClient(server.URL, cfg.Token, logger.GetLogger()) + streamer := NewStreamer(cfg, apiClient, logger.GetLogger()) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + go streamer.Run(ctx) + time.Sleep(500 * time.Millisecond) + cancel() + + _, err := os.Stat(filepath.Join(walDir, segmentName)) + assert.NoError(t, err, "segment file should be kept when delete is disabled") +} + +func Test_UploadSegment_ServerReturns500_FileKeptInQueue(t *testing.T) { + walDir := createTestWalDir(t) + segmentName := "000000010000000100000001" + writeTestSegment(t, walDir, segmentName, []byte("segment data")) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.ReadAll(r.Body) + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"error":"internal server error"}`)) + })) + defer server.Close() + + streamer := newTestStreamer(walDir, server.URL) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + go streamer.Run(ctx) + time.Sleep(500 * time.Millisecond) + cancel() + + _, err := os.Stat(filepath.Join(walDir, segmentName)) + assert.NoError(t, err, "segment file should remain in queue after server error") +} + +func Test_ProcessQueue_EmptyDirectory_NoUploads(t *testing.T) { + walDir := createTestWalDir(t) + + uploadCount := 0 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + uploadCount++ + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + streamer := newTestStreamer(walDir, server.URL) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + go streamer.Run(ctx) + time.Sleep(500 * time.Millisecond) + cancel() + + assert.Equal(t, 0, uploadCount, "no uploads should occur for empty directory") +} + +func Test_Run_ContextCancelled_StopsImmediately(t *testing.T) { + walDir := createTestWalDir(t) + + streamer := newTestStreamer(walDir, "http://localhost:0") + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + done := make(chan struct{}) + go func() { + streamer.Run(ctx) + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("Run should have stopped immediately when context is already cancelled") + } +} + +func Test_UploadSegment_ServerReturns409_FileNotDeleted(t *testing.T) { + walDir := createTestWalDir(t) + segmentName := "000000010000000100000005" + writeTestSegment(t, walDir, segmentName, []byte("gap segment")) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.ReadAll(r.Body) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusConflict) + + resp := map[string]string{ + "error": "gap_detected", + "expectedSegmentName": "000000010000000100000003", + "receivedSegmentName": segmentName, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + streamer := newTestStreamer(walDir, server.URL) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + go streamer.Run(ctx) + time.Sleep(500 * time.Millisecond) + cancel() + + _, err := os.Stat(filepath.Join(walDir, segmentName)) + assert.NoError(t, err, "segment file should not be deleted on gap detection") +} + +func newTestStreamer(walDir, serverURL string) *Streamer { + cfg := createTestConfig(walDir, serverURL) + apiClient := api.NewClient(serverURL, cfg.Token, logger.GetLogger()) + + return NewStreamer(cfg, apiClient, logger.GetLogger()) +} + +func createTestWalDir(t *testing.T) string { + t.Helper() + + baseDir := filepath.Join(".", ".test-tmp") + if err := os.MkdirAll(baseDir, 0o755); err != nil { + t.Fatalf("failed to create base test dir: %v", err) + } + + dir, err := os.MkdirTemp(baseDir, t.Name()+"-*") + if err != nil { + t.Fatalf("failed to create test wal dir: %v", err) + } + + t.Cleanup(func() { + _ = os.RemoveAll(dir) + }) + + return dir +} + +func writeTestSegment(t *testing.T, dir, name string, content []byte) { + t.Helper() + + if err := os.WriteFile(filepath.Join(dir, name), content, 0o644); err != nil { + t.Fatalf("failed to write test segment %s: %v", name, err) + } +} + +func createTestConfig(walDir, serverURL string) *config.Config { + isDeleteEnabled := true + + return &config.Config{ + DatabasusHost: serverURL, + DbID: "test-db-id", + Token: "test-token", + PgWalDir: walDir, + IsDeleteWalAfterUpload: &isDeleteEnabled, + } +} + +func decompressZstd(t *testing.T, data []byte) []byte { + t.Helper() + + decoder, err := zstd.NewReader(nil) + require.NoError(t, err) + defer decoder.Close() + + decoded, err := decoder.DecodeAll(data, nil) + require.NoError(t, err) + + return decoded +} diff --git a/agent/internal/logger/logger.go b/agent/internal/logger/logger.go index bab697f..f57c3dc 100644 --- a/agent/internal/logger/logger.go +++ b/agent/internal/logger/logger.go @@ -1,45 +1,119 @@ package logger import ( + "fmt" + "io" "log/slog" "os" "sync" "time" ) +const ( + logFileName = "databasus.log" + oldLogFileName = "databasus.log.old" + maxLogFileSize = 5 * 1024 * 1024 // 5MB +) + +type rotatingWriter struct { + mu sync.Mutex + file *os.File + currentSize int64 + maxSize int64 + logPath string + oldLogPath string +} + +func (w *rotatingWriter) Write(p []byte) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() + + if w.currentSize+int64(len(p)) > w.maxSize { + if err := w.rotate(); err != nil { + return 0, fmt.Errorf("failed to rotate log file: %w", err) + } + } + + n, err := w.file.Write(p) + w.currentSize += int64(n) + + return n, err +} + +func (w *rotatingWriter) rotate() error { + if err := w.file.Close(); err != nil { + return fmt.Errorf("failed to close %s: %w", w.logPath, err) + } + + if err := os.Remove(w.oldLogPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove %s: %w", w.oldLogPath, err) + } + + if err := os.Rename(w.logPath, w.oldLogPath); err != nil { + return fmt.Errorf("failed to rename %s to %s: %w", w.logPath, w.oldLogPath, err) + } + + f, err := os.OpenFile(w.logPath, os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return fmt.Errorf("failed to create new %s: %w", w.logPath, err) + } + + w.file = f + w.currentSize = 0 + + return nil +} + var ( loggerInstance *slog.Logger once sync.Once ) -func Init(isDebug bool) { - level := slog.LevelInfo - if isDebug { - level = slog.LevelDebug - } - - once.Do(func() { - loggerInstance = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ - Level: level, - ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { - if a.Key == slog.TimeKey { - a.Value = slog.StringValue(time.Now().Format("2006/01/02 15:04:05")) - } - if a.Key == slog.LevelKey { - return slog.Attr{} - } - - return a - }, - })) - }) -} - -// GetLogger returns a singleton slog.Logger that logs to the console func GetLogger() *slog.Logger { - if loggerInstance == nil { - Init(false) - } + once.Do(func() { + initialize() + }) return loggerInstance } + +func initialize() { + writer := buildWriter() + + loggerInstance = slog.New(slog.NewTextHandler(writer, &slog.HandlerOptions{ + Level: slog.LevelInfo, + ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { + if a.Key == slog.TimeKey { + a.Value = slog.StringValue(time.Now().Format("2006/01/02 15:04:05")) + } + if a.Key == slog.LevelKey { + return slog.Attr{} + } + + return a + }, + })) +} + +func buildWriter() io.Writer { + f, err := os.OpenFile(logFileName, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + fmt.Fprintf(os.Stderr, "Warning: failed to open %s for logging: %v\n", logFileName, err) + return os.Stdout + } + + var currentSize int64 + if info, err := f.Stat(); err == nil { + currentSize = info.Size() + } + + rw := &rotatingWriter{ + file: f, + currentSize: currentSize, + maxSize: maxLogFileSize, + logPath: logFileName, + oldLogPath: oldLogFileName, + } + + return io.MultiWriter(os.Stdout, rw) +} diff --git a/agent/internal/logger/logger_test.go b/agent/internal/logger/logger_test.go new file mode 100644 index 0000000..022733e --- /dev/null +++ b/agent/internal/logger/logger_test.go @@ -0,0 +1,128 @@ +package logger + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_Write_DataWrittenToFile(t *testing.T) { + rw, logPath, _ := setupRotatingWriter(t, 1024) + + data := []byte("hello world\n") + n, err := rw.Write(data) + + require.NoError(t, err) + assert.Equal(t, len(data), n) + assert.Equal(t, int64(len(data)), rw.currentSize) + + content, err := os.ReadFile(logPath) + require.NoError(t, err) + assert.Equal(t, string(data), string(content)) +} + +func Test_Write_WhenLimitExceeded_FileRotated(t *testing.T) { + rw, logPath, oldLogPath := setupRotatingWriter(t, 100) + + firstData := []byte(strings.Repeat("A", 80)) + _, err := rw.Write(firstData) + require.NoError(t, err) + + secondData := []byte(strings.Repeat("B", 30)) + _, err = rw.Write(secondData) + require.NoError(t, err) + + oldContent, err := os.ReadFile(oldLogPath) + require.NoError(t, err) + assert.Equal(t, string(firstData), string(oldContent)) + + newContent, err := os.ReadFile(logPath) + require.NoError(t, err) + assert.Equal(t, string(secondData), string(newContent)) + + assert.Equal(t, int64(len(secondData)), rw.currentSize) +} + +func Test_Write_WhenOldFileExists_OldFileReplaced(t *testing.T) { + rw, _, oldLogPath := setupRotatingWriter(t, 100) + + require.NoError(t, os.WriteFile(oldLogPath, []byte("stale data"), 0o644)) + + _, err := rw.Write([]byte(strings.Repeat("A", 80))) + require.NoError(t, err) + + _, err = rw.Write([]byte(strings.Repeat("B", 30))) + require.NoError(t, err) + + oldContent, err := os.ReadFile(oldLogPath) + require.NoError(t, err) + assert.Equal(t, strings.Repeat("A", 80), string(oldContent)) +} + +func Test_Write_MultipleSmallWrites_CurrentSizeAccumulated(t *testing.T) { + rw, _, _ := setupRotatingWriter(t, 1024) + + var totalWritten int64 + for i := 0; i < 10; i++ { + data := []byte("line\n") + n, err := rw.Write(data) + require.NoError(t, err) + + totalWritten += int64(n) + } + + assert.Equal(t, totalWritten, rw.currentSize) + assert.Equal(t, int64(50), rw.currentSize) +} + +func Test_Write_ExactlyAtBoundary_NoRotationUntilNextByte(t *testing.T) { + rw, logPath, oldLogPath := setupRotatingWriter(t, 100) + + exactData := []byte(strings.Repeat("X", 100)) + _, err := rw.Write(exactData) + require.NoError(t, err) + + _, err = os.Stat(oldLogPath) + assert.True(t, os.IsNotExist(err), ".old file should not exist yet") + + content, err := os.ReadFile(logPath) + require.NoError(t, err) + assert.Equal(t, string(exactData), string(content)) + + _, err = rw.Write([]byte("Z")) + require.NoError(t, err) + + _, err = os.Stat(oldLogPath) + assert.NoError(t, err, ".old file should exist after exceeding limit") + + assert.Equal(t, int64(1), rw.currentSize) +} + +func setupRotatingWriter(t *testing.T, maxSize int64) (*rotatingWriter, string, string) { + t.Helper() + + dir := t.TempDir() + logPath := filepath.Join(dir, "test.log") + oldLogPath := filepath.Join(dir, "test.log.old") + + f, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY, 0o644) + require.NoError(t, err) + + rw := &rotatingWriter{ + file: f, + currentSize: 0, + maxSize: maxSize, + logPath: logPath, + oldLogPath: oldLogPath, + } + + t.Cleanup(func() { + rw.file.Close() + }) + + return rw, logPath, oldLogPath +} diff --git a/backend/internal/features/backups/backups/backuping/cleaner.go b/backend/internal/features/backups/backups/backuping/cleaner.go index 44c3c1b..17b6202 100644 --- a/backend/internal/features/backups/backups/backuping/cleaner.go +++ b/backend/internal/features/backups/backups/backuping/cleaner.go @@ -59,6 +59,10 @@ func (c *BackupCleaner) Run(ctx context.Context) { if err := c.cleanExceededBackups(); err != nil { c.logger.Error("Failed to clean exceeded backups", "error", err) } + + if err := c.cleanStaleUploadedBasebackups(); err != nil { + c.logger.Error("Failed to clean stale uploaded basebackups", "error", err) + } } } }) @@ -100,6 +104,67 @@ func (c *BackupCleaner) AddBackupRemoveListener(listener backups_core.BackupRemo c.backupRemoveListeners = append(c.backupRemoveListeners, listener) } +func (c *BackupCleaner) cleanStaleUploadedBasebackups() error { + staleBackups, err := c.backupRepository.FindStaleUploadedBasebackups( + time.Now().UTC().Add(-10 * time.Minute), + ) + if err != nil { + return fmt.Errorf("failed to find stale uploaded basebackups: %w", err) + } + + for _, backup := range staleBackups { + staleStorage, storageErr := c.storageService.GetStorageByID(backup.StorageID) + if storageErr != nil { + c.logger.Error( + "Failed to get storage for stale basebackup cleanup", + "backupId", backup.ID, + "storageId", backup.StorageID, + "error", storageErr, + ) + } else { + if err := staleStorage.DeleteFile(c.fieldEncryptor, backup.FileName); err != nil { + c.logger.Error( + "Failed to delete stale basebackup file", + "backupId", backup.ID, + "fileName", backup.FileName, + "error", err, + ) + } + + metadataFileName := backup.FileName + ".metadata" + if err := staleStorage.DeleteFile(c.fieldEncryptor, metadataFileName); err != nil { + c.logger.Error( + "Failed to delete stale basebackup metadata file", + "backupId", backup.ID, + "fileName", metadataFileName, + "error", err, + ) + } + } + + failMsg := "basebackup finalization timed out after 10 minutes" + backup.Status = backups_core.BackupStatusFailed + backup.FailMessage = &failMsg + + if err := c.backupRepository.Save(backup); err != nil { + c.logger.Error( + "Failed to mark stale uploaded basebackup as failed", + "backupId", backup.ID, + "error", err, + ) + continue + } + + c.logger.Info( + "Marked stale uploaded basebackup as failed and cleaned storage", + "backupId", backup.ID, + "databaseId", backup.DatabaseID, + ) + } + + return nil +} + func (c *BackupCleaner) cleanByRetentionPolicy() error { enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups() if err != nil { diff --git a/backend/internal/features/backups/backups/backuping/cleaner_test.go b/backend/internal/features/backups/backups/backuping/cleaner_test.go index c8995dd..02b800e 100644 --- a/backend/internal/features/backups/backups/backuping/cleaner_test.go +++ b/backend/internal/features/backups/backups/backuping/cleaner_test.go @@ -1004,6 +1004,191 @@ func (m *mockBackupRemoveListener) OnBeforeBackupRemove(backup *backups_core.Bac return nil } +func Test_CleanStaleUploadedBasebackups_MarksAsFailed(t *testing.T) { + router := CreateTestRouter() + owner := users_testing.CreateTestUser(users_enums.UserRoleMember) + workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router) + storage := storages.CreateTestStorage(workspace.ID) + notifier := notifiers.CreateTestNotifier(workspace.ID) + database := databases.CreateTestDatabase(workspace.ID, storage, notifier) + + defer func() { + backups, _ := backupRepository.FindByDatabaseID(database.ID) + for _, backup := range backups { + backupRepository.DeleteByID(backup.ID) + } + + databases.RemoveTestDatabase(database) + time.Sleep(50 * time.Millisecond) + notifiers.RemoveTestNotifier(notifier) + storages.RemoveTestStorage(storage.ID) + workspaces_testing.RemoveTestWorkspace(workspace, router) + }() + + staleTime := time.Now().UTC().Add(-15 * time.Minute) + walBackupType := backups_core.PgWalBackupTypeFullBackup + staleBackup := &backups_core.Backup{ + ID: uuid.New(), + DatabaseID: database.ID, + StorageID: storage.ID, + Status: backups_core.BackupStatusInProgress, + PgWalBackupType: &walBackupType, + UploadCompletedAt: &staleTime, + CreatedAt: staleTime, + } + + err := backupRepository.Save(staleBackup) + assert.NoError(t, err) + + cleaner := GetBackupCleaner() + err = cleaner.cleanStaleUploadedBasebackups() + assert.NoError(t, err) + + updated, err := backupRepository.FindByID(staleBackup.ID) + assert.NoError(t, err) + assert.Equal(t, backups_core.BackupStatusFailed, updated.Status) + assert.NotNil(t, updated.FailMessage) + assert.Contains(t, *updated.FailMessage, "finalization timed out") +} + +func Test_CleanStaleUploadedBasebackups_SkipsRecentUploads(t *testing.T) { + router := CreateTestRouter() + owner := users_testing.CreateTestUser(users_enums.UserRoleMember) + workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router) + storage := storages.CreateTestStorage(workspace.ID) + notifier := notifiers.CreateTestNotifier(workspace.ID) + database := databases.CreateTestDatabase(workspace.ID, storage, notifier) + + defer func() { + backups, _ := backupRepository.FindByDatabaseID(database.ID) + for _, backup := range backups { + backupRepository.DeleteByID(backup.ID) + } + + databases.RemoveTestDatabase(database) + time.Sleep(50 * time.Millisecond) + notifiers.RemoveTestNotifier(notifier) + storages.RemoveTestStorage(storage.ID) + workspaces_testing.RemoveTestWorkspace(workspace, router) + }() + + recentTime := time.Now().UTC().Add(-2 * time.Minute) + walBackupType := backups_core.PgWalBackupTypeFullBackup + recentBackup := &backups_core.Backup{ + ID: uuid.New(), + DatabaseID: database.ID, + StorageID: storage.ID, + Status: backups_core.BackupStatusInProgress, + PgWalBackupType: &walBackupType, + UploadCompletedAt: &recentTime, + CreatedAt: recentTime, + } + + err := backupRepository.Save(recentBackup) + assert.NoError(t, err) + + cleaner := GetBackupCleaner() + err = cleaner.cleanStaleUploadedBasebackups() + assert.NoError(t, err) + + updated, err := backupRepository.FindByID(recentBackup.ID) + assert.NoError(t, err) + assert.Equal(t, backups_core.BackupStatusInProgress, updated.Status) +} + +func Test_CleanStaleUploadedBasebackups_SkipsActiveStreaming(t *testing.T) { + router := CreateTestRouter() + owner := users_testing.CreateTestUser(users_enums.UserRoleMember) + workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router) + storage := storages.CreateTestStorage(workspace.ID) + notifier := notifiers.CreateTestNotifier(workspace.ID) + database := databases.CreateTestDatabase(workspace.ID, storage, notifier) + + defer func() { + backups, _ := backupRepository.FindByDatabaseID(database.ID) + for _, backup := range backups { + backupRepository.DeleteByID(backup.ID) + } + + databases.RemoveTestDatabase(database) + time.Sleep(50 * time.Millisecond) + notifiers.RemoveTestNotifier(notifier) + storages.RemoveTestStorage(storage.ID) + workspaces_testing.RemoveTestWorkspace(workspace, router) + }() + + walBackupType := backups_core.PgWalBackupTypeFullBackup + activeBackup := &backups_core.Backup{ + ID: uuid.New(), + DatabaseID: database.ID, + StorageID: storage.ID, + Status: backups_core.BackupStatusInProgress, + PgWalBackupType: &walBackupType, + CreatedAt: time.Now().UTC().Add(-30 * time.Minute), + } + + err := backupRepository.Save(activeBackup) + assert.NoError(t, err) + + cleaner := GetBackupCleaner() + err = cleaner.cleanStaleUploadedBasebackups() + assert.NoError(t, err) + + updated, err := backupRepository.FindByID(activeBackup.ID) + assert.NoError(t, err) + assert.Equal(t, backups_core.BackupStatusInProgress, updated.Status) + assert.Nil(t, updated.UploadCompletedAt) +} + +func Test_CleanStaleUploadedBasebackups_CleansStorageFiles(t *testing.T) { + router := CreateTestRouter() + owner := users_testing.CreateTestUser(users_enums.UserRoleMember) + workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router) + storage := storages.CreateTestStorage(workspace.ID) + notifier := notifiers.CreateTestNotifier(workspace.ID) + database := databases.CreateTestDatabase(workspace.ID, storage, notifier) + + defer func() { + backups, _ := backupRepository.FindByDatabaseID(database.ID) + for _, backup := range backups { + backupRepository.DeleteByID(backup.ID) + } + + databases.RemoveTestDatabase(database) + time.Sleep(50 * time.Millisecond) + notifiers.RemoveTestNotifier(notifier) + storages.RemoveTestStorage(storage.ID) + workspaces_testing.RemoveTestWorkspace(workspace, router) + }() + + staleTime := time.Now().UTC().Add(-15 * time.Minute) + walBackupType := backups_core.PgWalBackupTypeFullBackup + staleBackup := &backups_core.Backup{ + ID: uuid.New(), + DatabaseID: database.ID, + StorageID: storage.ID, + Status: backups_core.BackupStatusInProgress, + PgWalBackupType: &walBackupType, + UploadCompletedAt: &staleTime, + BackupSizeMb: 500, + FileName: "stale-basebackup-test-file", + CreatedAt: staleTime, + } + + err := backupRepository.Save(staleBackup) + assert.NoError(t, err) + + cleaner := GetBackupCleaner() + err = cleaner.cleanStaleUploadedBasebackups() + assert.NoError(t, err) + + updated, err := backupRepository.FindByID(staleBackup.ID) + assert.NoError(t, err) + assert.Equal(t, backups_core.BackupStatusFailed, updated.Status) + assert.NotNil(t, updated.FailMessage) + assert.Contains(t, *updated.FailMessage, "finalization timed out") +} + func createTestInterval() *intervals.Interval { timeOfDay := "04:00" interval := &intervals.Interval{ diff --git a/backend/internal/features/backups/backups/controllers/postgres_wal_controller.go b/backend/internal/features/backups/backups/controllers/postgres_wal_controller.go index d14706d..5888651 100644 --- a/backend/internal/features/backups/backups/controllers/postgres_wal_controller.go +++ b/backend/internal/features/backups/backups/controllers/postgres_wal_controller.go @@ -3,12 +3,10 @@ package backups_controllers import ( "io" "net/http" - "strconv" "github.com/gin-gonic/gin" "github.com/google/uuid" - backups_core "databasus-backend/internal/features/backups/backups/core" backups_dto "databasus-backend/internal/features/backups/backups/dto" backups_services "databasus-backend/internal/features/backups/backups/services" "databasus-backend/internal/features/databases" @@ -25,8 +23,11 @@ func (c *PostgreWalBackupController) RegisterRoutes(router *gin.RouterGroup) { walRoutes := router.Group("/backups/postgres/wal") walRoutes.GET("/next-full-backup-time", c.GetNextFullBackupTime) + walRoutes.GET("/is-wal-chain-valid-since-last-full-backup", c.IsWalChainValidSinceLastBackup) walRoutes.POST("/error", c.ReportError) - walRoutes.POST("/upload", c.Upload) + walRoutes.POST("/upload/wal", c.UploadWalSegment) + walRoutes.POST("/upload/full-start", c.StartFullBackupUpload) + walRoutes.POST("/upload/full-complete", c.CompleteFullBackupUpload) walRoutes.GET("/restore/plan", c.GetRestorePlan) walRoutes.GET("/restore/download", c.DownloadBackupFile) } @@ -90,91 +91,66 @@ func (c *PostgreWalBackupController) ReportError(ctx *gin.Context) { ctx.Status(http.StatusOK) } -// Upload -// @Summary Stream upload a basebackup or WAL segment -// @Description Accepts a zstd-compressed binary stream and stores it in the database's configured storage. -// The server generates the storage filename; agents do not control the destination path. -// For WAL segment uploads the server validates the WAL chain and returns 409 if a gap is detected -// or 400 if no full backup exists yet (agent should trigger a full basebackup in both cases). +// IsWalChainValidSinceLastBackup +// @Summary Check WAL chain validity since last full backup +// @Description Checks whether the WAL chain is continuous since the last completed full backup. +// Returns isValid=true if the chain is intact, or isValid=false with error details if not. // @Tags backups-wal -// @Accept application/octet-stream // @Produce json // @Security AgentToken -// @Param X-Upload-Type header string true "Upload type" Enums(basebackup, wal) -// @Param X-Wal-Segment-Name header string false "24-hex WAL segment identifier (required for wal uploads, e.g. 0000000100000001000000AB)" -// @Param X-Wal-Segment-Size header int false "WAL segment size in bytes reported by the PostgreSQL instance (default: 16777216)" -// @Param fullBackupWalStartSegment query string false "First WAL segment needed to make the basebackup consistent (required for basebackup uploads)" -// @Param fullBackupWalStopSegment query string false "Last WAL segment included in the basebackup (required for basebackup uploads)" -// @Success 204 -// @Failure 400 {object} backups_dto.UploadGapResponse "No full backup exists (error: no_full_backup)" +// @Success 200 {object} backups_dto.IsWalChainValidResponse // @Failure 401 {object} map[string]string -// @Failure 409 {object} backups_dto.UploadGapResponse "WAL chain gap detected (error: gap_detected)" // @Failure 500 {object} map[string]string -// @Router /backups/postgres/wal/upload [post] -func (c *PostgreWalBackupController) Upload(ctx *gin.Context) { +// @Router /backups/postgres/wal/is-wal-chain-valid-since-last-full-backup [get] +func (c *PostgreWalBackupController) IsWalChainValidSinceLastBackup(ctx *gin.Context) { database, err := c.getDatabase(ctx) if err != nil { ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"}) return } - uploadType := backups_core.PgWalUploadType(ctx.GetHeader("X-Upload-Type")) - if uploadType != backups_core.PgWalUploadTypeBasebackup && - uploadType != backups_core.PgWalUploadTypeWal { + response, err := c.walService.IsWalChainValid(database) + if err != nil { + ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + ctx.JSON(http.StatusOK, response) +} + +// UploadWalSegment +// @Summary Stream upload a WAL segment +// @Description Accepts a zstd-compressed WAL segment binary stream and stores it in the database's configured storage. +// WAL segments are accepted unconditionally. +// @Tags backups-wal +// @Accept application/octet-stream +// @Security AgentToken +// @Param X-Wal-Segment-Name header string true "24-hex WAL segment identifier (e.g. 0000000100000001000000AB)" +// @Success 204 +// @Failure 400 {object} map[string]string +// @Failure 401 {object} map[string]string +// @Failure 500 {object} map[string]string +// @Router /backups/postgres/wal/upload/wal [post] +func (c *PostgreWalBackupController) UploadWalSegment(ctx *gin.Context) { + database, err := c.getDatabase(ctx) + if err != nil { + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"}) + return + } + + walSegmentName := ctx.GetHeader("X-Wal-Segment-Name") + if walSegmentName == "" { ctx.JSON( http.StatusBadRequest, - gin.H{"error": "X-Upload-Type must be 'basebackup' or 'wal'"}, + gin.H{"error": "X-Wal-Segment-Name is required for wal uploads"}, ) return } - walSegmentName := "" - if uploadType == backups_core.PgWalUploadTypeWal { - walSegmentName = ctx.GetHeader("X-Wal-Segment-Name") - if walSegmentName == "" { - ctx.JSON( - http.StatusBadRequest, - gin.H{"error": "X-Wal-Segment-Name is required for wal uploads"}, - ) - return - } - } - - if uploadType == backups_core.PgWalUploadTypeBasebackup { - if ctx.Query("fullBackupWalStartSegment") == "" || - ctx.Query("fullBackupWalStopSegment") == "" { - ctx.JSON( - http.StatusBadRequest, - gin.H{ - "error": "fullBackupWalStartSegment and fullBackupWalStopSegment are required for basebackup uploads", - }, - ) - return - } - } - - walSegmentSizeBytes := int64(0) - if raw := ctx.GetHeader("X-Wal-Segment-Size"); raw != "" { - parsed, parseErr := strconv.ParseInt(raw, 10, 64) - if parseErr != nil || parsed <= 0 { - ctx.JSON( - http.StatusBadRequest, - gin.H{"error": "X-Wal-Segment-Size must be a positive integer"}, - ) - return - } - - walSegmentSizeBytes = parsed - } - - gapResp, uploadErr := c.walService.UploadWal( + uploadErr := c.walService.UploadWalSegment( ctx.Request.Context(), database, - uploadType, walSegmentName, - ctx.Query("fullBackupWalStartSegment"), - ctx.Query("fullBackupWalStopSegment"), - walSegmentSizeBytes, ctx.Request.Body, ) @@ -183,17 +159,81 @@ func (c *PostgreWalBackupController) Upload(ctx *gin.Context) { return } - if gapResp != nil { - if gapResp.Error == "no_full_backup" { - ctx.JSON(http.StatusBadRequest, gapResp) - return - } + ctx.Status(http.StatusNoContent) +} - ctx.JSON(http.StatusConflict, gapResp) +// StartFullBackupUpload +// @Summary Stream upload a full basebackup (Phase 1) +// @Description Accepts a zstd-compressed basebackup binary stream and stores it in the database's configured storage. +// Returns a backupId that must be completed via /upload/full-complete with WAL segment names. +// @Tags backups-wal +// @Accept application/octet-stream +// @Produce json +// @Security AgentToken +// @Success 200 {object} backups_dto.UploadBasebackupResponse +// @Failure 401 {object} map[string]string +// @Failure 500 {object} map[string]string +// @Router /backups/postgres/wal/upload/full-start [post] +func (c *PostgreWalBackupController) StartFullBackupUpload(ctx *gin.Context) { + database, err := c.getDatabase(ctx) + if err != nil { + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"}) return } - ctx.Status(http.StatusNoContent) + backupID, uploadErr := c.walService.UploadBasebackup( + ctx.Request.Context(), + database, + ctx.Request.Body, + ) + + if uploadErr != nil { + ctx.JSON(http.StatusInternalServerError, gin.H{"error": uploadErr.Error()}) + return + } + + ctx.JSON(http.StatusOK, backups_dto.UploadBasebackupResponse{ + BackupID: backupID, + }) +} + +// CompleteFullBackupUpload +// @Summary Complete a previously uploaded basebackup (Phase 2) +// @Description Sets WAL segment names and marks the basebackup as completed, or marks it as failed if an error is provided. +// @Tags backups-wal +// @Accept json +// @Security AgentToken +// @Param request body backups_dto.FinalizeBasebackupRequest true "Completion details" +// @Success 200 +// @Failure 400 {object} map[string]string +// @Failure 401 {object} map[string]string +// @Failure 500 {object} map[string]string +// @Router /backups/postgres/wal/upload/full-complete [post] +func (c *PostgreWalBackupController) CompleteFullBackupUpload(ctx *gin.Context) { + database, err := c.getDatabase(ctx) + if err != nil { + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"}) + return + } + + var request backups_dto.FinalizeBasebackupRequest + if err := ctx.ShouldBindJSON(&request); err != nil { + ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if err := c.walService.FinalizeBasebackup( + database, + request.BackupID, + request.StartSegment, + request.StopSegment, + request.Error, + ); err != nil { + ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + ctx.Status(http.StatusOK) } // GetRestorePlan diff --git a/backend/internal/features/backups/backups/controllers/postgres_wal_controller_test.go b/backend/internal/features/backups/backups/controllers/postgres_wal_controller_test.go index 15764c9..2bafb90 100644 --- a/backend/internal/features/backups/backups/controllers/postgres_wal_controller_test.go +++ b/backend/internal/features/backups/backups/controllers/postgres_wal_controller_test.go @@ -38,7 +38,7 @@ func Test_WalUpload_InProgressStatusSetBeforeStream(t *testing.T) { uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010") pr, pw := io.Pipe() - req := newWalUploadRequest(pr, agentToken, "wal", "000000010000000100000011", "", "") + req := newWalSegmentUploadRequest(pr, agentToken, "000000010000000100000011") w := httptest.NewRecorder() done := make(chan struct{}) @@ -67,7 +67,7 @@ func Test_WalUpload_CompletedStatusAfterSuccessfulStream(t *testing.T) { uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010") body := bytes.NewReader([]byte("wal segment content")) - req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000011", "", "") + req := newWalSegmentUploadRequest(body, agentToken, "000000010000000100000011") w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -99,7 +99,7 @@ func Test_WalUpload_FailedStatusWithErrorOnStreamError(t *testing.T) { uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010") pr, pw := io.Pipe() - req := newWalUploadRequest(pr, agentToken, "wal", "000000010000000100000011", "", "") + req := newWalSegmentUploadRequest(pr, agentToken, "000000010000000100000011") w := httptest.NewRecorder() done := make(chan struct{}) @@ -129,59 +129,171 @@ func Test_WalUpload_FailedStatusWithErrorOnStreamError(t *testing.T) { assert.NotNil(t, walBackup.FailMessage) } -func Test_WalUpload_Basebackup_MissingWalSegments_Returns400(t *testing.T) { +func Test_WalUpload_Basebackup_StreamingUpload_Returns200WithBackupId(t *testing.T) { router, db, storage, agentToken, _ := createWalTestSetup(t) defer removeWalTestSetup(db, storage) body := bytes.NewReader([]byte("basebackup content")) - req := newWalUploadRequest(body, agentToken, backups_core.PgWalUploadTypeBasebackup, "", "", "") + req, _ := http.NewRequest(http.MethodPost, "/api/v1/backups/postgres/wal/upload/full-start", body) + req.Header.Set("Authorization", agentToken) + req.Header.Set("Content-Type", "application/octet-stream") w := httptest.NewRecorder() router.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) + + var response backups_dto.UploadBasebackupResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) + assert.NotEqual(t, uuid.Nil, response.BackupID) + + backup, err := backups_core.GetBackupRepository().FindByID(response.BackupID) + require.NoError(t, err) + assert.Equal(t, backups_core.BackupStatusInProgress, backup.Status) + assert.NotNil(t, backup.UploadCompletedAt) +} + +func Test_FinalizeBasebackup_ValidSegments_MarksCompleted(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + backupID := uploadBasebackupPhase1(t, router, agentToken) + + completeFullBackupUpload(t, router, agentToken, backupID, + "000000010000000100000001", "000000010000000100000010", nil) + + backup, err := backups_core.GetBackupRepository().FindByID(backupID) + require.NoError(t, err) + assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status) + require.NotNil(t, backup.PgFullBackupWalStartSegmentName) + assert.Equal(t, "000000010000000100000001", *backup.PgFullBackupWalStartSegmentName) + require.NotNil(t, backup.PgFullBackupWalStopSegmentName) + assert.Equal(t, "000000010000000100000010", *backup.PgFullBackupWalStopSegmentName) +} + +func Test_FinalizeBasebackup_WithError_MarksFailed(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + backupID := uploadBasebackupPhase1(t, router, agentToken) + + errMsg := "pg_basebackup stderr parse failed" + completeFullBackupUpload(t, router, agentToken, backupID, "", "", &errMsg) + + backup, err := backups_core.GetBackupRepository().FindByID(backupID) + require.NoError(t, err) + assert.Equal(t, backups_core.BackupStatusFailed, backup.Status) + require.NotNil(t, backup.FailMessage) + assert.Equal(t, errMsg, *backup.FailMessage) +} + +func Test_FinalizeBasebackup_InvalidBackupId_Returns400(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + nonExistentID := uuid.New() + body, _ := json.Marshal(backups_dto.FinalizeBasebackupRequest{ + BackupID: nonExistentID, + StartSegment: "000000010000000100000001", + StopSegment: "000000010000000100000010", + }) + + req, _ := http.NewRequest( + http.MethodPost, + "/api/v1/backups/postgres/wal/upload/full-complete", + bytes.NewReader(body), + ) + req.Header.Set("Authorization", agentToken) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusBadRequest, w.Code) } -func Test_WalUpload_WalSegment_NoFullBackup_Returns400(t *testing.T) { +func Test_FinalizeBasebackup_AlreadyCompleted_Returns400(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + backupID := uploadBasebackupPhase1(t, router, agentToken) + + completeFullBackupUpload(t, router, agentToken, backupID, + "000000010000000100000001", "000000010000000100000010", nil) + + // Second finalize should fail. + body, _ := json.Marshal(backups_dto.FinalizeBasebackupRequest{ + BackupID: backupID, + StartSegment: "000000010000000100000001", + StopSegment: "000000010000000100000010", + }) + + req, _ := http.NewRequest( + http.MethodPost, + "/api/v1/backups/postgres/wal/upload/full-complete", + bytes.NewReader(body), + ) + req.Header.Set("Authorization", agentToken) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) +} + +func Test_FinalizeBasebackup_InvalidToken_Returns401(t *testing.T) { + router, db, storage, _, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + body, _ := json.Marshal(backups_dto.FinalizeBasebackupRequest{ + BackupID: uuid.New(), + StartSegment: "000000010000000100000001", + StopSegment: "000000010000000100000010", + }) + + req, _ := http.NewRequest( + http.MethodPost, + "/api/v1/backups/postgres/wal/upload/full-complete", + bytes.NewReader(body), + ) + req.Header.Set("Authorization", "invalid-token") + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusUnauthorized, w.Code) +} + +func Test_WalUpload_WalSegment_WithoutFullBackup_Returns204(t *testing.T) { router, db, storage, agentToken, _ := createWalTestSetup(t) defer removeWalTestSetup(db, storage) - // No full backup inserted — chain anchor is missing. body := bytes.NewReader([]byte("wal content")) - req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000001", "", "") + req := newWalSegmentUploadRequest(body, agentToken, "000000010000000100000001") w := httptest.NewRecorder() router.ServeHTTP(w, req) - assert.Equal(t, http.StatusBadRequest, w.Code) - - var resp backups_dto.UploadGapResponse - require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) - assert.Equal(t, "no_full_backup", resp.Error) + assert.Equal(t, http.StatusNoContent, w.Code) } -func Test_WalUpload_WalSegment_GapDetected_Returns409WithExpectedAndReceived(t *testing.T) { +func Test_WalUpload_WalSegment_WithGap_Returns204(t *testing.T) { router, db, storage, agentToken, _ := createWalTestSetup(t) defer removeWalTestSetup(db, storage) - // Full backup stops at ...0010; upload one WAL segment at ...0011. uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010") uploadWalSegment(t, router, agentToken, "000000010000000100000011") - // Send ...0013 — should be rejected because ...0012 is missing. + // Skip ...0012, upload ...0013 — should succeed (no chain validation on upload). body := bytes.NewReader([]byte("wal content")) - req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000013", "", "") + req := newWalSegmentUploadRequest(body, agentToken, "000000010000000100000013") w := httptest.NewRecorder() router.ServeHTTP(w, req) - assert.Equal(t, http.StatusConflict, w.Code) - - var resp backups_dto.UploadGapResponse - require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) - assert.Equal(t, "gap_detected", resp.Error) - assert.Equal(t, "000000010000000100000012", resp.ExpectedSegmentName) - assert.Equal(t, "000000010000000100000013", resp.ReceivedSegmentName) + assert.Equal(t, http.StatusNoContent, w.Code) } func Test_WalUpload_WalSegment_DuplicateSegment_Returns200Idempotent(t *testing.T) { @@ -192,14 +304,14 @@ func Test_WalUpload_WalSegment_DuplicateSegment_Returns200Idempotent(t *testing. // Upload ...0011 once. body1 := bytes.NewReader([]byte("wal content")) - req1 := newWalUploadRequest(body1, agentToken, "wal", "000000010000000100000011", "", "") + req1 := newWalSegmentUploadRequest(body1, agentToken, "000000010000000100000011") w1 := httptest.NewRecorder() router.ServeHTTP(w1, req1) require.Equal(t, http.StatusNoContent, w1.Code) // Upload the same segment again — must return 204 (idempotent). body2 := bytes.NewReader([]byte("wal content")) - req2 := newWalUploadRequest(body2, agentToken, "wal", "000000010000000100000011", "", "") + req2 := newWalSegmentUploadRequest(body2, agentToken, "000000010000000100000011") w2 := httptest.NewRecorder() router.ServeHTTP(w2, req2) @@ -228,7 +340,7 @@ func Test_WalUpload_WalSegment_ValidNextSegment_Returns200AndCreatesRecord(t *te // First WAL segment after the full backup stop segment. body := bytes.NewReader([]byte("wal segment data")) - req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000011", "", "") + req := newWalSegmentUploadRequest(body, agentToken, "000000010000000100000011") w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -255,6 +367,108 @@ func Test_WalUpload_WalSegment_ValidNextSegment_Returns200AndCreatesRecord(t *te assert.Equal(t, "000000010000000100000011", *walBackup.PgWalSegmentName) } +func Test_IsWalChainValid_NoFullBackup_ReturnsFalse(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + var response backups_dto.IsWalChainValidResponse + test_utils.MakeGetRequestAndUnmarshal( + t, router, + "/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup", + agentToken, + http.StatusOK, + &response, + ) + + assert.False(t, response.IsValid) + assert.Equal(t, "no_full_backup", response.Error) +} + +func Test_IsWalChainValid_FullBackupOnly_ReturnsTrue(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010") + + var response backups_dto.IsWalChainValidResponse + test_utils.MakeGetRequestAndUnmarshal( + t, router, + "/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup", + agentToken, + http.StatusOK, + &response, + ) + + assert.True(t, response.IsValid) + assert.Empty(t, response.Error) +} + +func Test_IsWalChainValid_ContinuousChain_ReturnsTrue(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010") + uploadWalSegment(t, router, agentToken, "000000010000000100000011") + uploadWalSegment(t, router, agentToken, "000000010000000100000012") + uploadWalSegment(t, router, agentToken, "000000010000000100000013") + + var response backups_dto.IsWalChainValidResponse + test_utils.MakeGetRequestAndUnmarshal( + t, router, + "/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup", + agentToken, + http.StatusOK, + &response, + ) + + assert.True(t, response.IsValid) +} + +func Test_IsWalChainValid_BrokenChain_ReturnsFalse(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010") + uploadWalSegment(t, router, agentToken, "000000010000000100000011") + uploadWalSegment(t, router, agentToken, "000000010000000100000012") + uploadWalSegment(t, router, agentToken, "000000010000000100000013") + + // Delete the middle segment to create a gap. + middleSeg, err := backups_core.GetBackupRepository().FindWalSegmentByName( + db.ID, "000000010000000100000012", + ) + require.NoError(t, err) + require.NotNil(t, middleSeg) + require.NoError(t, backups_core.GetBackupRepository().DeleteByID(middleSeg.ID)) + + var response backups_dto.IsWalChainValidResponse + test_utils.MakeGetRequestAndUnmarshal( + t, router, + "/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup", + agentToken, + http.StatusOK, + &response, + ) + + assert.False(t, response.IsValid) + assert.Equal(t, "wal_chain_broken", response.Error) + assert.Equal(t, "000000010000000100000011", response.LastContiguousSegment) +} + +func Test_IsWalChainValid_InvalidToken_Returns401(t *testing.T) { + router, db, storage, _, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + resp := test_utils.MakeGetRequest( + t, router, + "/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup", + "invalid-token", + http.StatusUnauthorized, + ) + + assert.Contains(t, string(resp.Body), "invalid agent token") +} + func Test_ReportError_ValidTokenAndError_CreatesFailedBackupRecord(t *testing.T) { router, db, storage, agentToken, _ := createWalTestSetup(t) defer removeWalTestSetup(db, storage) @@ -457,29 +671,14 @@ func Test_GetNextFullBackupTime_WalSegmentAfterFullBackup_DoesNotImpactTime(t *t setHourlyInterval(t, router, db.ID, ownerToken) - // Upload basebackup via API. - bbBody := bytes.NewReader([]byte("basebackup content")) - bbReq := newWalUploadRequest( - bbBody, agentToken, backups_core.PgWalUploadTypeBasebackup, "", - "000000010000000100000001", "000000010000000100000010", - ) - bbW := httptest.NewRecorder() - router.ServeHTTP(bbW, bbReq) - require.Equal(t, http.StatusNoContent, bbW.Code) + uploadBasebackup(t, router, agentToken, + "000000010000000100000001", "000000010000000100000010") // Shift the full backup's CreatedAt to 2 hours ago. twoHoursAgo := time.Now().UTC().Add(-2 * time.Hour) updateLastFullBackupTime(t, db.ID, twoHoursAgo) - // Upload WAL segment via API. - walBody := bytes.NewReader([]byte("wal segment content")) - walReq := newWalUploadRequest( - walBody, agentToken, backups_core.PgWalUploadTypeWal, - "000000010000000100000011", "", "", - ) - walW := httptest.NewRecorder() - router.ServeHTTP(walW, walReq) - require.Equal(t, http.StatusNoContent, walW.Code) + uploadWalSegment(t, router, agentToken, "000000010000000100000011") var response backups_dto.GetNextFullBackupTimeResponse test_utils.MakeGetRequestAndUnmarshal( @@ -508,15 +707,8 @@ func Test_GetNextFullBackupTime_FailedBasebackup_DoesNotImpactTime(t *testing.T) setHourlyInterval(t, router, db.ID, ownerToken) - // Upload a successful basebackup via API. - bbBody := bytes.NewReader([]byte("basebackup content")) - bbReq := newWalUploadRequest( - bbBody, agentToken, backups_core.PgWalUploadTypeBasebackup, "", - "000000010000000100000001", "000000010000000100000010", - ) - bbW := httptest.NewRecorder() - router.ServeHTTP(bbW, bbReq) - require.Equal(t, http.StatusNoContent, bbW.Code) + uploadBasebackup(t, router, agentToken, + "000000010000000100000001", "000000010000000100000010") // Shift the full backup's CreatedAt to 2 hours ago. twoHoursAgo := time.Now().UTC().Add(-2 * time.Hour) @@ -563,15 +755,8 @@ func Test_GetNextFullBackupTime_NewCompletedFullBackup_ImpactsTime(t *testing.T) setHourlyInterval(t, router, db.ID, ownerToken) - // Upload first basebackup via API. - bb1 := bytes.NewReader([]byte("first basebackup")) - bb1Req := newWalUploadRequest( - bb1, agentToken, backups_core.PgWalUploadTypeBasebackup, "", - "000000010000000100000001", "000000010000000100000010", - ) - bb1W := httptest.NewRecorder() - router.ServeHTTP(bb1W, bb1Req) - require.Equal(t, http.StatusNoContent, bb1W.Code) + uploadBasebackup(t, router, agentToken, + "000000010000000100000001", "000000010000000100000010") // Shift the first backup's CreatedAt to 3 hours ago. threeHoursAgo := time.Now().UTC().Add(-3 * time.Hour) @@ -595,15 +780,8 @@ func Test_GetNextFullBackupTime_NewCompletedFullBackup_ImpactsTime(t *testing.T) "first next time should be in the past (old backup)", ) - // Upload second basebackup via API (created now). - bb2 := bytes.NewReader([]byte("second basebackup")) - bb2Req := newWalUploadRequest( - bb2, agentToken, backups_core.PgWalUploadTypeBasebackup, "", - "000000010000000100000011", "000000010000000100000020", - ) - bb2W := httptest.NewRecorder() - router.ServeHTTP(bb2W, bb2Req) - require.Equal(t, http.StatusNoContent, bb2W.Code) + uploadBasebackup(t, router, agentToken, + "000000010000000100000011", "000000010000000100000020") var secondResponse backups_dto.GetNextFullBackupTimeResponse test_utils.MakeGetRequestAndUnmarshal( @@ -841,15 +1019,18 @@ func Test_DownloadRestoreFile_UploadThenDownload_ContentMatches(t *testing.T) { uploadContent := "test-basebackup-content-for-download" body := bytes.NewReader([]byte(uploadContent)) - req := newWalUploadRequest( - body, agentToken, backups_core.PgWalUploadTypeBasebackup, "", - "000000010000000100000001", "000000010000000100000010", - ) + req, _ := http.NewRequest(http.MethodPost, "/api/v1/backups/postgres/wal/upload/full-start", body) + req.Header.Set("Authorization", agentToken) + req.Header.Set("Content-Type", "application/octet-stream") w := httptest.NewRecorder() router.ServeHTTP(w, req) - require.Equal(t, http.StatusNoContent, w.Code) + require.Equal(t, http.StatusOK, w.Code) - WaitForBackupCompletion(t, db.ID, 0, 5*time.Second) + var uploadResp backups_dto.UploadBasebackupResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &uploadResp)) + + completeFullBackupUpload(t, router, agentToken, uploadResp.BackupID, + "000000010000000100000001", "000000010000000100000010", nil) var planResp backups_dto.GetRestorePlanResponse test_utils.MakeGetRequestAndUnmarshal( @@ -883,7 +1064,7 @@ func Test_DownloadRestoreFile_WalSegment_UploadThenDownload_ContentMatches(t *te walContent := "test-wal-segment-content-for-download" body := bytes.NewReader([]byte(walContent)) - req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000011", "", "") + req := newWalSegmentUploadRequest(body, agentToken, "000000010000000100000011") w := httptest.NewRecorder() router.ServeHTTP(w, req) require.Equal(t, http.StatusNoContent, w.Code) @@ -1088,35 +1269,81 @@ func removeWalTestSetup(db *databases.Database, storage *storages.Storage) { storages.RemoveTestStorage(storage.ID) } -func newWalUploadRequest( +func newWalSegmentUploadRequest( body io.Reader, agentToken string, - uploadType backups_core.PgWalUploadType, - walSegmentName string, - walStart string, - walStop string, + segmentName string, ) *http.Request { - url := "/api/v1/backups/postgres/wal/upload" - if walStart != "" || walStop != "" { - url += "?fullBackupWalStartSegment=" + walStart + "&fullBackupWalStopSegment=" + walStop - } - - req, err := http.NewRequest(http.MethodPost, url, body) + req, err := http.NewRequest(http.MethodPost, "/api/v1/backups/postgres/wal/upload/wal", body) if err != nil { panic(err) } req.Header.Set("Authorization", agentToken) req.Header.Set("Content-Type", "application/octet-stream") - req.Header.Set("X-Upload-Type", string(uploadType)) - - if walSegmentName != "" { - req.Header.Set("X-Wal-Segment-Name", walSegmentName) - } + req.Header.Set("X-Wal-Segment-Name", segmentName) return req } +func uploadBasebackupPhase1( + t *testing.T, + router *gin.Engine, + agentToken string, +) uuid.UUID { + t.Helper() + + body := bytes.NewReader([]byte("test-basebackup-content")) + + req, _ := http.NewRequest(http.MethodPost, "/api/v1/backups/postgres/wal/upload/full-start", body) + req.Header.Set("Authorization", agentToken) + req.Header.Set("Content-Type", "application/octet-stream") + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + + var response backups_dto.UploadBasebackupResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) + require.NotEqual(t, uuid.Nil, response.BackupID) + + return response.BackupID +} + +func completeFullBackupUpload( + t *testing.T, + router *gin.Engine, + agentToken string, + backupID uuid.UUID, + walStart string, + walStop string, + errMsg *string, +) { + t.Helper() + + request := backups_dto.FinalizeBasebackupRequest{ + BackupID: backupID, + StartSegment: walStart, + StopSegment: walStop, + Error: errMsg, + } + + reqBody, _ := json.Marshal(request) + req, _ := http.NewRequest( + http.MethodPost, + "/api/v1/backups/postgres/wal/upload/full-complete", + bytes.NewReader(reqBody), + ) + req.Header.Set("Authorization", agentToken) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) +} + func uploadBasebackup( t *testing.T, router *gin.Engine, @@ -1126,15 +1353,8 @@ func uploadBasebackup( ) { t.Helper() - body := bytes.NewReader([]byte("test-basebackup-content")) - req := newWalUploadRequest( - body, agentToken, backups_core.PgWalUploadTypeBasebackup, "", - walStart, walStop, - ) - w := httptest.NewRecorder() - router.ServeHTTP(w, req) - - require.Equal(t, http.StatusNoContent, w.Code) + backupID := uploadBasebackupPhase1(t, router, agentToken) + completeFullBackupUpload(t, router, agentToken, backupID, walStart, walStop, nil) } func uploadWalSegment( @@ -1146,9 +1366,12 @@ func uploadWalSegment( t.Helper() body := bytes.NewReader([]byte("test-wal-segment-content")) - req := newWalUploadRequest( - body, agentToken, backups_core.PgWalUploadTypeWal, segmentName, "", "", - ) + + req, _ := http.NewRequest(http.MethodPost, "/api/v1/backups/postgres/wal/upload/wal", body) + req.Header.Set("Authorization", agentToken) + req.Header.Set("Content-Type", "application/octet-stream") + req.Header.Set("X-Wal-Segment-Name", segmentName) + w := httptest.NewRecorder() router.ServeHTTP(w, req) diff --git a/backend/internal/features/backups/backups/core/model.go b/backend/internal/features/backups/backups/core/model.go index 9a4123a..8ca49fd 100644 --- a/backend/internal/features/backups/backups/core/model.go +++ b/backend/internal/features/backups/backups/core/model.go @@ -43,7 +43,8 @@ type Backup struct { PgVersion *string `json:"pgVersion" gorm:"column:pg_version;type:text"` PgWalSegmentName *string `json:"pgWalSegmentName" gorm:"column:pg_wal_segment_name;type:text"` - CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"` + UploadCompletedAt *time.Time `json:"uploadCompletedAt" gorm:"column:upload_completed_at"` + CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"` } func (b *Backup) GenerateFilename(dbName string) { diff --git a/backend/internal/features/backups/backups/core/repository.go b/backend/internal/features/backups/backups/core/repository.go index c422f80..3225601 100644 --- a/backend/internal/features/backups/backups/core/repository.go +++ b/backend/internal/features/backups/backups/core/repository.go @@ -349,6 +349,24 @@ func (r *BackupRepository) FindWalSegmentByName( return &backup, nil } +func (r *BackupRepository) FindStaleUploadedBasebackups(olderThan time.Time) ([]*Backup, error) { + var backups []*Backup + + err := storage. + GetDb(). + Where( + "status = ? AND upload_completed_at IS NOT NULL AND upload_completed_at < ?", + BackupStatusInProgress, + olderThan, + ). + Find(&backups).Error + if err != nil { + return nil, err + } + + return backups, nil +} + func (r *BackupRepository) FindLastWalSegmentAfter( databaseID uuid.UUID, afterSegmentName string, diff --git a/backend/internal/features/backups/backups/dto/dto.go b/backend/internal/features/backups/backups/dto/dto.go index 805c646..bd969c0 100644 --- a/backend/internal/features/backups/backups/dto/dto.go +++ b/backend/internal/features/backups/backups/dto/dto.go @@ -44,10 +44,10 @@ type ReportErrorRequest struct { Error string `json:"error" binding:"required"` } -type UploadGapResponse struct { - Error string `json:"error"` - ExpectedSegmentName string `json:"expectedSegmentName"` - ReceivedSegmentName string `json:"receivedSegmentName"` +type IsWalChainValidResponse struct { + IsValid bool `json:"isValid"` + Error string `json:"error,omitempty"` + LastContiguousSegment string `json:"lastContiguousSegment,omitempty"` } type RestorePlanFullBackup struct { @@ -77,3 +77,14 @@ type GetRestorePlanResponse struct { TotalSizeBytes int64 `json:"totalSizeBytes"` LatestAvailableSegment string `json:"latestAvailableSegment"` } + +type UploadBasebackupResponse struct { + BackupID uuid.UUID `json:"backupId"` +} + +type FinalizeBasebackupRequest struct { + BackupID uuid.UUID `json:"backupId" binding:"required"` + StartSegment string `json:"startSegment" binding:"required"` + StopSegment string `json:"stopSegment" binding:"required"` + Error *string `json:"error"` +} diff --git a/backend/internal/features/backups/backups/services/postgres_wal_service.go b/backend/internal/features/backups/backups/services/postgres_wal_service.go index 452733e..c33b04d 100644 --- a/backend/internal/features/backups/backups/services/postgres_wal_service.go +++ b/backend/internal/features/backups/backups/services/postgres_wal_service.go @@ -30,75 +30,46 @@ type PostgreWalBackupService struct { backupService *BackupService } -// UploadWal accepts a streaming WAL segment or basebackup upload from the agent. -// For WAL segments it validates the WAL chain before accepting. Returns an UploadGapResponse -// (409) when the chain is broken so the agent knows to trigger a full basebackup. -func (s *PostgreWalBackupService) UploadWal( +// UploadWalSegment accepts a streaming WAL segment upload from the agent. +// WAL segments are accepted unconditionally. +func (s *PostgreWalBackupService) UploadWalSegment( ctx context.Context, database *databases.Database, - uploadType backups_core.PgWalUploadType, walSegmentName string, - fullBackupWalStartSegment string, - fullBackupWalStopSegment string, - walSegmentSizeBytes int64, body io.Reader, -) (*backups_dto.UploadGapResponse, error) { +) error { if err := s.validateWalBackupType(database); err != nil { - return nil, err - } - - if uploadType == backups_core.PgWalUploadTypeBasebackup { - if fullBackupWalStartSegment == "" || fullBackupWalStopSegment == "" { - return nil, fmt.Errorf( - "fullBackupWalStartSegment and fullBackupWalStopSegment are required for basebackup uploads", - ) - } + return err } backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID) if err != nil { - return nil, fmt.Errorf("failed to get backup config: %w", err) + return fmt.Errorf("failed to get backup config: %w", err) } if backupConfig.Storage == nil { - return nil, fmt.Errorf("no storage configured for database %s", database.ID) + return fmt.Errorf("no storage configured for database %s", database.ID) } - if uploadType == backups_core.PgWalUploadTypeWal { - // Idempotency: check before chain validation so a successful re-upload is - // not misidentified as a gap. - existing, err := s.backupRepository.FindWalSegmentByName(database.ID, walSegmentName) - if err != nil { - return nil, fmt.Errorf("failed to check for duplicate WAL segment: %w", err) - } - - if existing != nil { - return nil, nil - } - - gapResp, err := s.validateWalChain(database.ID, walSegmentName, walSegmentSizeBytes) - if err != nil { - return nil, err - } - - if gapResp != nil { - return gapResp, nil - } + existing, err := s.backupRepository.FindWalSegmentByName(database.ID, walSegmentName) + if err != nil { + return fmt.Errorf("failed to check for duplicate WAL segment: %w", err) } - backup := s.createBackupRecord( + if existing != nil { + return nil + } + + backup := s.createWalSegmentRecord( database.ID, backupConfig.Storage.ID, - uploadType, database.Name, walSegmentName, - fullBackupWalStartSegment, - fullBackupWalStopSegment, backupConfig.Encryption, ) if err := s.backupRepository.Save(backup); err != nil { - return nil, fmt.Errorf("failed to create backup record: %w", err) + return fmt.Errorf("failed to create backup record: %w", err) } sizeBytes, streamErr := s.streamToStorage(ctx, backup, backupConfig, body) @@ -106,12 +77,106 @@ func (s *PostgreWalBackupService) UploadWal( errMsg := streamErr.Error() s.markFailed(backup, errMsg) - return nil, fmt.Errorf("upload failed: %w", streamErr) + return fmt.Errorf("upload failed: %w", streamErr) } s.markCompleted(backup, sizeBytes) - return nil, nil + return nil +} + +// UploadBasebackup accepts a streaming basebackup upload from the agent (Phase 1). +// The backup stays IN_PROGRESS with UploadCompletedAt set after streaming finishes. +// The agent must call FinalizeBasebackup (Phase 2) with WAL segment names to complete. +func (s *PostgreWalBackupService) UploadBasebackup( + ctx context.Context, + database *databases.Database, + body io.Reader, +) (uuid.UUID, error) { + if err := s.validateWalBackupType(database); err != nil { + return uuid.Nil, err + } + + backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID) + if err != nil { + return uuid.Nil, fmt.Errorf("failed to get backup config: %w", err) + } + + if backupConfig.Storage == nil { + return uuid.Nil, fmt.Errorf("no storage configured for database %s", database.ID) + } + + backup := s.createBasebackupRecord( + database.ID, + backupConfig.Storage.ID, + database.Name, + backupConfig.Encryption, + ) + + if err := s.backupRepository.Save(backup); err != nil { + return uuid.Nil, fmt.Errorf("failed to create backup record: %w", err) + } + + sizeBytes, streamErr := s.streamToStorage(ctx, backup, backupConfig, body) + if streamErr != nil { + errMsg := streamErr.Error() + s.markFailed(backup, errMsg) + + return uuid.Nil, fmt.Errorf("upload failed: %w", streamErr) + } + + now := time.Now().UTC() + backup.UploadCompletedAt = &now + backup.BackupSizeMb = float64(sizeBytes) / (1024 * 1024) + + if err := s.backupRepository.Save(backup); err != nil { + return uuid.Nil, fmt.Errorf("failed to update backup after upload: %w", err) + } + + return backup.ID, nil +} + +// FinalizeBasebackup completes a previously uploaded basebackup (Phase 2). +// Sets WAL segment names and marks the backup as COMPLETED, or marks it FAILED if errorMsg is provided. +func (s *PostgreWalBackupService) FinalizeBasebackup( + database *databases.Database, + backupID uuid.UUID, + startSegment string, + stopSegment string, + errorMsg *string, +) error { + if err := s.validateWalBackupType(database); err != nil { + return err + } + + backup, err := s.backupRepository.FindByID(backupID) + if err != nil { + return fmt.Errorf("backup not found: %w", err) + } + + if backup.DatabaseID != database.ID { + return fmt.Errorf("backup does not belong to this database") + } + + if backup.Status != backups_core.BackupStatusInProgress || backup.UploadCompletedAt == nil { + return fmt.Errorf("backup is not awaiting finalization") + } + + if errorMsg != nil { + s.markFailed(backup, *errorMsg) + + return nil + } + + backup.PgFullBackupWalStartSegmentName = &startSegment + backup.PgFullBackupWalStopSegmentName = &stopSegment + backup.Status = backups_core.BackupStatusCompleted + + if err := s.backupRepository.Save(backup); err != nil { + return fmt.Errorf("failed to finalize backup: %w", err) + } + + return nil } func (s *PostgreWalBackupService) GetRestorePlan( @@ -299,97 +364,97 @@ func (s *PostgreWalBackupService) ReportError( return nil } -func (s *PostgreWalBackupService) validateWalChain( - databaseID uuid.UUID, - incomingSegment string, - walSegmentSizeBytes int64, -) (*backups_dto.UploadGapResponse, error) { - fullBackup, err := s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(databaseID) +// IsWalChainValid checks whether the WAL chain is continuous since the last completed full backup. +func (s *PostgreWalBackupService) IsWalChainValid( + database *databases.Database, +) (*backups_dto.IsWalChainValidResponse, error) { + if err := s.validateWalBackupType(database); err != nil { + return nil, err + } + + fullBackup, err := s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(database.ID) if err != nil { return nil, fmt.Errorf("failed to query full backup: %w", err) } - // No full backup exists yet: cannot accept WAL segments without a chain anchor. if fullBackup == nil || fullBackup.PgFullBackupWalStopSegmentName == nil { - return &backups_dto.UploadGapResponse{ - Error: "no_full_backup", - ExpectedSegmentName: "", - ReceivedSegmentName: incomingSegment, + return &backups_dto.IsWalChainValidResponse{ + IsValid: false, + Error: "no_full_backup", }, nil } - stopSegment := *fullBackup.PgFullBackupWalStopSegmentName + startSegment := "" + if fullBackup.PgFullBackupWalStartSegmentName != nil { + startSegment = *fullBackup.PgFullBackupWalStartSegmentName + } - lastWal, err := s.backupRepository.FindLastWalSegmentAfter(databaseID, stopSegment) + walSegments, err := s.backupRepository.FindCompletedWalSegmentsAfter(database.ID, startSegment) if err != nil { - return nil, fmt.Errorf("failed to query last WAL segment: %w", err) + return nil, fmt.Errorf("failed to query WAL segments: %w", err) } - walCalculator := util_wal.NewWalCalculator(walSegmentSizeBytes) - - var chainTail string - if lastWal != nil && lastWal.PgWalSegmentName != nil { - chainTail = *lastWal.PgWalSegmentName - } else { - chainTail = stopSegment - } - - expectedNext, err := walCalculator.NextSegment(chainTail) - if err != nil { - return nil, fmt.Errorf("WAL arithmetic failed for %q: %w", chainTail, err) - } - - if incomingSegment != expectedNext { - return &backups_dto.UploadGapResponse{ - Error: "gap_detected", - ExpectedSegmentName: expectedNext, - ReceivedSegmentName: incomingSegment, + chainErr := s.validateRestoreWalChain(fullBackup, walSegments) + if chainErr != nil { + return &backups_dto.IsWalChainValidResponse{ + IsValid: false, + Error: chainErr.Error, + LastContiguousSegment: chainErr.LastContiguousSegment, }, nil } - return nil, nil + return &backups_dto.IsWalChainValidResponse{ + IsValid: true, + }, nil } -func (s *PostgreWalBackupService) createBackupRecord( +func (s *PostgreWalBackupService) createBasebackupRecord( databaseID uuid.UUID, storageID uuid.UUID, - uploadType backups_core.PgWalUploadType, dbName string, - walSegmentName string, - fullBackupWalStartSegment string, - fullBackupWalStopSegment string, encryption backups_config.BackupEncryption, ) *backups_core.Backup { now := time.Now().UTC() + walBackupType := backups_core.PgWalBackupTypeFullBackup backup := &backups_core.Backup{ - ID: uuid.New(), - DatabaseID: databaseID, - StorageID: storageID, - Status: backups_core.BackupStatusInProgress, - Encryption: encryption, - CreatedAt: now, + ID: uuid.New(), + DatabaseID: databaseID, + StorageID: storageID, + Status: backups_core.BackupStatusInProgress, + PgWalBackupType: &walBackupType, + Encryption: encryption, + CreatedAt: now, } backup.GenerateFilename(dbName) - if uploadType == backups_core.PgWalUploadTypeBasebackup { - walBackupType := backups_core.PgWalBackupTypeFullBackup - backup.PgWalBackupType = &walBackupType + return backup +} - if fullBackupWalStartSegment != "" { - backup.PgFullBackupWalStartSegmentName = &fullBackupWalStartSegment - } +func (s *PostgreWalBackupService) createWalSegmentRecord( + databaseID uuid.UUID, + storageID uuid.UUID, + dbName string, + walSegmentName string, + encryption backups_config.BackupEncryption, +) *backups_core.Backup { + now := time.Now().UTC() + walBackupType := backups_core.PgWalBackupTypeWalSegment - if fullBackupWalStopSegment != "" { - backup.PgFullBackupWalStopSegmentName = &fullBackupWalStopSegment - } - } else { - walBackupType := backups_core.PgWalBackupTypeWalSegment - backup.PgWalBackupType = &walBackupType - backup.PgWalSegmentName = &walSegmentName + backup := &backups_core.Backup{ + ID: uuid.New(), + DatabaseID: databaseID, + StorageID: storageID, + Status: backups_core.BackupStatusInProgress, + PgWalBackupType: &walBackupType, + PgWalSegmentName: &walSegmentName, + Encryption: encryption, + CreatedAt: now, } + backup.GenerateFilename(dbName) + return backup } diff --git a/backend/migrations/20260316151706_add_upload_completed_at.sql b/backend/migrations/20260316151706_add_upload_completed_at.sql new file mode 100644 index 0000000..d14b332 --- /dev/null +++ b/backend/migrations/20260316151706_add_upload_completed_at.sql @@ -0,0 +1,7 @@ +-- +goose Up +ALTER TABLE backups + ADD COLUMN upload_completed_at TIMESTAMPTZ; + +-- +goose Down +ALTER TABLE backups + DROP COLUMN upload_completed_at;