Merge pull request #446 from databasus/develop

Develop
This commit is contained in:
Rostislav Dugin
2026-03-17 14:39:24 +03:00
committed by GitHub
49 changed files with 4458 additions and 462 deletions

View File

@@ -73,6 +73,10 @@ This is NOT a strict set of rules, but a set of recommendations to help you writ
- Patch the answer accordingly
- Verify edge cases are handled
7. **Fix the reason, not the symptom:**
- If you find a bug or issue, ask "Why did this happen?" and fix the root cause
- Avoid quick fixes that don't address underlying problems
### Application guidelines:
**Scale your response to the task:**
@@ -88,6 +92,36 @@ This is NOT a strict set of rules, but a set of recommendations to help you writ
## Backend guidelines
### Naming
Variables and functions naming are the most important part of code readability. Always choose descriptive and meaningful names that clearly indicate the purpose and intent of the code.
Avoid abbreviations, unless they are widely accepted and unambiguous (e.g., `ID`, `URL`, `HTTP`). Use consistent naming conventions across the codebase.
Do not use one-two letters. For example:
Bad:
```
u := users.getUser()
pr, pw := io.Pipe()
r := bufio.NewReader(pr)
```
Good:
```
user := users.GetUser()
pipeReader, pipeWriter := io.Pipe()
bufferedReader := bufio.NewReader(pipeReader)
```
Exclusion: widely used variables like "db", "ctx", "req", "res", etc.
### Code style
**Always place private methods to the bottom of file**

4
agent/.gitignore vendored
View File

@@ -21,4 +21,6 @@ cmd.exe
temp/
valkey-data/
victoria-logs-data/
databasus.json
databasus.json
.test-tmp/
databasus.log

View File

@@ -1,14 +1,17 @@
package main
import (
"errors"
"flag"
"fmt"
"log/slog"
"os"
"path/filepath"
"strings"
"syscall"
"databasus-agent/internal/config"
"databasus-agent/internal/features/api"
"databasus-agent/internal/features/start"
"databasus-agent/internal/features/upgrade"
"databasus-agent/internal/logger"
@@ -25,6 +28,8 @@ func main() {
switch os.Args[1] {
case "start":
runStart(os.Args[2:])
case "_run":
runDaemon(os.Args[2:])
case "stop":
runStop()
case "status":
@@ -43,7 +48,6 @@ func main() {
func runStart(args []string) {
fs := flag.NewFlagSet("start", flag.ExitOnError)
isDebug := fs.Bool("debug", false, "Enable debug logging")
isSkipUpdate := fs.Bool("skip-update", false, "Skip auto-update check")
cfg := &config.Config{}
@@ -53,26 +57,59 @@ func runStart(args []string) {
fmt.Fprintf(os.Stderr, "Failed to save config: %v\n", err)
}
logger.Init(*isDebug)
log := logger.GetLogger()
isDev := checkIsDevelopment()
runUpdateCheck(cfg.DatabasusHost, *isSkipUpdate, isDev, log)
if err := start.Run(cfg, log); err != nil {
if err := start.Start(cfg, Version, isDev, log); err != nil {
if errors.Is(err, upgrade.ErrUpgradeRestart) {
reexecAfterUpgrade(log)
}
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
}
func runDaemon(args []string) {
fs := flag.NewFlagSet("_run", flag.ExitOnError)
if err := fs.Parse(args); err != nil {
os.Exit(1)
}
log := logger.GetLogger()
cfg := &config.Config{}
cfg.LoadFromJSON()
if err := start.RunDaemon(cfg, Version, checkIsDevelopment(), log); err != nil {
if errors.Is(err, upgrade.ErrUpgradeRestart) {
reexecAfterUpgrade(log)
}
log.Error("Agent exited with error", "error", err)
os.Exit(1)
}
}
func runStop() {
logger.Init(false)
logger.GetLogger().Info("stop: stub — not yet implemented")
log := logger.GetLogger()
if err := start.Stop(log); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
}
func runStatus() {
logger.Init(false)
logger.GetLogger().Info("status: stub — not yet implemented")
log := logger.GetLogger()
if err := start.Status(log); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
}
func runRestore(args []string) {
@@ -82,7 +119,6 @@ func runRestore(args []string) {
backupID := fs.String("backup-id", "", "Full backup UUID (optional)")
targetTime := fs.String("target-time", "", "PITR target time in RFC3339 (optional)")
isYes := fs.Bool("yes", false, "Skip confirmation prompt")
isDebug := fs.Bool("debug", false, "Enable debug logging")
isSkipUpdate := fs.Bool("skip-update", false, "Skip auto-update check")
cfg := &config.Config{}
@@ -92,7 +128,6 @@ func runRestore(args []string) {
fmt.Fprintf(os.Stderr, "Failed to save config: %v\n", err)
}
logger.Init(*isDebug)
log := logger.GetLogger()
isDev := checkIsDevelopment()
@@ -126,10 +161,17 @@ func runUpdateCheck(host string, isSkipUpdate, isDev bool, log *slog.Logger) {
return
}
if err := upgrade.CheckAndUpdate(host, Version, isDev, log); err != nil {
apiClient := api.NewClient(host, "", log)
isUpgraded, err := upgrade.CheckAndUpdate(apiClient, Version, isDev, log)
if err != nil {
log.Error("Auto-update failed", "error", err)
os.Exit(1)
}
if isUpgraded {
reexecAfterUpgrade(log)
}
}
func checkIsDevelopment() bool {
@@ -168,3 +210,18 @@ func parseEnvMode(data []byte) bool {
return false
}
func reexecAfterUpgrade(log *slog.Logger) {
selfPath, err := os.Executable()
if err != nil {
log.Error("Failed to resolve executable for re-exec", "error", err)
os.Exit(1)
}
log.Info("Re-executing after upgrade...")
if err := syscall.Exec(selfPath, os.Args, os.Environ()); err != nil {
log.Error("Failed to re-exec after upgrade", "error", err)
os.Exit(1)
}
}

View File

@@ -24,6 +24,7 @@ func main() {
http.HandleFunc("/api/v1/system/version", s.handleVersion)
http.HandleFunc("/api/v1/system/agent", s.handleAgentDownload)
http.HandleFunc("/mock/set-version", s.handleSetVersion)
http.HandleFunc("/mock/set-binary-path", s.handleSetBinaryPath)
http.HandleFunc("/health", s.handleHealth)
addr := ":" + port
@@ -78,6 +79,29 @@ func (s *server) handleSetVersion(w http.ResponseWriter, r *http.Request) {
_, _ = fmt.Fprintf(w, "version set to %s", body.Version)
}
func (s *server) handleSetBinaryPath(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "POST only", http.StatusMethodNotAllowed)
return
}
var body struct {
BinaryPath string `json:"binaryPath"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
s.mu.Lock()
s.binaryPath = body.BinaryPath
s.mu.Unlock()
log.Printf("POST /mock/set-binary-path -> %s", body.BinaryPath)
_, _ = fmt.Fprintf(w, "binary path set to %s", body.BinaryPath)
}
func (s *server) handleHealth(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))

View File

@@ -27,11 +27,12 @@ run_test() {
if [ "$MODE" = "host" ]; then
run_test "Test 1: Upgrade success (v1 -> v2)" "$SCRIPT_DIR/test-upgrade-success.sh"
run_test "Test 2: Upgrade skip (version matches)" "$SCRIPT_DIR/test-upgrade-skip.sh"
run_test "Test 3: pg_basebackup in PATH" "$SCRIPT_DIR/test-pg-host-path.sh"
run_test "Test 4: pg_basebackup via bindir" "$SCRIPT_DIR/test-pg-host-bindir.sh"
run_test "Test 3: Background upgrade (v1 -> v2 while running)" "$SCRIPT_DIR/test-upgrade-background.sh"
run_test "Test 4: pg_basebackup in PATH" "$SCRIPT_DIR/test-pg-host-path.sh"
run_test "Test 5: pg_basebackup via bindir" "$SCRIPT_DIR/test-pg-host-bindir.sh"
elif [ "$MODE" = "docker" ]; then
run_test "Test 5: pg_basebackup via docker exec" "$SCRIPT_DIR/test-pg-docker-exec.sh"
run_test "Test 6: pg_basebackup via docker exec" "$SCRIPT_DIR/test-pg-docker-exec.sh"
else
echo "Unknown mode: $MODE (expected 'host' or 'docker')"

View File

@@ -5,6 +5,16 @@ ARTIFACTS="/opt/agent/artifacts"
AGENT="/tmp/test-agent"
PG_CONTAINER="e2e-agent-postgres"
# Cleanup from previous runs
pkill -f "test-agent" 2>/dev/null || true
for i in $(seq 1 20); do
pgrep -f "test-agent" > /dev/null 2>&1 || break
sleep 0.5
done
pkill -9 -f "test-agent" 2>/dev/null || true
sleep 0.5
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
# Copy agent binary
cp "$ARTIFACTS/agent-v1" "$AGENT"
chmod +x "$AGENT"
@@ -26,7 +36,7 @@ OUTPUT=$("$AGENT" start \
--pg-port 5432 \
--pg-user testuser \
--pg-password testpassword \
--wal-dir /tmp/wal \
--pg-wal-dir /tmp/wal \
--pg-type docker \
--pg-docker-container-name "$PG_CONTAINER" 2>&1)

View File

@@ -5,6 +5,16 @@ ARTIFACTS="/opt/agent/artifacts"
AGENT="/tmp/test-agent"
CUSTOM_BIN_DIR="/opt/pg/bin"
# Cleanup from previous runs
pkill -f "test-agent" 2>/dev/null || true
for i in $(seq 1 20); do
pgrep -f "test-agent" > /dev/null 2>&1 || break
sleep 0.5
done
pkill -9 -f "test-agent" 2>/dev/null || true
sleep 0.5
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
# Copy agent binary
cp "$ARTIFACTS/agent-v1" "$AGENT"
chmod +x "$AGENT"
@@ -32,7 +42,7 @@ OUTPUT=$("$AGENT" start \
--pg-port 5432 \
--pg-user testuser \
--pg-password testpassword \
--wal-dir /tmp/wal \
--pg-wal-dir /tmp/wal \
--pg-type host \
--pg-host-bin-dir "$CUSTOM_BIN_DIR" 2>&1)

View File

@@ -4,6 +4,16 @@ set -euo pipefail
ARTIFACTS="/opt/agent/artifacts"
AGENT="/tmp/test-agent"
# Cleanup from previous runs
pkill -f "test-agent" 2>/dev/null || true
for i in $(seq 1 20); do
pgrep -f "test-agent" > /dev/null 2>&1 || break
sleep 0.5
done
pkill -9 -f "test-agent" 2>/dev/null || true
sleep 0.5
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
# Copy agent binary
cp "$ARTIFACTS/agent-v1" "$AGENT"
chmod +x "$AGENT"
@@ -25,7 +35,7 @@ OUTPUT=$("$AGENT" start \
--pg-port 5432 \
--pg-user testuser \
--pg-password testpassword \
--wal-dir /tmp/wal \
--pg-wal-dir /tmp/wal \
--pg-type host 2>&1)
EXIT_CODE=$?

View File

@@ -0,0 +1,90 @@
#!/bin/bash
set -euo pipefail
ARTIFACTS="/opt/agent/artifacts"
AGENT="/tmp/test-agent"
# Cleanup from previous runs
pkill -f "test-agent" 2>/dev/null || true
for i in $(seq 1 20); do
pgrep -f "test-agent" > /dev/null 2>&1 || break
sleep 0.5
done
pkill -9 -f "test-agent" 2>/dev/null || true
sleep 0.5
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
# Set mock server to v1.0.0 (same as agent — no sync upgrade on start)
curl -sf -X POST http://e2e-mock-server:4050/mock/set-version \
-H "Content-Type: application/json" \
-d '{"version":"v1.0.0"}'
curl -sf -X POST http://e2e-mock-server:4050/mock/set-binary-path \
-H "Content-Type: application/json" \
-d '{"binaryPath":"/artifacts/agent-v1"}'
# Copy v1 binary to writable location
cp "$ARTIFACTS/agent-v1" "$AGENT"
chmod +x "$AGENT"
# Verify initial version
VERSION=$("$AGENT" version)
if [ "$VERSION" != "v1.0.0" ]; then
echo "FAIL: Expected initial version v1.0.0, got $VERSION"
exit 1
fi
echo "Initial version: $VERSION"
# Start agent as daemon (versions match → no sync upgrade)
mkdir -p /tmp/wal
"$AGENT" start \
--databasus-host http://e2e-mock-server:4050 \
--db-id test-db-id \
--token test-token \
--pg-host e2e-postgres \
--pg-port 5432 \
--pg-user testuser \
--pg-password testpassword \
--pg-wal-dir /tmp/wal \
--pg-type host
echo "Agent started as daemon, waiting for stabilization..."
sleep 2
# Change mock server to v2.0.0 and point to v2 binary
curl -sf -X POST http://e2e-mock-server:4050/mock/set-version \
-H "Content-Type: application/json" \
-d '{"version":"v2.0.0"}'
curl -sf -X POST http://e2e-mock-server:4050/mock/set-binary-path \
-H "Content-Type: application/json" \
-d '{"binaryPath":"/artifacts/agent-v2"}'
echo "Mock server updated to v2.0.0, waiting for background upgrade..."
# Poll for upgrade (timeout 60s, poll every 3s)
DEADLINE=$((SECONDS + 60))
while [ $SECONDS -lt $DEADLINE ]; do
VERSION=$("$AGENT" version)
if [ "$VERSION" = "v2.0.0" ]; then
echo "Binary upgraded to $VERSION"
break
fi
sleep 3
done
VERSION=$("$AGENT" version)
if [ "$VERSION" != "v2.0.0" ]; then
echo "FAIL: Expected v2.0.0 after background upgrade, got $VERSION"
cat databasus.log 2>/dev/null || true
exit 1
fi
# Verify agent is still running after restart
sleep 2
"$AGENT" status || true
# Cleanup
"$AGENT" stop || true
echo "Background upgrade test passed"

View File

@@ -4,6 +4,16 @@ set -euo pipefail
ARTIFACTS="/opt/agent/artifacts"
AGENT="/tmp/test-agent"
# Cleanup from previous runs
pkill -f "test-agent" 2>/dev/null || true
for i in $(seq 1 20); do
pgrep -f "test-agent" > /dev/null 2>&1 || break
sleep 0.5
done
pkill -9 -f "test-agent" 2>/dev/null || true
sleep 0.5
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
# Set mock server to return v1.0.0 (same as agent)
curl -sf -X POST http://e2e-mock-server:4050/mock/set-version \
-H "Content-Type: application/json" \
@@ -30,7 +40,7 @@ OUTPUT=$("$AGENT" start \
--pg-port 5432 \
--pg-user testuser \
--pg-password testpassword \
--wal-dir /tmp/wal \
--pg-wal-dir /tmp/wal \
--pg-type host 2>&1) || true
echo "$OUTPUT"

View File

@@ -4,11 +4,25 @@ set -euo pipefail
ARTIFACTS="/opt/agent/artifacts"
AGENT="/tmp/test-agent"
# Ensure mock server returns v2.0.0
# Cleanup from previous runs
pkill -f "test-agent" 2>/dev/null || true
for i in $(seq 1 20); do
pgrep -f "test-agent" > /dev/null 2>&1 || break
sleep 0.5
done
pkill -9 -f "test-agent" 2>/dev/null || true
sleep 0.5
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
# Ensure mock server returns v2.0.0 and serves v2 binary
curl -sf -X POST http://e2e-mock-server:4050/mock/set-version \
-H "Content-Type: application/json" \
-d '{"version":"v2.0.0"}'
curl -sf -X POST http://e2e-mock-server:4050/mock/set-binary-path \
-H "Content-Type: application/json" \
-d '{"binaryPath":"/artifacts/agent-v2"}'
# Copy v1 binary to writable location
cp "$ARTIFACTS/agent-v1" "$AGENT"
chmod +x "$AGENT"
@@ -37,7 +51,7 @@ OUTPUT=$("$AGENT" start \
--pg-port 5432 \
--pg-user testuser \
--pg-password testpassword \
--wal-dir /tmp/wal \
--pg-wal-dir /tmp/wal \
--pg-type host 2>&1) || true
echo "$OUTPUT"

View File

@@ -3,7 +3,9 @@ module databasus-agent
go 1.26.1
require (
github.com/go-resty/resty/v2 v2.17.2
github.com/jackc/pgx/v5 v5.8.0
github.com/klauspost/compress v1.18.4
github.com/stretchr/testify v1.11.1
)
@@ -14,6 +16,7 @@ require (
github.com/kr/text v0.2.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.14.1 // indirect
golang.org/x/net v0.43.0 // indirect
golang.org/x/text v0.29.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -2,6 +2,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-resty/resty/v2 v2.17.2 h1:FQW5oHYcIlkCNrMD2lloGScxcHJ0gkjshV3qcQAyHQk=
github.com/go-resty/resty/v2 v2.17.2/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -10,6 +12,8 @@ github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo=
github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c=
github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
@@ -23,10 +27,14 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=

View File

@@ -24,7 +24,7 @@ type Config struct {
PgType string `json:"pgType"`
PgHostBinDir string `json:"pgHostBinDir"`
PgDockerContainerName string `json:"pgDockerContainerName"`
WalDir string `json:"walDir"`
PgWalDir string `json:"pgWalDir"`
IsDeleteWalAfterUpload *bool `json:"deleteWalAfterUpload"`
flags parsedFlags
@@ -51,7 +51,7 @@ func (c *Config) LoadFromJSONAndArgs(fs *flag.FlagSet, args []string) {
c.flags.pgType = fs.String("pg-type", "", "PostgreSQL type: host or docker")
c.flags.pgHostBinDir = fs.String("pg-host-bin-dir", "", "Path to PG bin directory (host mode)")
c.flags.pgDockerContainerName = fs.String("pg-docker-container-name", "", "Docker container name (docker mode)")
c.flags.walDir = fs.String("wal-dir", "", "Path to WAL queue directory")
c.flags.pgWalDir = fs.String("pg-wal-dir", "", "Path to WAL queue directory")
if err := fs.Parse(args); err != nil {
os.Exit(1)
@@ -73,6 +73,11 @@ func (c *Config) SaveToJSON() error {
return os.WriteFile(configFileName, data, 0o644)
}
func (c *Config) LoadFromJSON() {
c.loadFromJSON()
c.applyDefaults()
}
func (c *Config) loadFromJSON() {
data, err := os.ReadFile(configFileName)
if err != nil {
@@ -122,7 +127,7 @@ func (c *Config) initSources() {
"pg-type": "not configured",
"pg-host-bin-dir": "not configured",
"pg-docker-container-name": "not configured",
"wal-dir": "not configured",
"pg-wal-dir": "not configured",
"delete-wal-after-upload": "not configured",
}
@@ -164,8 +169,8 @@ func (c *Config) initSources() {
c.flags.sources["pg-docker-container-name"] = configFileName
}
if c.WalDir != "" {
c.flags.sources["wal-dir"] = configFileName
if c.PgWalDir != "" {
c.flags.sources["pg-wal-dir"] = configFileName
}
// IsDeleteWalAfterUpload always has a value after applyDefaults
@@ -223,9 +228,9 @@ func (c *Config) applyFlags() {
c.flags.sources["pg-docker-container-name"] = "command line args"
}
if c.flags.walDir != nil && *c.flags.walDir != "" {
c.WalDir = *c.flags.walDir
c.flags.sources["wal-dir"] = "command line args"
if c.flags.pgWalDir != nil && *c.flags.pgWalDir != "" {
c.PgWalDir = *c.flags.pgWalDir
c.flags.sources["pg-wal-dir"] = "command line args"
}
}
@@ -246,7 +251,7 @@ func (c *Config) logConfigSources() {
"source",
c.flags.sources["pg-docker-container-name"],
)
log.Info("wal-dir", "value", c.WalDir, "source", c.flags.sources["wal-dir"])
log.Info("pg-wal-dir", "value", c.PgWalDir, "source", c.flags.sources["pg-wal-dir"])
log.Info(
"delete-wal-after-upload",
"value",

View File

@@ -142,7 +142,7 @@ func Test_LoadFromJSONAndArgs_PgFieldsLoadedFromJSON(t *testing.T) {
PgType: "docker",
PgHostBinDir: "/usr/bin",
PgDockerContainerName: "pg-container",
WalDir: "/opt/wal",
PgWalDir: "/opt/wal",
IsDeleteWalAfterUpload: &deleteWal,
})
@@ -157,7 +157,7 @@ func Test_LoadFromJSONAndArgs_PgFieldsLoadedFromJSON(t *testing.T) {
assert.Equal(t, "docker", cfg.PgType)
assert.Equal(t, "/usr/bin", cfg.PgHostBinDir)
assert.Equal(t, "pg-container", cfg.PgDockerContainerName)
assert.Equal(t, "/opt/wal", cfg.WalDir)
assert.Equal(t, "/opt/wal", cfg.PgWalDir)
assert.Equal(t, false, *cfg.IsDeleteWalAfterUpload)
}
@@ -174,7 +174,7 @@ func Test_LoadFromJSONAndArgs_PgFieldsLoadedFromArgs(t *testing.T) {
"--pg-type", "docker",
"--pg-host-bin-dir", "/custom/bin",
"--pg-docker-container-name", "my-pg",
"--wal-dir", "/var/wal",
"--pg-wal-dir", "/var/wal",
})
assert.Equal(t, "arg-pg-host", cfg.PgHost)
@@ -184,17 +184,17 @@ func Test_LoadFromJSONAndArgs_PgFieldsLoadedFromArgs(t *testing.T) {
assert.Equal(t, "docker", cfg.PgType)
assert.Equal(t, "/custom/bin", cfg.PgHostBinDir)
assert.Equal(t, "my-pg", cfg.PgDockerContainerName)
assert.Equal(t, "/var/wal", cfg.WalDir)
assert.Equal(t, "/var/wal", cfg.PgWalDir)
}
func Test_LoadFromJSONAndArgs_PgArgsOverrideJSON(t *testing.T) {
dir := setupTempDir(t)
writeConfigJSON(t, dir, Config{
PgHost: "json-host",
PgPort: 5432,
PgUser: "json-user",
PgType: "host",
WalDir: "/json/wal",
PgHost: "json-host",
PgPort: 5432,
PgUser: "json-user",
PgType: "host",
PgWalDir: "/json/wal",
})
cfg := &Config{}
@@ -205,7 +205,7 @@ func Test_LoadFromJSONAndArgs_PgArgsOverrideJSON(t *testing.T) {
"--pg-user", "arg-user",
"--pg-type", "docker",
"--pg-docker-container-name", "my-container",
"--wal-dir", "/arg/wal",
"--pg-wal-dir", "/arg/wal",
})
assert.Equal(t, "arg-host", cfg.PgHost)
@@ -213,7 +213,7 @@ func Test_LoadFromJSONAndArgs_PgArgsOverrideJSON(t *testing.T) {
assert.Equal(t, "arg-user", cfg.PgUser)
assert.Equal(t, "docker", cfg.PgType)
assert.Equal(t, "my-container", cfg.PgDockerContainerName)
assert.Equal(t, "/arg/wal", cfg.WalDir)
assert.Equal(t, "/arg/wal", cfg.PgWalDir)
}
func Test_LoadFromJSONAndArgs_DefaultsApplied_WhenNoJSONAndNoArgs(t *testing.T) {
@@ -244,7 +244,7 @@ func Test_SaveToJSON_PgFieldsSavedCorrectly(t *testing.T) {
PgType: "docker",
PgHostBinDir: "/usr/bin",
PgDockerContainerName: "pg-container",
WalDir: "/opt/wal",
PgWalDir: "/opt/wal",
IsDeleteWalAfterUpload: &deleteWal,
}
@@ -260,7 +260,7 @@ func Test_SaveToJSON_PgFieldsSavedCorrectly(t *testing.T) {
assert.Equal(t, "docker", saved.PgType)
assert.Equal(t, "/usr/bin", saved.PgHostBinDir)
assert.Equal(t, "pg-container", saved.PgDockerContainerName)
assert.Equal(t, "/opt/wal", saved.WalDir)
assert.Equal(t, "/opt/wal", saved.PgWalDir)
require.NotNil(t, saved.IsDeleteWalAfterUpload)
assert.Equal(t, false, *saved.IsDeleteWalAfterUpload)
}

View File

@@ -11,7 +11,7 @@ type parsedFlags struct {
pgType *string
pgHostBinDir *string
pgDockerContainerName *string
walDir *string
pgWalDir *string
sources map[string]string
}

View File

@@ -0,0 +1,288 @@
package api
import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"time"
"github.com/go-resty/resty/v2"
)
const (
chainValidPath = "/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup"
nextBackupTimePath = "/api/v1/backups/postgres/wal/next-full-backup-time"
walUploadPath = "/api/v1/backups/postgres/wal/upload/wal"
fullStartPath = "/api/v1/backups/postgres/wal/upload/full-start"
fullCompletePath = "/api/v1/backups/postgres/wal/upload/full-complete"
reportErrorPath = "/api/v1/backups/postgres/wal/error"
versionPath = "/api/v1/system/version"
agentBinaryPath = "/api/v1/system/agent"
apiCallTimeout = 30 * time.Second
maxRetryAttempts = 3
retryBaseDelay = 1 * time.Second
)
type Client struct {
json *resty.Client
stream *resty.Client
host string
log *slog.Logger
}
func NewClient(host, token string, log *slog.Logger) *Client {
setAuth := func(_ *resty.Client, req *resty.Request) error {
if token != "" {
req.SetHeader("Authorization", token)
}
return nil
}
jsonClient := resty.New().
SetTimeout(apiCallTimeout).
SetRetryCount(maxRetryAttempts - 1).
SetRetryWaitTime(retryBaseDelay).
SetRetryMaxWaitTime(4 * retryBaseDelay).
AddRetryCondition(func(resp *resty.Response, err error) bool {
return err != nil || resp.StatusCode() >= 500
}).
OnBeforeRequest(setAuth)
streamClient := resty.New().
OnBeforeRequest(setAuth)
return &Client{
json: jsonClient,
stream: streamClient,
host: host,
log: log,
}
}
func (c *Client) CheckWalChainValidity(ctx context.Context) (*WalChainValidityResponse, error) {
var resp WalChainValidityResponse
httpResp, err := c.json.R().
SetContext(ctx).
SetResult(&resp).
Get(c.buildURL(chainValidPath))
if err != nil {
return nil, err
}
if err := c.checkResponse(httpResp, "check WAL chain validity"); err != nil {
return nil, err
}
return &resp, nil
}
func (c *Client) GetNextFullBackupTime(ctx context.Context) (*NextFullBackupTimeResponse, error) {
var resp NextFullBackupTimeResponse
httpResp, err := c.json.R().
SetContext(ctx).
SetResult(&resp).
Get(c.buildURL(nextBackupTimePath))
if err != nil {
return nil, err
}
if err := c.checkResponse(httpResp, "get next full backup time"); err != nil {
return nil, err
}
return &resp, nil
}
func (c *Client) ReportBackupError(ctx context.Context, errMsg string) error {
httpResp, err := c.json.R().
SetContext(ctx).
SetBody(reportErrorRequest{Error: errMsg}).
Post(c.buildURL(reportErrorPath))
if err != nil {
return err
}
return c.checkResponse(httpResp, "report backup error")
}
func (c *Client) UploadBasebackup(
ctx context.Context,
body io.Reader,
) (*UploadBasebackupResponse, error) {
resp, err := c.stream.R().
SetContext(ctx).
SetBody(body).
SetHeader("Content-Type", "application/octet-stream").
SetDoNotParseResponse(true).
Post(c.buildURL(fullStartPath))
if err != nil {
return nil, fmt.Errorf("upload request: %w", err)
}
defer func() { _ = resp.RawBody().Close() }()
if resp.StatusCode() != http.StatusOK {
respBody, _ := io.ReadAll(resp.RawBody())
return nil, fmt.Errorf("upload failed with status %d: %s", resp.StatusCode(), string(respBody))
}
var result UploadBasebackupResponse
if err := json.NewDecoder(resp.RawBody()).Decode(&result); err != nil {
return nil, fmt.Errorf("decode upload response: %w", err)
}
return &result, nil
}
func (c *Client) FinalizeBasebackup(
ctx context.Context,
backupID string,
startSegment string,
stopSegment string,
) error {
resp, err := c.json.R().
SetContext(ctx).
SetBody(finalizeBasebackupRequest{
BackupID: backupID,
StartSegment: startSegment,
StopSegment: stopSegment,
}).
Post(c.buildURL(fullCompletePath))
if err != nil {
return fmt.Errorf("finalize request: %w", err)
}
if resp.StatusCode() != http.StatusOK {
return fmt.Errorf("finalize failed with status %d: %s", resp.StatusCode(), resp.String())
}
return nil
}
func (c *Client) FinalizeBasebackupWithError(
ctx context.Context,
backupID string,
errMsg string,
) error {
resp, err := c.json.R().
SetContext(ctx).
SetBody(finalizeBasebackupRequest{
BackupID: backupID,
Error: &errMsg,
}).
Post(c.buildURL(fullCompletePath))
if err != nil {
return fmt.Errorf("finalize-with-error request: %w", err)
}
if resp.StatusCode() != http.StatusOK {
return fmt.Errorf("finalize-with-error failed with status %d: %s", resp.StatusCode(), resp.String())
}
return nil
}
func (c *Client) UploadWalSegment(
ctx context.Context,
segmentName string,
body io.Reader,
) (*UploadWalSegmentResult, error) {
resp, err := c.stream.R().
SetContext(ctx).
SetBody(body).
SetHeader("Content-Type", "application/octet-stream").
SetHeader("X-Wal-Segment-Name", segmentName).
SetDoNotParseResponse(true).
Post(c.buildURL(walUploadPath))
if err != nil {
return nil, fmt.Errorf("upload request: %w", err)
}
defer func() { _ = resp.RawBody().Close() }()
switch resp.StatusCode() {
case http.StatusNoContent:
return &UploadWalSegmentResult{IsGapDetected: false}, nil
case http.StatusConflict:
var errResp uploadErrorResponse
if err := json.NewDecoder(resp.RawBody()).Decode(&errResp); err != nil {
return &UploadWalSegmentResult{IsGapDetected: true}, nil
}
return &UploadWalSegmentResult{
IsGapDetected: true,
ExpectedSegmentName: errResp.ExpectedSegmentName,
ReceivedSegmentName: errResp.ReceivedSegmentName,
}, nil
default:
respBody, _ := io.ReadAll(resp.RawBody())
return nil, fmt.Errorf("upload failed with status %d: %s", resp.StatusCode(), string(respBody))
}
}
func (c *Client) FetchServerVersion(ctx context.Context) (string, error) {
var ver versionResponse
httpResp, err := c.json.R().
SetContext(ctx).
SetResult(&ver).
Get(c.buildURL(versionPath))
if err != nil {
return "", err
}
if err := c.checkResponse(httpResp, "fetch server version"); err != nil {
return "", err
}
return ver.Version, nil
}
func (c *Client) DownloadAgentBinary(ctx context.Context, arch, destPath string) error {
resp, err := c.stream.R().
SetContext(ctx).
SetQueryParam("arch", arch).
SetDoNotParseResponse(true).
Get(c.buildURL(agentBinaryPath))
if err != nil {
return err
}
defer func() { _ = resp.RawBody().Close() }()
if resp.StatusCode() != http.StatusOK {
return fmt.Errorf("server returned %d for agent download", resp.StatusCode())
}
f, err := os.Create(destPath)
if err != nil {
return err
}
defer func() { _ = f.Close() }()
_, err = io.Copy(f, resp.RawBody())
return err
}
func (c *Client) buildURL(path string) string {
return c.host + path
}
func (c *Client) checkResponse(resp *resty.Response, method string) error {
if resp.StatusCode() >= 400 {
return fmt.Errorf("%s: server returned status %d: %s", method, resp.StatusCode(), resp.String())
}
return nil
}

View File

@@ -0,0 +1,44 @@
package api
import "time"
type WalChainValidityResponse struct {
IsValid bool `json:"isValid"`
Error string `json:"error,omitempty"`
LastContiguousSegment string `json:"lastContiguousSegment,omitempty"`
}
type NextFullBackupTimeResponse struct {
NextFullBackupTime *time.Time `json:"nextFullBackupTime"`
}
type UploadWalSegmentResult struct {
IsGapDetected bool
ExpectedSegmentName string
ReceivedSegmentName string
}
type reportErrorRequest struct {
Error string `json:"error"`
}
type versionResponse struct {
Version string `json:"version"`
}
type UploadBasebackupResponse struct {
BackupID string `json:"backupId"`
}
type finalizeBasebackupRequest struct {
BackupID string `json:"backupId"`
StartSegment string `json:"startSegment"`
StopSegment string `json:"stopSegment"`
Error *string `json:"error,omitempty"`
}
type uploadErrorResponse struct {
Error string `json:"error"`
ExpectedSegmentName string `json:"expectedSegmentName"`
ReceivedSegmentName string `json:"receivedSegmentName"`
}

View File

@@ -0,0 +1,292 @@
package full_backup
import (
"bytes"
"context"
"fmt"
"io"
"log/slog"
"os"
"os/exec"
"path/filepath"
"sync/atomic"
"time"
"github.com/klauspost/compress/zstd"
"databasus-agent/internal/config"
"databasus-agent/internal/features/api"
)
const (
checkInterval = 30 * time.Second
retryDelay = 1 * time.Minute
uploadTimeout = 30 * time.Minute
)
var retryDelayOverride *time.Duration
type CmdBuilder func(ctx context.Context) *exec.Cmd
// FullBackuper runs pg_basebackup when the WAL chain is broken or a scheduled backup is due.
//
// Every 30 seconds it checks two conditions via the Databasus API:
// 1. WAL chain validity — if broken or no full backup exists, triggers an immediate basebackup.
// 2. Scheduled backup time — if the next full backup time has passed, triggers a basebackup.
//
// Only one basebackup runs at a time (guarded by atomic bool).
// On failure the error is reported to the server and the backup retries after 1 minute, indefinitely.
// WAL segment uploads (handled by wal.Streamer) continue independently and are not paused.
//
// pg_basebackup runs as "pg_basebackup -Ft -D - -X none --verbose". Stdout (tar) is zstd-compressed
// and uploaded to the server. Stderr is parsed for WAL start/stop segment names (LSN → segment arithmetic).
type FullBackuper struct {
cfg *config.Config
apiClient *api.Client
log *slog.Logger
isRunning atomic.Bool
cmdBuilder CmdBuilder
}
func NewFullBackuper(cfg *config.Config, apiClient *api.Client, log *slog.Logger) *FullBackuper {
backuper := &FullBackuper{
cfg: cfg,
apiClient: apiClient,
log: log,
}
backuper.cmdBuilder = backuper.defaultCmdBuilder
return backuper
}
func (backuper *FullBackuper) Run(ctx context.Context) {
backuper.log.Info("Full backuper started")
backuper.checkAndRunIfNeeded(ctx)
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
backuper.log.Info("Full backuper stopping")
return
case <-ticker.C:
backuper.checkAndRunIfNeeded(ctx)
}
}
}
func (backuper *FullBackuper) checkAndRunIfNeeded(ctx context.Context) {
if backuper.isRunning.Load() {
backuper.log.Debug("Skipping check: basebackup already in progress")
return
}
chainResp, err := backuper.apiClient.CheckWalChainValidity(ctx)
if err != nil {
backuper.log.Error("Failed to check WAL chain validity", "error", err)
return
}
if !chainResp.IsValid {
backuper.log.Info("WAL chain is invalid, triggering basebackup",
"error", chainResp.Error,
"lastContiguousSegment", chainResp.LastContiguousSegment,
)
backuper.runBasebackupWithRetry(ctx)
return
}
nextTimeResp, err := backuper.apiClient.GetNextFullBackupTime(ctx)
if err != nil {
backuper.log.Error("Failed to check next full backup time", "error", err)
return
}
if nextTimeResp.NextFullBackupTime == nil || !nextTimeResp.NextFullBackupTime.After(time.Now().UTC()) {
backuper.log.Info("Scheduled full backup is due, triggering basebackup")
backuper.runBasebackupWithRetry(ctx)
return
}
backuper.log.Debug("No basebackup needed",
"nextFullBackupTime", nextTimeResp.NextFullBackupTime,
)
}
func (backuper *FullBackuper) runBasebackupWithRetry(ctx context.Context) {
if !backuper.isRunning.CompareAndSwap(false, true) {
backuper.log.Debug("Skipping basebackup: already running")
return
}
defer backuper.isRunning.Store(false)
for {
if ctx.Err() != nil {
return
}
backuper.log.Info("Starting pg_basebackup")
err := backuper.executeAndUploadBasebackup(ctx)
if err == nil {
backuper.log.Info("Basebackup completed successfully")
return
}
backuper.log.Error("Basebackup failed", "error", err)
backuper.reportError(ctx, err.Error())
delay := retryDelay
if retryDelayOverride != nil {
delay = *retryDelayOverride
}
backuper.log.Info("Retrying basebackup after delay", "delay", delay)
select {
case <-ctx.Done():
return
case <-time.After(delay):
}
}
}
func (backuper *FullBackuper) executeAndUploadBasebackup(ctx context.Context) error {
cmd := backuper.cmdBuilder(ctx)
var stderrBuf bytes.Buffer
cmd.Stderr = &stderrBuf
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("create stdout pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("start pg_basebackup: %w", err)
}
// Phase 1: Stream compressed data via io.Pipe directly to the API.
pipeReader, pipeWriter := io.Pipe()
go backuper.compressAndStream(pipeWriter, stdoutPipe)
uploadCtx, cancel := context.WithTimeout(ctx, uploadTimeout)
defer cancel()
uploadResp, uploadErr := backuper.apiClient.UploadBasebackup(uploadCtx, pipeReader)
cmdErr := cmd.Wait()
if uploadErr != nil {
return fmt.Errorf("upload basebackup: %w", uploadErr)
}
if cmdErr != nil {
errMsg := fmt.Sprintf("pg_basebackup exited with error: %v (stderr: %s)", cmdErr, stderrBuf.String())
_ = backuper.apiClient.FinalizeBasebackupWithError(ctx, uploadResp.BackupID, errMsg)
return fmt.Errorf("pg_basebackup: %w", cmdErr)
}
// Phase 2: Parse stderr for WAL segments and finalize the backup.
stderrStr := stderrBuf.String()
backuper.log.Debug("pg_basebackup stderr", "stderr", stderrStr)
startSegment, stopSegment, err := ParseBasebackupStderr(stderrStr)
if err != nil {
errMsg := fmt.Sprintf("parse pg_basebackup stderr: %v", err)
_ = backuper.apiClient.FinalizeBasebackupWithError(ctx, uploadResp.BackupID, errMsg)
return fmt.Errorf("parse pg_basebackup stderr: %w", err)
}
backuper.log.Info("Basebackup WAL segments parsed",
"startSegment", startSegment,
"stopSegment", stopSegment,
"backupId", uploadResp.BackupID,
)
if err := backuper.apiClient.FinalizeBasebackup(ctx, uploadResp.BackupID, startSegment, stopSegment); err != nil {
return fmt.Errorf("finalize basebackup: %w", err)
}
return nil
}
func (backuper *FullBackuper) compressAndStream(pipeWriter *io.PipeWriter, reader io.Reader) {
encoder, err := zstd.NewWriter(pipeWriter,
zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(5)),
zstd.WithEncoderCRC(true),
)
if err != nil {
_ = pipeWriter.CloseWithError(fmt.Errorf("create zstd encoder: %w", err))
return
}
if _, err := io.Copy(encoder, reader); err != nil {
_ = encoder.Close()
_ = pipeWriter.CloseWithError(fmt.Errorf("compress: %w", err))
return
}
if err := encoder.Close(); err != nil {
_ = pipeWriter.CloseWithError(fmt.Errorf("close encoder: %w", err))
return
}
_ = pipeWriter.Close()
}
func (backuper *FullBackuper) reportError(ctx context.Context, errMsg string) {
if err := backuper.apiClient.ReportBackupError(ctx, errMsg); err != nil {
backuper.log.Error("Failed to report error to server", "error", err)
}
}
func (backuper *FullBackuper) defaultCmdBuilder(ctx context.Context) *exec.Cmd {
switch backuper.cfg.PgType {
case "docker":
return backuper.buildDockerCmd(ctx)
default:
return backuper.buildHostCmd(ctx)
}
}
func (backuper *FullBackuper) buildHostCmd(ctx context.Context) *exec.Cmd {
binary := "pg_basebackup"
if backuper.cfg.PgHostBinDir != "" {
binary = filepath.Join(backuper.cfg.PgHostBinDir, "pg_basebackup")
}
cmd := exec.CommandContext(ctx, binary,
"-Ft", "-D", "-", "-X", "none", "--verbose",
"-h", backuper.cfg.PgHost,
"-p", fmt.Sprintf("%d", backuper.cfg.PgPort),
"-U", backuper.cfg.PgUser,
)
cmd.Env = append(os.Environ(), "PGPASSWORD="+backuper.cfg.PgPassword)
return cmd
}
func (backuper *FullBackuper) buildDockerCmd(ctx context.Context) *exec.Cmd {
cmd := exec.CommandContext(ctx, "docker", "exec",
"-e", "PGPASSWORD="+backuper.cfg.PgPassword,
"-i", backuper.cfg.PgDockerContainerName,
"pg_basebackup",
"-Ft", "-D", "-", "-X", "none", "--verbose",
"-h", backuper.cfg.PgHost,
"-p", fmt.Sprintf("%d", backuper.cfg.PgPort),
"-U", backuper.cfg.PgUser,
)
return cmd
}

View File

@@ -0,0 +1,671 @@
package full_backup
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"os/exec"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/klauspost/compress/zstd"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"databasus-agent/internal/config"
"databasus-agent/internal/features/api"
"databasus-agent/internal/logger"
)
const (
testChainValidPath = "/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup"
testNextBackupTimePath = "/api/v1/backups/postgres/wal/next-full-backup-time"
testFullStartPath = "/api/v1/backups/postgres/wal/upload/full-start"
testFullCompletePath = "/api/v1/backups/postgres/wal/upload/full-complete"
testReportErrorPath = "/api/v1/backups/postgres/wal/error"
testBackupID = "test-backup-id-1234"
)
func Test_RunFullBackup_WhenChainBroken_BasebackupTriggered(t *testing.T) {
var mu sync.Mutex
var uploadReceived bool
var uploadHeaders http.Header
var finalizeReceived bool
var finalizeBody map[string]any
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testChainValidPath:
writeJSON(w, api.WalChainValidityResponse{
IsValid: false,
Error: "wal_chain_broken",
LastContiguousSegment: "000000010000000100000011",
})
case testFullStartPath:
mu.Lock()
uploadReceived = true
uploadHeaders = r.Header.Clone()
mu.Unlock()
_, _ = io.ReadAll(r.Body)
writeJSON(w, map[string]string{"backupId": testBackupID})
case testFullCompletePath:
mu.Lock()
finalizeReceived = true
_ = json.NewDecoder(r.Body).Decode(&finalizeBody)
mu.Unlock()
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
})
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, "test-backup-data", validStderr())
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
go fb.Run(ctx)
waitForCondition(t, func() bool {
mu.Lock()
defer mu.Unlock()
return finalizeReceived
}, 5*time.Second)
cancel()
mu.Lock()
defer mu.Unlock()
assert.True(t, uploadReceived)
assert.Equal(t, "application/octet-stream", uploadHeaders.Get("Content-Type"))
assert.Equal(t, "test-token", uploadHeaders.Get("Authorization"))
assert.True(t, finalizeReceived)
assert.Equal(t, testBackupID, finalizeBody["backupId"])
assert.Equal(t, "000000010000000000000002", finalizeBody["startSegment"])
assert.Equal(t, "000000010000000000000002", finalizeBody["stopSegment"])
}
func Test_RunFullBackup_WhenScheduledBackupDue_BasebackupTriggered(t *testing.T) {
var mu sync.Mutex
var finalizeReceived bool
pastTime := time.Now().UTC().Add(-1 * time.Hour)
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testChainValidPath:
writeJSON(w, api.WalChainValidityResponse{IsValid: true})
case testNextBackupTimePath:
writeJSON(w, api.NextFullBackupTimeResponse{NextFullBackupTime: &pastTime})
case testFullStartPath:
_, _ = io.ReadAll(r.Body)
writeJSON(w, map[string]string{"backupId": testBackupID})
case testFullCompletePath:
mu.Lock()
finalizeReceived = true
mu.Unlock()
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
})
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, "scheduled-backup-data", validStderr())
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
go fb.Run(ctx)
waitForCondition(t, func() bool {
mu.Lock()
defer mu.Unlock()
return finalizeReceived
}, 5*time.Second)
cancel()
mu.Lock()
defer mu.Unlock()
assert.True(t, finalizeReceived)
}
func Test_RunFullBackup_WhenNoFullBackupExists_ImmediateBasebackupTriggered(t *testing.T) {
var mu sync.Mutex
var finalizeReceived bool
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testChainValidPath:
writeJSON(w, api.WalChainValidityResponse{
IsValid: false,
Error: "no_full_backup",
})
case testFullStartPath:
_, _ = io.ReadAll(r.Body)
writeJSON(w, map[string]string{"backupId": testBackupID})
case testFullCompletePath:
mu.Lock()
finalizeReceived = true
mu.Unlock()
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
})
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, "first-backup-data", validStderr())
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
go fb.Run(ctx)
waitForCondition(t, func() bool {
mu.Lock()
defer mu.Unlock()
return finalizeReceived
}, 5*time.Second)
cancel()
mu.Lock()
defer mu.Unlock()
assert.True(t, finalizeReceived)
}
func Test_RunFullBackup_WhenUploadFails_RetriesAfterDelay(t *testing.T) {
var mu sync.Mutex
var uploadAttempts int
var errorReported bool
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testChainValidPath:
writeJSON(w, api.WalChainValidityResponse{
IsValid: false,
Error: "no_full_backup",
})
case testFullStartPath:
_, _ = io.ReadAll(r.Body)
mu.Lock()
uploadAttempts++
attempt := uploadAttempts
mu.Unlock()
if attempt == 1 {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(`{"error":"storage unavailable"}`))
return
}
writeJSON(w, map[string]string{"backupId": testBackupID})
case testFullCompletePath:
w.WriteHeader(http.StatusOK)
case testReportErrorPath:
mu.Lock()
errorReported = true
mu.Unlock()
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
})
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, "retry-backup-data", validStderr())
origRetryDelay := retryDelay
setRetryDelay(100 * time.Millisecond)
defer setRetryDelay(origRetryDelay)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
go fb.Run(ctx)
waitForCondition(t, func() bool {
mu.Lock()
defer mu.Unlock()
return uploadAttempts >= 2
}, 10*time.Second)
cancel()
mu.Lock()
defer mu.Unlock()
assert.GreaterOrEqual(t, uploadAttempts, 2)
assert.True(t, errorReported)
}
func Test_RunFullBackup_WhenAlreadyRunning_SkipsExecution(t *testing.T) {
var mu sync.Mutex
var uploadCount int
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testChainValidPath:
writeJSON(w, api.WalChainValidityResponse{
IsValid: false,
Error: "no_full_backup",
})
case testFullStartPath:
_, _ = io.ReadAll(r.Body)
mu.Lock()
uploadCount++
mu.Unlock()
writeJSON(w, map[string]string{"backupId": testBackupID})
case testFullCompletePath:
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
})
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr())
fb.isRunning.Store(true)
fb.checkAndRunIfNeeded(context.Background())
mu.Lock()
count := uploadCount
mu.Unlock()
assert.Equal(t, 0, count, "should not trigger backup when already running")
}
func Test_RunFullBackup_WhenContextCancelled_StopsCleanly(t *testing.T) {
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testChainValidPath:
writeJSON(w, api.WalChainValidityResponse{
IsValid: false,
Error: "no_full_backup",
})
case testFullStartPath:
_, _ = io.ReadAll(r.Body)
w.WriteHeader(http.StatusInternalServerError)
case testFullCompletePath:
w.WriteHeader(http.StatusOK)
case testReportErrorPath:
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
})
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr())
origRetryDelay := retryDelay
setRetryDelay(5 * time.Second)
defer setRetryDelay(origRetryDelay)
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
done := make(chan struct{})
go func() {
fb.Run(ctx)
close(done)
}()
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("Run should have stopped after context cancellation")
}
}
func Test_RunFullBackup_WhenChainValidAndNotScheduled_NoBasebackupTriggered(t *testing.T) {
var uploadReceived atomic.Bool
futureTime := time.Now().UTC().Add(24 * time.Hour)
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testChainValidPath:
writeJSON(w, api.WalChainValidityResponse{IsValid: true})
case testNextBackupTimePath:
writeJSON(w, api.NextFullBackupTimeResponse{NextFullBackupTime: &futureTime})
case testFullStartPath:
uploadReceived.Store(true)
_, _ = io.ReadAll(r.Body)
writeJSON(w, map[string]string{"backupId": testBackupID})
case testFullCompletePath:
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
})
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr())
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
go fb.Run(ctx)
time.Sleep(500 * time.Millisecond)
cancel()
assert.False(t, uploadReceived.Load(), "should not trigger backup when chain valid and not scheduled")
}
func Test_RunFullBackup_WhenStderrParsingFails_FinalizesWithErrorAndRetries(t *testing.T) {
var mu sync.Mutex
var errorReported bool
var finalizeWithErrorReceived bool
var finalizeBody map[string]any
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testChainValidPath:
writeJSON(w, api.WalChainValidityResponse{
IsValid: false,
Error: "no_full_backup",
})
case testFullStartPath:
_, _ = io.ReadAll(r.Body)
writeJSON(w, map[string]string{"backupId": testBackupID})
case testFullCompletePath:
mu.Lock()
finalizeWithErrorReceived = true
_ = json.NewDecoder(r.Body).Decode(&finalizeBody)
mu.Unlock()
w.WriteHeader(http.StatusOK)
case testReportErrorPath:
mu.Lock()
errorReported = true
mu.Unlock()
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
})
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, "data", "pg_basebackup: unexpected output with no LSN info")
origRetryDelay := retryDelay
setRetryDelay(100 * time.Millisecond)
defer setRetryDelay(origRetryDelay)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
go fb.Run(ctx)
waitForCondition(t, func() bool {
mu.Lock()
defer mu.Unlock()
return errorReported
}, 2*time.Second)
cancel()
mu.Lock()
defer mu.Unlock()
assert.True(t, errorReported)
assert.True(t, finalizeWithErrorReceived, "should finalize with error when stderr parsing fails")
assert.Equal(t, testBackupID, finalizeBody["backupId"])
assert.NotNil(t, finalizeBody["error"], "finalize should include error message")
}
func Test_RunFullBackup_WhenNextBackupTimeNull_BasebackupTriggered(t *testing.T) {
var mu sync.Mutex
var finalizeReceived bool
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testChainValidPath:
writeJSON(w, api.WalChainValidityResponse{IsValid: true})
case testNextBackupTimePath:
writeJSON(w, api.NextFullBackupTimeResponse{NextFullBackupTime: nil})
case testFullStartPath:
_, _ = io.ReadAll(r.Body)
writeJSON(w, map[string]string{"backupId": testBackupID})
case testFullCompletePath:
mu.Lock()
finalizeReceived = true
mu.Unlock()
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
})
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, "first-run-data", validStderr())
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
go fb.Run(ctx)
waitForCondition(t, func() bool {
mu.Lock()
defer mu.Unlock()
return finalizeReceived
}, 5*time.Second)
cancel()
mu.Lock()
defer mu.Unlock()
assert.True(t, finalizeReceived)
}
func Test_RunFullBackup_WhenChainValidityReturns401_NoBasebackupTriggered(t *testing.T) {
var uploadReceived atomic.Bool
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testChainValidPath:
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"error":"invalid token"}`))
case testFullStartPath:
uploadReceived.Store(true)
_, _ = io.ReadAll(r.Body)
writeJSON(w, map[string]string{"backupId": testBackupID})
case testFullCompletePath:
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
})
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr())
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
go fb.Run(ctx)
time.Sleep(500 * time.Millisecond)
cancel()
assert.False(t, uploadReceived.Load(), "should not trigger backup when API returns 401")
}
func Test_RunFullBackup_WhenUploadSucceeds_BodyIsZstdCompressed(t *testing.T) {
var mu sync.Mutex
var receivedBody []byte
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case testChainValidPath:
writeJSON(w, api.WalChainValidityResponse{
IsValid: false,
Error: "no_full_backup",
})
case testFullStartPath:
body, _ := io.ReadAll(r.Body)
mu.Lock()
receivedBody = body
mu.Unlock()
writeJSON(w, map[string]string{"backupId": testBackupID})
case testFullCompletePath:
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
})
originalContent := "test-backup-content-for-compression-check"
fb := newTestFullBackuper(server.URL)
fb.cmdBuilder = mockCmdBuilder(t, originalContent, validStderr())
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
go fb.Run(ctx)
waitForCondition(t, func() bool {
mu.Lock()
defer mu.Unlock()
return len(receivedBody) > 0
}, 5*time.Second)
cancel()
mu.Lock()
body := receivedBody
mu.Unlock()
decoder, err := zstd.NewReader(nil)
require.NoError(t, err)
defer decoder.Close()
decompressed, err := decoder.DecodeAll(body, nil)
require.NoError(t, err)
assert.Equal(t, originalContent, string(decompressed))
}
func newTestServer(t *testing.T, handler http.HandlerFunc) *httptest.Server {
t.Helper()
server := httptest.NewServer(handler)
t.Cleanup(server.Close)
return server
}
func newTestFullBackuper(serverURL string) *FullBackuper {
cfg := &config.Config{
DatabasusHost: serverURL,
DbID: "test-db-id",
Token: "test-token",
PgHost: "localhost",
PgPort: 5432,
PgUser: "postgres",
PgPassword: "password",
PgType: "host",
}
apiClient := api.NewClient(serverURL, cfg.Token, logger.GetLogger())
return NewFullBackuper(cfg, apiClient, logger.GetLogger())
}
func mockCmdBuilder(t *testing.T, stdoutContent, stderrContent string) CmdBuilder {
t.Helper()
return func(ctx context.Context) *exec.Cmd {
cmd := exec.CommandContext(ctx, os.Args[0],
"-test.run=TestHelperProcess",
"--",
stdoutContent,
stderrContent,
)
cmd.Env = append(os.Environ(), "GO_TEST_HELPER_PROCESS=1")
return cmd
}
}
func TestHelperProcess(t *testing.T) {
if os.Getenv("GO_TEST_HELPER_PROCESS") != "1" {
return
}
args := os.Args
for i, arg := range args {
if arg == "--" {
args = args[i+1:]
break
}
}
if len(args) >= 1 {
_, _ = fmt.Fprint(os.Stdout, args[0])
}
if len(args) >= 2 {
_, _ = fmt.Fprint(os.Stderr, args[1])
}
os.Exit(0)
}
func validStderr() string {
return `pg_basebackup: initiating base backup, waiting for checkpoint to complete
pg_basebackup: checkpoint completed
pg_basebackup: write-ahead log start point: 0/2000028, on timeline 1
pg_basebackup: checkpoint redo point at 0/2000028
pg_basebackup: write-ahead log end point: 0/2000100
pg_basebackup: base backup completed`
}
func writeJSON(w http.ResponseWriter, v any) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(v); err != nil {
w.WriteHeader(http.StatusInternalServerError)
}
}
func waitForCondition(t *testing.T, condition func() bool, timeout time.Duration) {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if condition() {
return
}
time.Sleep(50 * time.Millisecond)
}
t.Fatalf("condition not met within %v", timeout)
}
func setRetryDelay(d time.Duration) {
retryDelayOverride = &d
}
func init() {
retryDelayOverride = nil
}

View File

@@ -0,0 +1,75 @@
package full_backup
import (
"fmt"
"regexp"
"strconv"
"strings"
)
const defaultWalSegmentSize uint32 = 16 * 1024 * 1024 // 16 MB
var (
startLSNRegex = regexp.MustCompile(`checkpoint redo point at ([0-9A-Fa-f]+/[0-9A-Fa-f]+)`)
stopLSNRegex = regexp.MustCompile(`write-ahead log end point: ([0-9A-Fa-f]+/[0-9A-Fa-f]+)`)
)
func ParseBasebackupStderr(stderr string) (startSegment, stopSegment string, err error) {
startMatch := startLSNRegex.FindStringSubmatch(stderr)
if len(startMatch) < 2 {
return "", "", fmt.Errorf("failed to parse start WAL location from pg_basebackup stderr")
}
stopMatch := stopLSNRegex.FindStringSubmatch(stderr)
if len(stopMatch) < 2 {
return "", "", fmt.Errorf("failed to parse stop WAL location from pg_basebackup stderr")
}
startSegment, err = LSNToSegmentName(startMatch[1], 1, defaultWalSegmentSize)
if err != nil {
return "", "", fmt.Errorf("failed to convert start LSN to segment name: %w", err)
}
stopSegment, err = LSNToSegmentName(stopMatch[1], 1, defaultWalSegmentSize)
if err != nil {
return "", "", fmt.Errorf("failed to convert stop LSN to segment name: %w", err)
}
return startSegment, stopSegment, nil
}
func LSNToSegmentName(lsn string, timelineID, walSegmentSize uint32) (string, error) {
high, low, err := parseLSN(lsn)
if err != nil {
return "", err
}
segmentsPerXLogID := uint32(0x100000000 / uint64(walSegmentSize))
logID := high
segmentOffset := low / walSegmentSize
if segmentOffset >= segmentsPerXLogID {
return "", fmt.Errorf("segment offset %d exceeds segments per XLogId %d", segmentOffset, segmentsPerXLogID)
}
return fmt.Sprintf("%08X%08X%08X", timelineID, logID, segmentOffset), nil
}
func parseLSN(lsn string) (high, low uint32, err error) {
parts := strings.SplitN(lsn, "/", 2)
if len(parts) != 2 {
return 0, 0, fmt.Errorf("invalid LSN format: %q (expected X/Y)", lsn)
}
highVal, err := strconv.ParseUint(parts[0], 16, 32)
if err != nil {
return 0, 0, fmt.Errorf("invalid LSN high part %q: %w", parts[0], err)
}
lowVal, err := strconv.ParseUint(parts[1], 16, 32)
if err != nil {
return 0, 0, fmt.Errorf("invalid LSN low part %q: %w", parts[1], err)
}
return uint32(highVal), uint32(lowVal), nil
}

View File

@@ -0,0 +1,162 @@
package full_backup
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_ParseBasebackupStderr_WithPG17Output_ExtractsCorrectSegments(t *testing.T) {
stderr := `pg_basebackup: initiating base backup, waiting for checkpoint to complete
pg_basebackup: checkpoint completed
pg_basebackup: write-ahead log start point: 0/2000028, on timeline 1
pg_basebackup: starting background WAL receiver
pg_basebackup: checkpoint redo point at 0/2000028
pg_basebackup: write-ahead log end point: 0/2000100
pg_basebackup: waiting for background process to finish streaming ...
pg_basebackup: syncing data to disk ...
pg_basebackup: renaming backup_manifest.tmp to backup_manifest
pg_basebackup: base backup completed`
startSeg, stopSeg, err := ParseBasebackupStderr(stderr)
require.NoError(t, err)
assert.Equal(t, "000000010000000000000002", startSeg)
assert.Equal(t, "000000010000000000000002", stopSeg)
}
func Test_ParseBasebackupStderr_WithPG15Output_ExtractsCorrectSegments(t *testing.T) {
stderr := `pg_basebackup: initiating base backup, waiting for checkpoint to complete
pg_basebackup: checkpoint completed
pg_basebackup: write-ahead log start point: 1/AB000028, on timeline 1
pg_basebackup: checkpoint redo point at 1/AB000028
pg_basebackup: write-ahead log end point: 1/AC000000
pg_basebackup: base backup completed`
startSeg, stopSeg, err := ParseBasebackupStderr(stderr)
require.NoError(t, err)
assert.Equal(t, "0000000100000001000000AB", startSeg)
assert.Equal(t, "0000000100000001000000AC", stopSeg)
}
func Test_ParseBasebackupStderr_WithHighLogID_ExtractsCorrectSegments(t *testing.T) {
stderr := `pg_basebackup: checkpoint redo point at A/FF000028
pg_basebackup: write-ahead log end point: B/1000000`
startSeg, stopSeg, err := ParseBasebackupStderr(stderr)
require.NoError(t, err)
assert.Equal(t, "000000010000000A000000FF", startSeg)
assert.Equal(t, "000000010000000B00000001", stopSeg)
}
func Test_ParseBasebackupStderr_WhenStartLSNMissing_ReturnsError(t *testing.T) {
stderr := `pg_basebackup: write-ahead log end point: 0/2000100
pg_basebackup: base backup completed`
_, _, err := ParseBasebackupStderr(stderr)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to parse start WAL location")
}
func Test_ParseBasebackupStderr_WhenStopLSNMissing_ReturnsError(t *testing.T) {
stderr := `pg_basebackup: checkpoint redo point at 0/2000028
pg_basebackup: base backup completed`
_, _, err := ParseBasebackupStderr(stderr)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to parse stop WAL location")
}
func Test_ParseBasebackupStderr_WhenEmptyStderr_ReturnsError(t *testing.T) {
_, _, err := ParseBasebackupStderr("")
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to parse start WAL location")
}
func Test_LSNToSegmentName_WithBoundaryValues_ConvertsCorrectly(t *testing.T) {
tests := []struct {
name string
lsn string
timeline uint32
segSize uint32
expected string
}{
{
name: "first segment",
lsn: "0/1000000",
timeline: 1,
segSize: 16 * 1024 * 1024,
expected: "000000010000000000000001",
},
{
name: "segment at boundary FF",
lsn: "0/FF000000",
timeline: 1,
segSize: 16 * 1024 * 1024,
expected: "0000000100000000000000FF",
},
{
name: "segment in second log file",
lsn: "1/0",
timeline: 1,
segSize: 16 * 1024 * 1024,
expected: "000000010000000100000000",
},
{
name: "segment with offset within 16MB",
lsn: "0/200ABCD",
timeline: 1,
segSize: 16 * 1024 * 1024,
expected: "000000010000000000000002",
},
{
name: "zero LSN",
lsn: "0/0",
timeline: 1,
segSize: 16 * 1024 * 1024,
expected: "000000010000000000000000",
},
{
name: "high timeline ID",
lsn: "0/1000000",
timeline: 2,
segSize: 16 * 1024 * 1024,
expected: "000000020000000000000001",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := LSNToSegmentName(tt.lsn, tt.timeline, tt.segSize)
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}
func Test_LSNToSegmentName_WithInvalidLSN_ReturnsError(t *testing.T) {
tests := []struct {
name string
lsn string
}{
{name: "no slash", lsn: "012345"},
{name: "empty string", lsn: ""},
{name: "invalid hex high", lsn: "GG/0"},
{name: "invalid hex low", lsn: "0/ZZ"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := LSNToSegmentName(tt.lsn, 1, 16*1024*1024)
require.Error(t, err)
})
}
}

View File

@@ -0,0 +1,120 @@
//go:build !windows
package start
import (
"errors"
"fmt"
"log/slog"
"os"
"os/exec"
"syscall"
"time"
)
const (
logFileName = "databasus.log"
stopTimeout = 30 * time.Second
stopPollInterval = 500 * time.Millisecond
daemonStartupDelay = 500 * time.Millisecond
)
func Stop(log *slog.Logger) error {
pid, err := ReadLockFilePID()
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return errors.New("agent is not running (no lock file found)")
}
return fmt.Errorf("failed to read lock file: %w", err)
}
if !isProcessAlive(pid) {
_ = os.Remove(lockFileName)
return fmt.Errorf("agent is not running (stale lock file removed, PID %d)", pid)
}
log.Info("Sending SIGTERM to agent", "pid", pid)
if err := syscall.Kill(pid, syscall.SIGTERM); err != nil {
return fmt.Errorf("failed to send SIGTERM to PID %d: %w", pid, err)
}
deadline := time.Now().Add(stopTimeout)
for time.Now().Before(deadline) {
if !isProcessAlive(pid) {
log.Info("Agent stopped", "pid", pid)
return nil
}
time.Sleep(stopPollInterval)
}
return fmt.Errorf("agent (PID %d) did not stop within %s — process may be stuck", pid, stopTimeout)
}
func Status(log *slog.Logger) error {
pid, err := ReadLockFilePID()
if err != nil {
if errors.Is(err, os.ErrNotExist) {
fmt.Println("Agent is not running")
return nil
}
return fmt.Errorf("failed to read lock file: %w", err)
}
if isProcessAlive(pid) {
fmt.Printf("Agent is running (PID %d)\n", pid)
} else {
fmt.Println("Agent is not running (stale lock file)")
_ = os.Remove(lockFileName)
}
return nil
}
func spawnDaemon(log *slog.Logger) (int, error) {
execPath, err := os.Executable()
if err != nil {
return 0, fmt.Errorf("failed to resolve executable path: %w", err)
}
args := []string{"_run"}
logFile, err := os.OpenFile(logFileName, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return 0, fmt.Errorf("failed to open log file %s: %w", logFileName, err)
}
cwd, err := os.Getwd()
if err != nil {
_ = logFile.Close()
return 0, fmt.Errorf("failed to get working directory: %w", err)
}
cmd := exec.Command(execPath, args...)
cmd.Dir = cwd
cmd.Stderr = logFile
cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true}
if err := cmd.Start(); err != nil {
_ = logFile.Close()
return 0, fmt.Errorf("failed to start daemon process: %w", err)
}
pid := cmd.Process.Pid
// Detach — we don't wait for the child
_ = logFile.Close()
time.Sleep(daemonStartupDelay)
if !isProcessAlive(pid) {
return 0, fmt.Errorf("daemon process (PID %d) exited immediately — check %s for details", pid, logFileName)
}
log.Info("Daemon spawned", "pid", pid, "log", logFileName)
return pid, nil
}

View File

@@ -0,0 +1,20 @@
//go:build windows
package start
import (
"errors"
"log/slog"
)
func Stop(log *slog.Logger) error {
return errors.New("stop is not supported on Windows — use Ctrl+C in the terminal where the agent is running")
}
func Status(log *slog.Logger) error {
return errors.New("status is not supported on Windows — check the terminal where the agent is running")
}
func spawnDaemon(_ *slog.Logger) (int, error) {
return 0, errors.New("daemon mode is not supported on Windows")
}

View File

@@ -0,0 +1,117 @@
//go:build !windows
package start
import (
"errors"
"fmt"
"io"
"log/slog"
"os"
"strconv"
"strings"
"syscall"
)
const lockFileName = "databasus.lock"
func AcquireLock(log *slog.Logger) (*os.File, error) {
f, err := os.OpenFile(lockFileName, os.O_CREATE|os.O_RDWR, 0o644)
if err != nil {
return nil, fmt.Errorf("failed to open lock file: %w", err)
}
err = syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB)
if err == nil {
if err := writePID(f); err != nil {
_ = f.Close()
return nil, err
}
log.Info("Process lock acquired", "pid", os.Getpid(), "lockFile", lockFileName)
return f, nil
}
if !errors.Is(err, syscall.EWOULDBLOCK) {
_ = f.Close()
return nil, fmt.Errorf("failed to acquire lock: %w", err)
}
pid, pidErr := readLockPID(f)
_ = f.Close()
if pidErr != nil {
return nil, fmt.Errorf("Another instance is already running.")
}
return nil, fmt.Errorf("Another instance is already running (PID %d).", pid)
}
func ReleaseLock(f *os.File) {
_ = syscall.Flock(int(f.Fd()), syscall.LOCK_UN)
_ = f.Close()
_ = os.Remove(lockFileName)
}
func ReadLockFilePID() (int, error) {
f, err := os.Open(lockFileName)
if err != nil {
return 0, err
}
defer f.Close()
return readLockPID(f)
}
func writePID(f *os.File) error {
if err := f.Truncate(0); err != nil {
return fmt.Errorf("failed to truncate lock file: %w", err)
}
if _, err := f.Seek(0, io.SeekStart); err != nil {
return fmt.Errorf("failed to seek lock file: %w", err)
}
if _, err := fmt.Fprintf(f, "%d\n", os.Getpid()); err != nil {
return fmt.Errorf("failed to write PID to lock file: %w", err)
}
return f.Sync()
}
func readLockPID(f *os.File) (int, error) {
if _, err := f.Seek(0, io.SeekStart); err != nil {
return 0, err
}
data, err := io.ReadAll(f)
if err != nil {
return 0, err
}
s := strings.TrimSpace(string(data))
if s == "" {
return 0, errors.New("lock file is empty")
}
pid, err := strconv.Atoi(s)
if err != nil {
return 0, fmt.Errorf("invalid PID in lock file: %w", err)
}
return pid, nil
}
func isProcessAlive(pid int) bool {
err := syscall.Kill(pid, 0)
if err == nil {
return true
}
if errors.Is(err, syscall.EPERM) {
return true
}
return false
}

View File

@@ -0,0 +1,148 @@
//go:build !windows
package start
import (
"fmt"
"os"
"strconv"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"databasus-agent/internal/logger"
)
func Test_AcquireLock_LockFileCreatedWithPID(t *testing.T) {
setupTempDir(t)
log := logger.GetLogger()
lockFile, err := AcquireLock(log)
require.NoError(t, err)
defer ReleaseLock(lockFile)
data, err := os.ReadFile(lockFileName)
require.NoError(t, err)
pid, err := strconv.Atoi(strings.TrimSpace(string(data)))
require.NoError(t, err)
assert.Equal(t, os.Getpid(), pid)
}
func Test_AcquireLock_SecondAcquireFails_WhenFirstHeld(t *testing.T) {
setupTempDir(t)
log := logger.GetLogger()
first, err := AcquireLock(log)
require.NoError(t, err)
defer ReleaseLock(first)
second, err := AcquireLock(log)
assert.Nil(t, second)
require.Error(t, err)
assert.Contains(t, err.Error(), "Another instance is already running")
assert.Contains(t, err.Error(), fmt.Sprintf("PID %d", os.Getpid()))
}
func Test_AcquireLock_StaleLockReacquired_WhenProcessDead(t *testing.T) {
setupTempDir(t)
log := logger.GetLogger()
err := os.WriteFile(lockFileName, []byte("999999999\n"), 0o644)
require.NoError(t, err)
lockFile, err := AcquireLock(log)
require.NoError(t, err)
defer ReleaseLock(lockFile)
data, err := os.ReadFile(lockFileName)
require.NoError(t, err)
pid, err := strconv.Atoi(strings.TrimSpace(string(data)))
require.NoError(t, err)
assert.Equal(t, os.Getpid(), pid)
}
func Test_ReleaseLock_LockFileRemoved(t *testing.T) {
setupTempDir(t)
log := logger.GetLogger()
lockFile, err := AcquireLock(log)
require.NoError(t, err)
ReleaseLock(lockFile)
_, err = os.Stat(lockFileName)
assert.True(t, os.IsNotExist(err))
}
func Test_AcquireLock_ReacquiredAfterRelease(t *testing.T) {
setupTempDir(t)
log := logger.GetLogger()
first, err := AcquireLock(log)
require.NoError(t, err)
ReleaseLock(first)
second, err := AcquireLock(log)
require.NoError(t, err)
defer ReleaseLock(second)
data, err := os.ReadFile(lockFileName)
require.NoError(t, err)
pid, err := strconv.Atoi(strings.TrimSpace(string(data)))
require.NoError(t, err)
assert.Equal(t, os.Getpid(), pid)
}
func Test_isProcessAlive_ReturnsTrueForSelf(t *testing.T) {
assert.True(t, isProcessAlive(os.Getpid()))
}
func Test_isProcessAlive_ReturnsFalseForNonExistentPID(t *testing.T) {
assert.False(t, isProcessAlive(999999999))
}
func Test_readLockPID_ParsesValidPID(t *testing.T) {
setupTempDir(t)
f, err := os.CreateTemp("", "lock-test-*")
require.NoError(t, err)
defer os.Remove(f.Name())
_, err = f.WriteString("12345\n")
require.NoError(t, err)
pid, err := readLockPID(f)
require.NoError(t, err)
assert.Equal(t, 12345, pid)
}
func Test_readLockPID_ReturnsErrorForEmptyFile(t *testing.T) {
setupTempDir(t)
f, err := os.CreateTemp("", "lock-test-*")
require.NoError(t, err)
defer os.Remove(f.Name())
_, err = readLockPID(f)
require.Error(t, err)
assert.Contains(t, err.Error(), "lock file is empty")
}
func setupTempDir(t *testing.T) string {
t.Helper()
origDir, err := os.Getwd()
require.NoError(t, err)
dir := t.TempDir()
require.NoError(t, os.Chdir(dir))
t.Cleanup(func() { _ = os.Chdir(origDir) })
return dir
}

View File

@@ -0,0 +1,90 @@
//go:build !windows
package start
import (
"context"
"log/slog"
"os"
"syscall"
"time"
)
const lockWatchInterval = 5 * time.Second
type LockWatcher struct {
originalInode uint64
cancel context.CancelFunc
log *slog.Logger
}
func NewLockWatcher(lockFile *os.File, cancel context.CancelFunc, log *slog.Logger) (*LockWatcher, error) {
inode, err := getFileInode(lockFile)
if err != nil {
return nil, err
}
return &LockWatcher{
originalInode: inode,
cancel: cancel,
log: log,
}, nil
}
func (w *LockWatcher) Run(ctx context.Context) {
ticker := time.NewTicker(lockWatchInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
w.check()
}
}
}
func (w *LockWatcher) check() {
info, err := os.Stat(lockFileName)
if err != nil {
w.log.Error("Lock file disappeared, shutting down", "file", lockFileName, "error", err)
w.cancel()
return
}
currentInode, err := getStatInode(info)
if err != nil {
w.log.Error("Failed to read lock file inode, shutting down", "error", err)
w.cancel()
return
}
if currentInode != w.originalInode {
w.log.Error("Lock file was replaced (inode changed), shutting down",
"originalInode", w.originalInode,
"currentInode", currentInode,
)
w.cancel()
}
}
func getFileInode(f *os.File) (uint64, error) {
info, err := f.Stat()
if err != nil {
return 0, err
}
return getStatInode(info)
}
func getStatInode(info os.FileInfo) (uint64, error) {
stat, ok := info.Sys().(*syscall.Stat_t)
if !ok {
return 0, os.ErrInvalid
}
return stat.Ino, nil
}

View File

@@ -0,0 +1,110 @@
//go:build !windows
package start
import (
"context"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"databasus-agent/internal/logger"
)
func Test_NewLockWatcher_CapturesInode(t *testing.T) {
setupTempDir(t)
log := logger.GetLogger()
lockFile, err := AcquireLock(log)
require.NoError(t, err)
defer ReleaseLock(lockFile)
_, cancel := context.WithCancel(context.Background())
defer cancel()
watcher, err := NewLockWatcher(lockFile, cancel, log)
require.NoError(t, err)
assert.NotZero(t, watcher.originalInode)
}
func Test_LockWatcher_FileUnchanged_ContextNotCancelled(t *testing.T) {
setupTempDir(t)
log := logger.GetLogger()
lockFile, err := AcquireLock(log)
require.NoError(t, err)
defer ReleaseLock(lockFile)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
watcher, err := NewLockWatcher(lockFile, cancel, log)
require.NoError(t, err)
watcher.check()
watcher.check()
watcher.check()
select {
case <-ctx.Done():
t.Fatal("context should not be cancelled when lock file is unchanged")
default:
}
}
func Test_LockWatcher_FileDeleted_CancelsContext(t *testing.T) {
setupTempDir(t)
log := logger.GetLogger()
lockFile, err := AcquireLock(log)
require.NoError(t, err)
defer ReleaseLock(lockFile)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
watcher, err := NewLockWatcher(lockFile, cancel, log)
require.NoError(t, err)
err = os.Remove(lockFileName)
require.NoError(t, err)
watcher.check()
select {
case <-ctx.Done():
default:
t.Fatal("context should be cancelled when lock file is deleted")
}
}
func Test_LockWatcher_FileReplacedWithDifferentInode_CancelsContext(t *testing.T) {
setupTempDir(t)
log := logger.GetLogger()
lockFile, err := AcquireLock(log)
require.NoError(t, err)
defer ReleaseLock(lockFile)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
watcher, err := NewLockWatcher(lockFile, cancel, log)
require.NoError(t, err)
err = os.Remove(lockFileName)
require.NoError(t, err)
err = os.WriteFile(lockFileName, []byte("99999\n"), 0o644)
require.NoError(t, err)
watcher.check()
select {
case <-ctx.Done():
default:
t.Fatal("context should be cancelled when lock file inode changes")
}
}

View File

@@ -0,0 +1,17 @@
//go:build windows
package start
import (
"context"
"log/slog"
"os"
)
type LockWatcher struct{}
func NewLockWatcher(_ *os.File, _ context.CancelFunc, _ *slog.Logger) (*LockWatcher, error) {
return &LockWatcher{}, nil
}
func (w *LockWatcher) Run(_ context.Context) {}

View File

@@ -0,0 +1,18 @@
package start
import (
"log/slog"
"os"
)
func AcquireLock(log *slog.Logger) (*os.File, error) {
log.Warn("Process locking is not supported on Windows, skipping")
return nil, nil
}
func ReleaseLock(f *os.File) {
if f != nil {
_ = f.Close()
}
}

View File

@@ -5,22 +5,32 @@ import (
"errors"
"fmt"
"log/slog"
"os"
"os/exec"
"os/signal"
"path/filepath"
"runtime"
"strconv"
"strings"
"syscall"
"time"
"github.com/jackc/pgx/v5"
"databasus-agent/internal/config"
"databasus-agent/internal/features/api"
full_backup "databasus-agent/internal/features/full_backup"
"databasus-agent/internal/features/upgrade"
"databasus-agent/internal/features/wal"
)
const (
pgBasebackupVerifyTimeout = 10 * time.Second
dbVerifyTimeout = 10 * time.Second
minPgMajorVersion = 15
)
func Run(cfg *config.Config, log *slog.Logger) error {
func Start(cfg *config.Config, agentVersion string, isDev bool, log *slog.Logger) error {
if err := validateConfig(cfg); err != nil {
return err
}
@@ -33,10 +43,59 @@ func Run(cfg *config.Config, log *slog.Logger) error {
return err
}
log.Info("start: stub — not yet implemented",
"dbId", cfg.DbID,
"hasToken", cfg.Token != "",
)
if runtime.GOOS == "windows" {
return RunDaemon(cfg, agentVersion, isDev, log)
}
pid, err := spawnDaemon(log)
if err != nil {
return err
}
fmt.Printf("Agent started in background (PID %d)\n", pid)
return nil
}
func RunDaemon(cfg *config.Config, agentVersion string, isDev bool, log *slog.Logger) error {
lockFile, err := AcquireLock(log)
if err != nil {
return err
}
defer ReleaseLock(lockFile)
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer cancel()
watcher, err := NewLockWatcher(lockFile, cancel, log)
if err != nil {
return fmt.Errorf("failed to initialize lock watcher: %w", err)
}
go watcher.Run(ctx)
apiClient := api.NewClient(cfg.DatabasusHost, cfg.Token, log)
var backgroundUpgrader *upgrade.BackgroundUpgrader
if agentVersion != "dev" && runtime.GOOS != "windows" {
backgroundUpgrader = upgrade.NewBackgroundUpgrader(apiClient, agentVersion, isDev, cancel, log)
go backgroundUpgrader.Run(ctx)
}
fullBackuper := full_backup.NewFullBackuper(cfg, apiClient, log)
go fullBackuper.Run(ctx)
streamer := wal.NewStreamer(cfg, apiClient, log)
streamer.Run(ctx)
if backgroundUpgrader != nil {
backgroundUpgrader.WaitForCompletion(30 * time.Second)
if backgroundUpgrader.IsUpgraded() {
return upgrade.ErrUpgradeRestart
}
}
log.Info("Agent stopped")
return nil
}
@@ -70,8 +129,8 @@ func validateConfig(cfg *config.Config) error {
return fmt.Errorf("argument pg-type must be 'host' or 'docker', got '%s'", cfg.PgType)
}
if cfg.WalDir == "" {
return errors.New("argument wal-dir is required")
if cfg.PgWalDir == "" {
return errors.New("argument pg-wal-dir is required")
}
if cfg.PgType == "docker" && cfg.PgDockerContainerName == "" {
@@ -169,11 +228,44 @@ func verifyDatabase(cfg *config.Config, log *slog.Logger) error {
)
}
var versionNumStr string
if err := conn.QueryRow(ctx, "SHOW server_version_num").Scan(&versionNumStr); err != nil {
return fmt.Errorf("failed to query PostgreSQL version: %w", err)
}
majorVersion, err := parsePgVersionNum(versionNumStr)
if err != nil {
return fmt.Errorf("failed to parse PostgreSQL version '%s': %w", versionNumStr, err)
}
if majorVersion < minPgMajorVersion {
return fmt.Errorf(
"PostgreSQL %d is not supported, minimum required version is %d",
majorVersion, minPgMajorVersion,
)
}
log.Info("PostgreSQL connection verified",
"host", cfg.PgHost,
"port", cfg.PgPort,
"user", cfg.PgUser,
"version", majorVersion,
)
return nil
}
func parsePgVersionNum(versionNumStr string) (int, error) {
versionNum, err := strconv.Atoi(strings.TrimSpace(versionNumStr))
if err != nil {
return 0, fmt.Errorf("invalid version number: %w", err)
}
if versionNum <= 0 {
return 0, fmt.Errorf("invalid version number: %d", versionNum)
}
majorVersion := versionNum / 10000
return majorVersion, nil
}

View File

@@ -0,0 +1,84 @@
package start
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_ParsePgVersionNum_SupportedVersions_ReturnsMajorVersion(t *testing.T) {
tests := []struct {
name string
versionNumStr string
expectedMajor int
}{
{name: "PG 15.0", versionNumStr: "150000", expectedMajor: 15},
{name: "PG 15.4", versionNumStr: "150004", expectedMajor: 15},
{name: "PG 16.0", versionNumStr: "160000", expectedMajor: 16},
{name: "PG 16.3", versionNumStr: "160003", expectedMajor: 16},
{name: "PG 17.2", versionNumStr: "170002", expectedMajor: 17},
{name: "PG 18.0", versionNumStr: "180000", expectedMajor: 18},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
major, err := parsePgVersionNum(tt.versionNumStr)
require.NoError(t, err)
assert.Equal(t, tt.expectedMajor, major)
assert.GreaterOrEqual(t, major, minPgMajorVersion)
})
}
}
func Test_ParsePgVersionNum_UnsupportedVersions_ReturnsMajorVersionBelow15(t *testing.T) {
tests := []struct {
name string
versionNumStr string
expectedMajor int
}{
{name: "PG 12.5", versionNumStr: "120005", expectedMajor: 12},
{name: "PG 13.0", versionNumStr: "130000", expectedMajor: 13},
{name: "PG 14.12", versionNumStr: "140012", expectedMajor: 14},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
major, err := parsePgVersionNum(tt.versionNumStr)
require.NoError(t, err)
assert.Equal(t, tt.expectedMajor, major)
assert.Less(t, major, minPgMajorVersion)
})
}
}
func Test_ParsePgVersionNum_InvalidInput_ReturnsError(t *testing.T) {
tests := []struct {
name string
versionNumStr string
}{
{name: "empty string", versionNumStr: ""},
{name: "non-numeric", versionNumStr: "abc"},
{name: "negative number", versionNumStr: "-1"},
{name: "zero", versionNumStr: "0"},
{name: "float", versionNumStr: "15.4"},
{name: "whitespace only", versionNumStr: " "},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := parsePgVersionNum(tt.versionNumStr)
require.Error(t, err)
})
}
}
func Test_ParsePgVersionNum_WithWhitespace_ParsesCorrectly(t *testing.T) {
major, err := parsePgVersionNum(" 150004 ")
require.NoError(t, err)
assert.Equal(t, 15, major)
}

View File

@@ -0,0 +1,88 @@
package upgrade
import (
"context"
"log/slog"
"sync/atomic"
"time"
"databasus-agent/internal/features/api"
)
const backgroundCheckInterval = 5 * time.Second
type BackgroundUpgrader struct {
apiClient *api.Client
currentVersion string
isDev bool
cancel context.CancelFunc
isUpgraded atomic.Bool
log *slog.Logger
done chan struct{}
}
func NewBackgroundUpgrader(
apiClient *api.Client,
currentVersion string,
isDev bool,
cancel context.CancelFunc,
log *slog.Logger,
) *BackgroundUpgrader {
return &BackgroundUpgrader{
apiClient,
currentVersion,
isDev,
cancel,
atomic.Bool{},
log,
make(chan struct{}),
}
}
func (u *BackgroundUpgrader) Run(ctx context.Context) {
defer close(u.done)
ticker := time.NewTicker(backgroundCheckInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if u.checkAndUpgrade() {
return
}
}
}
}
func (u *BackgroundUpgrader) IsUpgraded() bool {
return u.isUpgraded.Load()
}
func (u *BackgroundUpgrader) WaitForCompletion(timeout time.Duration) {
select {
case <-u.done:
case <-time.After(timeout):
}
}
func (u *BackgroundUpgrader) checkAndUpgrade() bool {
isUpgraded, err := CheckAndUpdate(u.apiClient, u.currentVersion, u.isDev, u.log)
if err != nil {
u.log.Warn("Background update check failed", "error", err)
return false
}
if !isUpgraded {
return false
}
u.log.Info("Background upgrade complete, restarting...")
u.isUpgraded.Store(true)
u.cancel()
return true
}

View File

@@ -1,5 +0,0 @@
package upgrade
type versionResponse struct {
Version string `json:"version"`
}

View File

@@ -0,0 +1,5 @@
package upgrade
import "errors"
var ErrUpgradeRestart = errors.New("agent upgraded, restart required")

View File

@@ -2,44 +2,47 @@ package upgrade
import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"os/exec"
"runtime"
"strings"
"syscall"
"time"
"databasus-agent/internal/features/api"
)
func CheckAndUpdate(databasusHost, currentVersion string, isDev bool, log *slog.Logger) error {
// CheckAndUpdate checks if a new version is available and upgrades the binary on disk.
// Returns (true, nil) if the binary was upgraded, (false, nil) if already up to date,
// or (false, err) on failure. Callers are responsible for re-exec or restart signaling.
func CheckAndUpdate(apiClient *api.Client, currentVersion string, isDev bool, log *slog.Logger) (bool, error) {
if isDev {
log.Info("Skipping update check (development mode)")
return nil
return false, nil
}
serverVersion, err := fetchServerVersion(databasusHost, log)
serverVersion, err := apiClient.FetchServerVersion(context.Background())
if err != nil {
return fmt.Errorf(
"unable to check version, please verify Databasus server is available at %s: %w",
databasusHost,
log.Warn("Could not reach server for update check", "error", err)
return false, fmt.Errorf(
"unable to check version, please verify Databasus server is available: %w",
err,
)
}
if serverVersion == currentVersion {
log.Info("Agent version is up to date", "version", currentVersion)
return nil
return false, nil
}
log.Info("Updating agent...", "current", currentVersion, "target", serverVersion)
selfPath, err := os.Executable()
if err != nil {
return fmt.Errorf("failed to determine executable path: %w", err)
return false, fmt.Errorf("failed to determine executable path: %w", err)
}
tempPath := selfPath + ".update"
@@ -48,93 +51,25 @@ func CheckAndUpdate(databasusHost, currentVersion string, isDev bool, log *slog.
_ = os.Remove(tempPath)
}()
if err := downloadBinary(databasusHost, tempPath); err != nil {
return fmt.Errorf("failed to download update: %w", err)
if err := apiClient.DownloadAgentBinary(context.Background(), runtime.GOARCH, tempPath); err != nil {
return false, fmt.Errorf("failed to download update: %w", err)
}
if err := os.Chmod(tempPath, 0o755); err != nil {
return fmt.Errorf("failed to set permissions on update: %w", err)
return false, fmt.Errorf("failed to set permissions on update: %w", err)
}
if err := verifyBinary(tempPath, serverVersion); err != nil {
return fmt.Errorf("update verification failed: %w", err)
return false, fmt.Errorf("update verification failed: %w", err)
}
if err := os.Rename(tempPath, selfPath); err != nil {
return fmt.Errorf("failed to replace binary (try --skip-update if this persists): %w", err)
return false, fmt.Errorf("failed to replace binary (try --skip-update if this persists): %w", err)
}
log.Info("Update complete, re-executing...")
log.Info("Agent binary updated", "version", serverVersion)
return syscall.Exec(selfPath, os.Args, os.Environ())
}
func fetchServerVersion(host string, log *slog.Logger) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
client := &http.Client{Timeout: 10 * time.Second}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, host+"/api/v1/system/version", nil)
if err != nil {
return "", err
}
resp, err := client.Do(req)
if err != nil {
log.Warn("Could not reach server for update check, continuing", "error", err)
return "", err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
log.Warn(
"Server returned non-OK status for version check, continuing",
"status",
resp.StatusCode,
)
return "", fmt.Errorf("status %d", resp.StatusCode)
}
var ver versionResponse
if err := json.NewDecoder(resp.Body).Decode(&ver); err != nil {
log.Warn("Failed to parse server version response, continuing", "error", err)
return "", err
}
return ver.Version, nil
}
func downloadBinary(host, destPath string) error {
url := fmt.Sprintf("%s/api/v1/system/agent?arch=%s", host, runtime.GOARCH)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("server returned %d for agent download", resp.StatusCode)
}
f, err := os.Create(destPath)
if err != nil {
return err
}
defer func() { _ = f.Close() }()
_, err = io.Copy(f, resp.Body)
return err
return true, nil
}
func verifyBinary(binaryPath, expectedVersion string) error {

View File

@@ -0,0 +1,182 @@
package wal
import (
"context"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"time"
"github.com/klauspost/compress/zstd"
"databasus-agent/internal/config"
"databasus-agent/internal/features/api"
)
const (
pollInterval = 2 * time.Second
uploadTimeout = 5 * time.Minute
)
var segmentNameRegex = regexp.MustCompile(`^[0-9A-Fa-f]{24}$`)
type Streamer struct {
cfg *config.Config
apiClient *api.Client
log *slog.Logger
}
func NewStreamer(cfg *config.Config, apiClient *api.Client, log *slog.Logger) *Streamer {
return &Streamer{
cfg: cfg,
apiClient: apiClient,
log: log,
}
}
func (s *Streamer) Run(ctx context.Context) {
s.log.Info("WAL streamer started", "pgWalDir", s.cfg.PgWalDir)
s.processQueue(ctx)
ticker := time.NewTicker(pollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
s.log.Info("WAL streamer stopping")
return
case <-ticker.C:
s.processQueue(ctx)
}
}
}
func (s *Streamer) processQueue(ctx context.Context) {
segments, err := s.listSegments()
if err != nil {
s.log.Error("Failed to list WAL segments", "error", err)
return
}
for _, segmentName := range segments {
if ctx.Err() != nil {
return
}
if err := s.uploadSegment(ctx, segmentName); err != nil {
s.log.Error("Failed to upload WAL segment",
"segment", segmentName,
"error", err,
)
return
}
}
}
func (s *Streamer) listSegments() ([]string, error) {
entries, err := os.ReadDir(s.cfg.PgWalDir)
if err != nil {
return nil, fmt.Errorf("read wal dir: %w", err)
}
var segments []string
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if strings.HasSuffix(name, ".tmp") {
continue
}
if !segmentNameRegex.MatchString(name) {
continue
}
segments = append(segments, name)
}
sort.Strings(segments)
return segments, nil
}
func (s *Streamer) uploadSegment(ctx context.Context, segmentName string) error {
filePath := filepath.Join(s.cfg.PgWalDir, segmentName)
pr, pw := io.Pipe()
go s.compressAndStream(pw, filePath)
uploadCtx, cancel := context.WithTimeout(ctx, uploadTimeout)
defer cancel()
result, err := s.apiClient.UploadWalSegment(uploadCtx, segmentName, pr)
if err != nil {
return err
}
if result.IsGapDetected {
s.log.Warn("WAL chain gap detected",
"segment", segmentName,
"expected", result.ExpectedSegmentName,
"received", result.ReceivedSegmentName,
)
return fmt.Errorf("gap detected for segment %s", segmentName)
}
s.log.Debug("WAL segment uploaded", "segment", segmentName)
if *s.cfg.IsDeleteWalAfterUpload {
if err := os.Remove(filePath); err != nil {
s.log.Warn("Failed to delete uploaded WAL segment",
"segment", segmentName,
"error", err,
)
}
}
return nil
}
func (s *Streamer) compressAndStream(pw *io.PipeWriter, filePath string) {
f, err := os.Open(filePath)
if err != nil {
_ = pw.CloseWithError(fmt.Errorf("open file: %w", err))
return
}
defer func() { _ = f.Close() }()
encoder, err := zstd.NewWriter(pw,
zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(5)),
zstd.WithEncoderCRC(true),
)
if err != nil {
_ = pw.CloseWithError(fmt.Errorf("create zstd encoder: %w", err))
return
}
if _, err := io.Copy(encoder, f); err != nil {
_ = encoder.Close()
_ = pw.CloseWithError(fmt.Errorf("compress: %w", err))
return
}
if err := encoder.Close(); err != nil {
_ = pw.CloseWithError(fmt.Errorf("close encoder: %w", err))
return
}
_ = pw.Close()
}

View File

@@ -0,0 +1,348 @@
package wal
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"sync"
"testing"
"time"
"github.com/klauspost/compress/zstd"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"databasus-agent/internal/config"
"databasus-agent/internal/features/api"
"databasus-agent/internal/logger"
)
func Test_UploadSegment_SingleSegment_ServerReceivesCorrectHeadersAndBody(t *testing.T) {
walDir := createTestWalDir(t)
segmentContent := []byte("test-wal-segment-data-for-upload")
writeTestSegment(t, walDir, "000000010000000100000001", segmentContent)
var receivedHeaders http.Header
var receivedBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedHeaders = r.Header.Clone()
body, err := io.ReadAll(r.Body)
require.NoError(t, err)
receivedBody = body
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
streamer := newTestStreamer(walDir, server.URL)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
go streamer.Run(ctx)
time.Sleep(500 * time.Millisecond)
cancel()
require.NotNil(t, receivedHeaders)
assert.Equal(t, "test-token", receivedHeaders.Get("Authorization"))
assert.Equal(t, "application/octet-stream", receivedHeaders.Get("Content-Type"))
assert.Equal(t, "000000010000000100000001", receivedHeaders.Get("X-Wal-Segment-Name"))
decompressed := decompressZstd(t, receivedBody)
assert.Equal(t, segmentContent, decompressed)
}
func Test_UploadSegments_MultipleSegmentsOutOfOrder_UploadedInAscendingOrder(t *testing.T) {
walDir := createTestWalDir(t)
writeTestSegment(t, walDir, "000000010000000100000003", []byte("third"))
writeTestSegment(t, walDir, "000000010000000100000001", []byte("first"))
writeTestSegment(t, walDir, "000000010000000100000002", []byte("second"))
var mu sync.Mutex
var uploadOrder []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
uploadOrder = append(uploadOrder, r.Header.Get("X-Wal-Segment-Name"))
mu.Unlock()
_, _ = io.ReadAll(r.Body)
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
streamer := newTestStreamer(walDir, server.URL)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
go streamer.Run(ctx)
time.Sleep(500 * time.Millisecond)
cancel()
mu.Lock()
defer mu.Unlock()
require.Len(t, uploadOrder, 3)
assert.Equal(t, "000000010000000100000001", uploadOrder[0])
assert.Equal(t, "000000010000000100000002", uploadOrder[1])
assert.Equal(t, "000000010000000100000003", uploadOrder[2])
}
func Test_UploadSegments_DirectoryHasTmpFiles_TmpFilesIgnored(t *testing.T) {
walDir := createTestWalDir(t)
writeTestSegment(t, walDir, "000000010000000100000001", []byte("real segment"))
writeTestSegment(t, walDir, "000000010000000100000002.tmp", []byte("partial copy"))
var mu sync.Mutex
var uploadedSegments []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
uploadedSegments = append(uploadedSegments, r.Header.Get("X-Wal-Segment-Name"))
mu.Unlock()
_, _ = io.ReadAll(r.Body)
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
streamer := newTestStreamer(walDir, server.URL)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
go streamer.Run(ctx)
time.Sleep(500 * time.Millisecond)
cancel()
mu.Lock()
defer mu.Unlock()
require.Len(t, uploadedSegments, 1)
assert.Equal(t, "000000010000000100000001", uploadedSegments[0])
}
func Test_UploadSegment_DeleteEnabled_FileRemovedAfterUpload(t *testing.T) {
walDir := createTestWalDir(t)
segmentName := "000000010000000100000001"
writeTestSegment(t, walDir, segmentName, []byte("segment data"))
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.ReadAll(r.Body)
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
isDeleteEnabled := true
cfg := createTestConfig(walDir, server.URL)
cfg.IsDeleteWalAfterUpload = &isDeleteEnabled
apiClient := api.NewClient(server.URL, cfg.Token, logger.GetLogger())
streamer := NewStreamer(cfg, apiClient, logger.GetLogger())
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
go streamer.Run(ctx)
time.Sleep(500 * time.Millisecond)
cancel()
_, err := os.Stat(filepath.Join(walDir, segmentName))
assert.True(t, os.IsNotExist(err), "segment file should be deleted after successful upload")
}
func Test_UploadSegment_DeleteDisabled_FileKeptAfterUpload(t *testing.T) {
walDir := createTestWalDir(t)
segmentName := "000000010000000100000001"
writeTestSegment(t, walDir, segmentName, []byte("segment data"))
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.ReadAll(r.Body)
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
isDeleteDisabled := false
cfg := createTestConfig(walDir, server.URL)
cfg.IsDeleteWalAfterUpload = &isDeleteDisabled
apiClient := api.NewClient(server.URL, cfg.Token, logger.GetLogger())
streamer := NewStreamer(cfg, apiClient, logger.GetLogger())
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
go streamer.Run(ctx)
time.Sleep(500 * time.Millisecond)
cancel()
_, err := os.Stat(filepath.Join(walDir, segmentName))
assert.NoError(t, err, "segment file should be kept when delete is disabled")
}
func Test_UploadSegment_ServerReturns500_FileKeptInQueue(t *testing.T) {
walDir := createTestWalDir(t)
segmentName := "000000010000000100000001"
writeTestSegment(t, walDir, segmentName, []byte("segment data"))
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.ReadAll(r.Body)
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(`{"error":"internal server error"}`))
}))
defer server.Close()
streamer := newTestStreamer(walDir, server.URL)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
go streamer.Run(ctx)
time.Sleep(500 * time.Millisecond)
cancel()
_, err := os.Stat(filepath.Join(walDir, segmentName))
assert.NoError(t, err, "segment file should remain in queue after server error")
}
func Test_ProcessQueue_EmptyDirectory_NoUploads(t *testing.T) {
walDir := createTestWalDir(t)
uploadCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
uploadCount++
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
streamer := newTestStreamer(walDir, server.URL)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
go streamer.Run(ctx)
time.Sleep(500 * time.Millisecond)
cancel()
assert.Equal(t, 0, uploadCount, "no uploads should occur for empty directory")
}
func Test_Run_ContextCancelled_StopsImmediately(t *testing.T) {
walDir := createTestWalDir(t)
streamer := newTestStreamer(walDir, "http://localhost:0")
ctx, cancel := context.WithCancel(context.Background())
cancel()
done := make(chan struct{})
go func() {
streamer.Run(ctx)
close(done)
}()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("Run should have stopped immediately when context is already cancelled")
}
}
func Test_UploadSegment_ServerReturns409_FileNotDeleted(t *testing.T) {
walDir := createTestWalDir(t)
segmentName := "000000010000000100000005"
writeTestSegment(t, walDir, segmentName, []byte("gap segment"))
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.ReadAll(r.Body)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusConflict)
resp := map[string]string{
"error": "gap_detected",
"expectedSegmentName": "000000010000000100000003",
"receivedSegmentName": segmentName,
}
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
streamer := newTestStreamer(walDir, server.URL)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
go streamer.Run(ctx)
time.Sleep(500 * time.Millisecond)
cancel()
_, err := os.Stat(filepath.Join(walDir, segmentName))
assert.NoError(t, err, "segment file should not be deleted on gap detection")
}
func newTestStreamer(walDir, serverURL string) *Streamer {
cfg := createTestConfig(walDir, serverURL)
apiClient := api.NewClient(serverURL, cfg.Token, logger.GetLogger())
return NewStreamer(cfg, apiClient, logger.GetLogger())
}
func createTestWalDir(t *testing.T) string {
t.Helper()
baseDir := filepath.Join(".", ".test-tmp")
if err := os.MkdirAll(baseDir, 0o755); err != nil {
t.Fatalf("failed to create base test dir: %v", err)
}
dir, err := os.MkdirTemp(baseDir, t.Name()+"-*")
if err != nil {
t.Fatalf("failed to create test wal dir: %v", err)
}
t.Cleanup(func() {
_ = os.RemoveAll(dir)
})
return dir
}
func writeTestSegment(t *testing.T, dir, name string, content []byte) {
t.Helper()
if err := os.WriteFile(filepath.Join(dir, name), content, 0o644); err != nil {
t.Fatalf("failed to write test segment %s: %v", name, err)
}
}
func createTestConfig(walDir, serverURL string) *config.Config {
isDeleteEnabled := true
return &config.Config{
DatabasusHost: serverURL,
DbID: "test-db-id",
Token: "test-token",
PgWalDir: walDir,
IsDeleteWalAfterUpload: &isDeleteEnabled,
}
}
func decompressZstd(t *testing.T, data []byte) []byte {
t.Helper()
decoder, err := zstd.NewReader(nil)
require.NoError(t, err)
defer decoder.Close()
decoded, err := decoder.DecodeAll(data, nil)
require.NoError(t, err)
return decoded
}

View File

@@ -1,45 +1,119 @@
package logger
import (
"fmt"
"io"
"log/slog"
"os"
"sync"
"time"
)
const (
logFileName = "databasus.log"
oldLogFileName = "databasus.log.old"
maxLogFileSize = 5 * 1024 * 1024 // 5MB
)
type rotatingWriter struct {
mu sync.Mutex
file *os.File
currentSize int64
maxSize int64
logPath string
oldLogPath string
}
func (w *rotatingWriter) Write(p []byte) (int, error) {
w.mu.Lock()
defer w.mu.Unlock()
if w.currentSize+int64(len(p)) > w.maxSize {
if err := w.rotate(); err != nil {
return 0, fmt.Errorf("failed to rotate log file: %w", err)
}
}
n, err := w.file.Write(p)
w.currentSize += int64(n)
return n, err
}
func (w *rotatingWriter) rotate() error {
if err := w.file.Close(); err != nil {
return fmt.Errorf("failed to close %s: %w", w.logPath, err)
}
if err := os.Remove(w.oldLogPath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to remove %s: %w", w.oldLogPath, err)
}
if err := os.Rename(w.logPath, w.oldLogPath); err != nil {
return fmt.Errorf("failed to rename %s to %s: %w", w.logPath, w.oldLogPath, err)
}
f, err := os.OpenFile(w.logPath, os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return fmt.Errorf("failed to create new %s: %w", w.logPath, err)
}
w.file = f
w.currentSize = 0
return nil
}
var (
loggerInstance *slog.Logger
once sync.Once
)
func Init(isDebug bool) {
level := slog.LevelInfo
if isDebug {
level = slog.LevelDebug
}
once.Do(func() {
loggerInstance = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
Level: level,
ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr {
if a.Key == slog.TimeKey {
a.Value = slog.StringValue(time.Now().Format("2006/01/02 15:04:05"))
}
if a.Key == slog.LevelKey {
return slog.Attr{}
}
return a
},
}))
})
}
// GetLogger returns a singleton slog.Logger that logs to the console
func GetLogger() *slog.Logger {
if loggerInstance == nil {
Init(false)
}
once.Do(func() {
initialize()
})
return loggerInstance
}
func initialize() {
writer := buildWriter()
loggerInstance = slog.New(slog.NewTextHandler(writer, &slog.HandlerOptions{
Level: slog.LevelInfo,
ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr {
if a.Key == slog.TimeKey {
a.Value = slog.StringValue(time.Now().Format("2006/01/02 15:04:05"))
}
if a.Key == slog.LevelKey {
return slog.Attr{}
}
return a
},
}))
}
func buildWriter() io.Writer {
f, err := os.OpenFile(logFileName, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
fmt.Fprintf(os.Stderr, "Warning: failed to open %s for logging: %v\n", logFileName, err)
return os.Stdout
}
var currentSize int64
if info, err := f.Stat(); err == nil {
currentSize = info.Size()
}
rw := &rotatingWriter{
file: f,
currentSize: currentSize,
maxSize: maxLogFileSize,
logPath: logFileName,
oldLogPath: oldLogFileName,
}
return io.MultiWriter(os.Stdout, rw)
}

View File

@@ -0,0 +1,128 @@
package logger
import (
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_Write_DataWrittenToFile(t *testing.T) {
rw, logPath, _ := setupRotatingWriter(t, 1024)
data := []byte("hello world\n")
n, err := rw.Write(data)
require.NoError(t, err)
assert.Equal(t, len(data), n)
assert.Equal(t, int64(len(data)), rw.currentSize)
content, err := os.ReadFile(logPath)
require.NoError(t, err)
assert.Equal(t, string(data), string(content))
}
func Test_Write_WhenLimitExceeded_FileRotated(t *testing.T) {
rw, logPath, oldLogPath := setupRotatingWriter(t, 100)
firstData := []byte(strings.Repeat("A", 80))
_, err := rw.Write(firstData)
require.NoError(t, err)
secondData := []byte(strings.Repeat("B", 30))
_, err = rw.Write(secondData)
require.NoError(t, err)
oldContent, err := os.ReadFile(oldLogPath)
require.NoError(t, err)
assert.Equal(t, string(firstData), string(oldContent))
newContent, err := os.ReadFile(logPath)
require.NoError(t, err)
assert.Equal(t, string(secondData), string(newContent))
assert.Equal(t, int64(len(secondData)), rw.currentSize)
}
func Test_Write_WhenOldFileExists_OldFileReplaced(t *testing.T) {
rw, _, oldLogPath := setupRotatingWriter(t, 100)
require.NoError(t, os.WriteFile(oldLogPath, []byte("stale data"), 0o644))
_, err := rw.Write([]byte(strings.Repeat("A", 80)))
require.NoError(t, err)
_, err = rw.Write([]byte(strings.Repeat("B", 30)))
require.NoError(t, err)
oldContent, err := os.ReadFile(oldLogPath)
require.NoError(t, err)
assert.Equal(t, strings.Repeat("A", 80), string(oldContent))
}
func Test_Write_MultipleSmallWrites_CurrentSizeAccumulated(t *testing.T) {
rw, _, _ := setupRotatingWriter(t, 1024)
var totalWritten int64
for i := 0; i < 10; i++ {
data := []byte("line\n")
n, err := rw.Write(data)
require.NoError(t, err)
totalWritten += int64(n)
}
assert.Equal(t, totalWritten, rw.currentSize)
assert.Equal(t, int64(50), rw.currentSize)
}
func Test_Write_ExactlyAtBoundary_NoRotationUntilNextByte(t *testing.T) {
rw, logPath, oldLogPath := setupRotatingWriter(t, 100)
exactData := []byte(strings.Repeat("X", 100))
_, err := rw.Write(exactData)
require.NoError(t, err)
_, err = os.Stat(oldLogPath)
assert.True(t, os.IsNotExist(err), ".old file should not exist yet")
content, err := os.ReadFile(logPath)
require.NoError(t, err)
assert.Equal(t, string(exactData), string(content))
_, err = rw.Write([]byte("Z"))
require.NoError(t, err)
_, err = os.Stat(oldLogPath)
assert.NoError(t, err, ".old file should exist after exceeding limit")
assert.Equal(t, int64(1), rw.currentSize)
}
func setupRotatingWriter(t *testing.T, maxSize int64) (*rotatingWriter, string, string) {
t.Helper()
dir := t.TempDir()
logPath := filepath.Join(dir, "test.log")
oldLogPath := filepath.Join(dir, "test.log.old")
f, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY, 0o644)
require.NoError(t, err)
rw := &rotatingWriter{
file: f,
currentSize: 0,
maxSize: maxSize,
logPath: logPath,
oldLogPath: oldLogPath,
}
t.Cleanup(func() {
rw.file.Close()
})
return rw, logPath, oldLogPath
}

View File

@@ -59,6 +59,10 @@ func (c *BackupCleaner) Run(ctx context.Context) {
if err := c.cleanExceededBackups(); err != nil {
c.logger.Error("Failed to clean exceeded backups", "error", err)
}
if err := c.cleanStaleUploadedBasebackups(); err != nil {
c.logger.Error("Failed to clean stale uploaded basebackups", "error", err)
}
}
}
})
@@ -100,6 +104,67 @@ func (c *BackupCleaner) AddBackupRemoveListener(listener backups_core.BackupRemo
c.backupRemoveListeners = append(c.backupRemoveListeners, listener)
}
func (c *BackupCleaner) cleanStaleUploadedBasebackups() error {
staleBackups, err := c.backupRepository.FindStaleUploadedBasebackups(
time.Now().UTC().Add(-10 * time.Minute),
)
if err != nil {
return fmt.Errorf("failed to find stale uploaded basebackups: %w", err)
}
for _, backup := range staleBackups {
staleStorage, storageErr := c.storageService.GetStorageByID(backup.StorageID)
if storageErr != nil {
c.logger.Error(
"Failed to get storage for stale basebackup cleanup",
"backupId", backup.ID,
"storageId", backup.StorageID,
"error", storageErr,
)
} else {
if err := staleStorage.DeleteFile(c.fieldEncryptor, backup.FileName); err != nil {
c.logger.Error(
"Failed to delete stale basebackup file",
"backupId", backup.ID,
"fileName", backup.FileName,
"error", err,
)
}
metadataFileName := backup.FileName + ".metadata"
if err := staleStorage.DeleteFile(c.fieldEncryptor, metadataFileName); err != nil {
c.logger.Error(
"Failed to delete stale basebackup metadata file",
"backupId", backup.ID,
"fileName", metadataFileName,
"error", err,
)
}
}
failMsg := "basebackup finalization timed out after 10 minutes"
backup.Status = backups_core.BackupStatusFailed
backup.FailMessage = &failMsg
if err := c.backupRepository.Save(backup); err != nil {
c.logger.Error(
"Failed to mark stale uploaded basebackup as failed",
"backupId", backup.ID,
"error", err,
)
continue
}
c.logger.Info(
"Marked stale uploaded basebackup as failed and cleaned storage",
"backupId", backup.ID,
"databaseId", backup.DatabaseID,
)
}
return nil
}
func (c *BackupCleaner) cleanByRetentionPolicy() error {
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {

View File

@@ -1004,6 +1004,191 @@ func (m *mockBackupRemoveListener) OnBeforeBackupRemove(backup *backups_core.Bac
return nil
}
func Test_CleanStaleUploadedBasebackups_MarksAsFailed(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
staleTime := time.Now().UTC().Add(-15 * time.Minute)
walBackupType := backups_core.PgWalBackupTypeFullBackup
staleBackup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
PgWalBackupType: &walBackupType,
UploadCompletedAt: &staleTime,
CreatedAt: staleTime,
}
err := backupRepository.Save(staleBackup)
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanStaleUploadedBasebackups()
assert.NoError(t, err)
updated, err := backupRepository.FindByID(staleBackup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusFailed, updated.Status)
assert.NotNil(t, updated.FailMessage)
assert.Contains(t, *updated.FailMessage, "finalization timed out")
}
func Test_CleanStaleUploadedBasebackups_SkipsRecentUploads(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
recentTime := time.Now().UTC().Add(-2 * time.Minute)
walBackupType := backups_core.PgWalBackupTypeFullBackup
recentBackup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
PgWalBackupType: &walBackupType,
UploadCompletedAt: &recentTime,
CreatedAt: recentTime,
}
err := backupRepository.Save(recentBackup)
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanStaleUploadedBasebackups()
assert.NoError(t, err)
updated, err := backupRepository.FindByID(recentBackup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusInProgress, updated.Status)
}
func Test_CleanStaleUploadedBasebackups_SkipsActiveStreaming(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
walBackupType := backups_core.PgWalBackupTypeFullBackup
activeBackup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
PgWalBackupType: &walBackupType,
CreatedAt: time.Now().UTC().Add(-30 * time.Minute),
}
err := backupRepository.Save(activeBackup)
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanStaleUploadedBasebackups()
assert.NoError(t, err)
updated, err := backupRepository.FindByID(activeBackup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusInProgress, updated.Status)
assert.Nil(t, updated.UploadCompletedAt)
}
func Test_CleanStaleUploadedBasebackups_CleansStorageFiles(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
staleTime := time.Now().UTC().Add(-15 * time.Minute)
walBackupType := backups_core.PgWalBackupTypeFullBackup
staleBackup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
PgWalBackupType: &walBackupType,
UploadCompletedAt: &staleTime,
BackupSizeMb: 500,
FileName: "stale-basebackup-test-file",
CreatedAt: staleTime,
}
err := backupRepository.Save(staleBackup)
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.cleanStaleUploadedBasebackups()
assert.NoError(t, err)
updated, err := backupRepository.FindByID(staleBackup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusFailed, updated.Status)
assert.NotNil(t, updated.FailMessage)
assert.Contains(t, *updated.FailMessage, "finalization timed out")
}
func createTestInterval() *intervals.Interval {
timeOfDay := "04:00"
interval := &intervals.Interval{

View File

@@ -3,12 +3,10 @@ package backups_controllers
import (
"io"
"net/http"
"strconv"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_dto "databasus-backend/internal/features/backups/backups/dto"
backups_services "databasus-backend/internal/features/backups/backups/services"
"databasus-backend/internal/features/databases"
@@ -25,8 +23,11 @@ func (c *PostgreWalBackupController) RegisterRoutes(router *gin.RouterGroup) {
walRoutes := router.Group("/backups/postgres/wal")
walRoutes.GET("/next-full-backup-time", c.GetNextFullBackupTime)
walRoutes.GET("/is-wal-chain-valid-since-last-full-backup", c.IsWalChainValidSinceLastBackup)
walRoutes.POST("/error", c.ReportError)
walRoutes.POST("/upload", c.Upload)
walRoutes.POST("/upload/wal", c.UploadWalSegment)
walRoutes.POST("/upload/full-start", c.StartFullBackupUpload)
walRoutes.POST("/upload/full-complete", c.CompleteFullBackupUpload)
walRoutes.GET("/restore/plan", c.GetRestorePlan)
walRoutes.GET("/restore/download", c.DownloadBackupFile)
}
@@ -90,91 +91,66 @@ func (c *PostgreWalBackupController) ReportError(ctx *gin.Context) {
ctx.Status(http.StatusOK)
}
// Upload
// @Summary Stream upload a basebackup or WAL segment
// @Description Accepts a zstd-compressed binary stream and stores it in the database's configured storage.
// The server generates the storage filename; agents do not control the destination path.
// For WAL segment uploads the server validates the WAL chain and returns 409 if a gap is detected
// or 400 if no full backup exists yet (agent should trigger a full basebackup in both cases).
// IsWalChainValidSinceLastBackup
// @Summary Check WAL chain validity since last full backup
// @Description Checks whether the WAL chain is continuous since the last completed full backup.
// Returns isValid=true if the chain is intact, or isValid=false with error details if not.
// @Tags backups-wal
// @Accept application/octet-stream
// @Produce json
// @Security AgentToken
// @Param X-Upload-Type header string true "Upload type" Enums(basebackup, wal)
// @Param X-Wal-Segment-Name header string false "24-hex WAL segment identifier (required for wal uploads, e.g. 0000000100000001000000AB)"
// @Param X-Wal-Segment-Size header int false "WAL segment size in bytes reported by the PostgreSQL instance (default: 16777216)"
// @Param fullBackupWalStartSegment query string false "First WAL segment needed to make the basebackup consistent (required for basebackup uploads)"
// @Param fullBackupWalStopSegment query string false "Last WAL segment included in the basebackup (required for basebackup uploads)"
// @Success 204
// @Failure 400 {object} backups_dto.UploadGapResponse "No full backup exists (error: no_full_backup)"
// @Success 200 {object} backups_dto.IsWalChainValidResponse
// @Failure 401 {object} map[string]string
// @Failure 409 {object} backups_dto.UploadGapResponse "WAL chain gap detected (error: gap_detected)"
// @Failure 500 {object} map[string]string
// @Router /backups/postgres/wal/upload [post]
func (c *PostgreWalBackupController) Upload(ctx *gin.Context) {
// @Router /backups/postgres/wal/is-wal-chain-valid-since-last-full-backup [get]
func (c *PostgreWalBackupController) IsWalChainValidSinceLastBackup(ctx *gin.Context) {
database, err := c.getDatabase(ctx)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
return
}
uploadType := backups_core.PgWalUploadType(ctx.GetHeader("X-Upload-Type"))
if uploadType != backups_core.PgWalUploadTypeBasebackup &&
uploadType != backups_core.PgWalUploadTypeWal {
response, err := c.walService.IsWalChainValid(database)
if err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, response)
}
// UploadWalSegment
// @Summary Stream upload a WAL segment
// @Description Accepts a zstd-compressed WAL segment binary stream and stores it in the database's configured storage.
// WAL segments are accepted unconditionally.
// @Tags backups-wal
// @Accept application/octet-stream
// @Security AgentToken
// @Param X-Wal-Segment-Name header string true "24-hex WAL segment identifier (e.g. 0000000100000001000000AB)"
// @Success 204
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /backups/postgres/wal/upload/wal [post]
func (c *PostgreWalBackupController) UploadWalSegment(ctx *gin.Context) {
database, err := c.getDatabase(ctx)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
return
}
walSegmentName := ctx.GetHeader("X-Wal-Segment-Name")
if walSegmentName == "" {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "X-Upload-Type must be 'basebackup' or 'wal'"},
gin.H{"error": "X-Wal-Segment-Name is required for wal uploads"},
)
return
}
walSegmentName := ""
if uploadType == backups_core.PgWalUploadTypeWal {
walSegmentName = ctx.GetHeader("X-Wal-Segment-Name")
if walSegmentName == "" {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "X-Wal-Segment-Name is required for wal uploads"},
)
return
}
}
if uploadType == backups_core.PgWalUploadTypeBasebackup {
if ctx.Query("fullBackupWalStartSegment") == "" ||
ctx.Query("fullBackupWalStopSegment") == "" {
ctx.JSON(
http.StatusBadRequest,
gin.H{
"error": "fullBackupWalStartSegment and fullBackupWalStopSegment are required for basebackup uploads",
},
)
return
}
}
walSegmentSizeBytes := int64(0)
if raw := ctx.GetHeader("X-Wal-Segment-Size"); raw != "" {
parsed, parseErr := strconv.ParseInt(raw, 10, 64)
if parseErr != nil || parsed <= 0 {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "X-Wal-Segment-Size must be a positive integer"},
)
return
}
walSegmentSizeBytes = parsed
}
gapResp, uploadErr := c.walService.UploadWal(
uploadErr := c.walService.UploadWalSegment(
ctx.Request.Context(),
database,
uploadType,
walSegmentName,
ctx.Query("fullBackupWalStartSegment"),
ctx.Query("fullBackupWalStopSegment"),
walSegmentSizeBytes,
ctx.Request.Body,
)
@@ -183,17 +159,81 @@ func (c *PostgreWalBackupController) Upload(ctx *gin.Context) {
return
}
if gapResp != nil {
if gapResp.Error == "no_full_backup" {
ctx.JSON(http.StatusBadRequest, gapResp)
return
}
ctx.Status(http.StatusNoContent)
}
ctx.JSON(http.StatusConflict, gapResp)
// StartFullBackupUpload
// @Summary Stream upload a full basebackup (Phase 1)
// @Description Accepts a zstd-compressed basebackup binary stream and stores it in the database's configured storage.
// Returns a backupId that must be completed via /upload/full-complete with WAL segment names.
// @Tags backups-wal
// @Accept application/octet-stream
// @Produce json
// @Security AgentToken
// @Success 200 {object} backups_dto.UploadBasebackupResponse
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /backups/postgres/wal/upload/full-start [post]
func (c *PostgreWalBackupController) StartFullBackupUpload(ctx *gin.Context) {
database, err := c.getDatabase(ctx)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
return
}
ctx.Status(http.StatusNoContent)
backupID, uploadErr := c.walService.UploadBasebackup(
ctx.Request.Context(),
database,
ctx.Request.Body,
)
if uploadErr != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": uploadErr.Error()})
return
}
ctx.JSON(http.StatusOK, backups_dto.UploadBasebackupResponse{
BackupID: backupID,
})
}
// CompleteFullBackupUpload
// @Summary Complete a previously uploaded basebackup (Phase 2)
// @Description Sets WAL segment names and marks the basebackup as completed, or marks it as failed if an error is provided.
// @Tags backups-wal
// @Accept json
// @Security AgentToken
// @Param request body backups_dto.FinalizeBasebackupRequest true "Completion details"
// @Success 200
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /backups/postgres/wal/upload/full-complete [post]
func (c *PostgreWalBackupController) CompleteFullBackupUpload(ctx *gin.Context) {
database, err := c.getDatabase(ctx)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
return
}
var request backups_dto.FinalizeBasebackupRequest
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := c.walService.FinalizeBasebackup(
database,
request.BackupID,
request.StartSegment,
request.StopSegment,
request.Error,
); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.Status(http.StatusOK)
}
// GetRestorePlan

View File

@@ -38,7 +38,7 @@ func Test_WalUpload_InProgressStatusSetBeforeStream(t *testing.T) {
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
pr, pw := io.Pipe()
req := newWalUploadRequest(pr, agentToken, "wal", "000000010000000100000011", "", "")
req := newWalSegmentUploadRequest(pr, agentToken, "000000010000000100000011")
w := httptest.NewRecorder()
done := make(chan struct{})
@@ -67,7 +67,7 @@ func Test_WalUpload_CompletedStatusAfterSuccessfulStream(t *testing.T) {
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
body := bytes.NewReader([]byte("wal segment content"))
req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000011", "", "")
req := newWalSegmentUploadRequest(body, agentToken, "000000010000000100000011")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -99,7 +99,7 @@ func Test_WalUpload_FailedStatusWithErrorOnStreamError(t *testing.T) {
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
pr, pw := io.Pipe()
req := newWalUploadRequest(pr, agentToken, "wal", "000000010000000100000011", "", "")
req := newWalSegmentUploadRequest(pr, agentToken, "000000010000000100000011")
w := httptest.NewRecorder()
done := make(chan struct{})
@@ -129,59 +129,171 @@ func Test_WalUpload_FailedStatusWithErrorOnStreamError(t *testing.T) {
assert.NotNil(t, walBackup.FailMessage)
}
func Test_WalUpload_Basebackup_MissingWalSegments_Returns400(t *testing.T) {
func Test_WalUpload_Basebackup_StreamingUpload_Returns200WithBackupId(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
body := bytes.NewReader([]byte("basebackup content"))
req := newWalUploadRequest(body, agentToken, backups_core.PgWalUploadTypeBasebackup, "", "", "")
req, _ := http.NewRequest(http.MethodPost, "/api/v1/backups/postgres/wal/upload/full-start", body)
req.Header.Set("Authorization", agentToken)
req.Header.Set("Content-Type", "application/octet-stream")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
var response backups_dto.UploadBasebackupResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
assert.NotEqual(t, uuid.Nil, response.BackupID)
backup, err := backups_core.GetBackupRepository().FindByID(response.BackupID)
require.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusInProgress, backup.Status)
assert.NotNil(t, backup.UploadCompletedAt)
}
func Test_FinalizeBasebackup_ValidSegments_MarksCompleted(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
backupID := uploadBasebackupPhase1(t, router, agentToken)
completeFullBackupUpload(t, router, agentToken, backupID,
"000000010000000100000001", "000000010000000100000010", nil)
backup, err := backups_core.GetBackupRepository().FindByID(backupID)
require.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
require.NotNil(t, backup.PgFullBackupWalStartSegmentName)
assert.Equal(t, "000000010000000100000001", *backup.PgFullBackupWalStartSegmentName)
require.NotNil(t, backup.PgFullBackupWalStopSegmentName)
assert.Equal(t, "000000010000000100000010", *backup.PgFullBackupWalStopSegmentName)
}
func Test_FinalizeBasebackup_WithError_MarksFailed(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
backupID := uploadBasebackupPhase1(t, router, agentToken)
errMsg := "pg_basebackup stderr parse failed"
completeFullBackupUpload(t, router, agentToken, backupID, "", "", &errMsg)
backup, err := backups_core.GetBackupRepository().FindByID(backupID)
require.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusFailed, backup.Status)
require.NotNil(t, backup.FailMessage)
assert.Equal(t, errMsg, *backup.FailMessage)
}
func Test_FinalizeBasebackup_InvalidBackupId_Returns400(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
nonExistentID := uuid.New()
body, _ := json.Marshal(backups_dto.FinalizeBasebackupRequest{
BackupID: nonExistentID,
StartSegment: "000000010000000100000001",
StopSegment: "000000010000000100000010",
})
req, _ := http.NewRequest(
http.MethodPost,
"/api/v1/backups/postgres/wal/upload/full-complete",
bytes.NewReader(body),
)
req.Header.Set("Authorization", agentToken)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
}
func Test_WalUpload_WalSegment_NoFullBackup_Returns400(t *testing.T) {
func Test_FinalizeBasebackup_AlreadyCompleted_Returns400(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
backupID := uploadBasebackupPhase1(t, router, agentToken)
completeFullBackupUpload(t, router, agentToken, backupID,
"000000010000000100000001", "000000010000000100000010", nil)
// Second finalize should fail.
body, _ := json.Marshal(backups_dto.FinalizeBasebackupRequest{
BackupID: backupID,
StartSegment: "000000010000000100000001",
StopSegment: "000000010000000100000010",
})
req, _ := http.NewRequest(
http.MethodPost,
"/api/v1/backups/postgres/wal/upload/full-complete",
bytes.NewReader(body),
)
req.Header.Set("Authorization", agentToken)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
}
func Test_FinalizeBasebackup_InvalidToken_Returns401(t *testing.T) {
router, db, storage, _, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
body, _ := json.Marshal(backups_dto.FinalizeBasebackupRequest{
BackupID: uuid.New(),
StartSegment: "000000010000000100000001",
StopSegment: "000000010000000100000010",
})
req, _ := http.NewRequest(
http.MethodPost,
"/api/v1/backups/postgres/wal/upload/full-complete",
bytes.NewReader(body),
)
req.Header.Set("Authorization", "invalid-token")
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
}
func Test_WalUpload_WalSegment_WithoutFullBackup_Returns204(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
// No full backup inserted — chain anchor is missing.
body := bytes.NewReader([]byte("wal content"))
req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000001", "", "")
req := newWalSegmentUploadRequest(body, agentToken, "000000010000000100000001")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
var resp backups_dto.UploadGapResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "no_full_backup", resp.Error)
assert.Equal(t, http.StatusNoContent, w.Code)
}
func Test_WalUpload_WalSegment_GapDetected_Returns409WithExpectedAndReceived(t *testing.T) {
func Test_WalUpload_WalSegment_WithGap_Returns204(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
// Full backup stops at ...0010; upload one WAL segment at ...0011.
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
uploadWalSegment(t, router, agentToken, "000000010000000100000011")
// Send ...0013 — should be rejected because ...0012 is missing.
// Skip ...0012, upload ...0013 — should succeed (no chain validation on upload).
body := bytes.NewReader([]byte("wal content"))
req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000013", "", "")
req := newWalSegmentUploadRequest(body, agentToken, "000000010000000100000013")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusConflict, w.Code)
var resp backups_dto.UploadGapResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "gap_detected", resp.Error)
assert.Equal(t, "000000010000000100000012", resp.ExpectedSegmentName)
assert.Equal(t, "000000010000000100000013", resp.ReceivedSegmentName)
assert.Equal(t, http.StatusNoContent, w.Code)
}
func Test_WalUpload_WalSegment_DuplicateSegment_Returns200Idempotent(t *testing.T) {
@@ -192,14 +304,14 @@ func Test_WalUpload_WalSegment_DuplicateSegment_Returns200Idempotent(t *testing.
// Upload ...0011 once.
body1 := bytes.NewReader([]byte("wal content"))
req1 := newWalUploadRequest(body1, agentToken, "wal", "000000010000000100000011", "", "")
req1 := newWalSegmentUploadRequest(body1, agentToken, "000000010000000100000011")
w1 := httptest.NewRecorder()
router.ServeHTTP(w1, req1)
require.Equal(t, http.StatusNoContent, w1.Code)
// Upload the same segment again — must return 204 (idempotent).
body2 := bytes.NewReader([]byte("wal content"))
req2 := newWalUploadRequest(body2, agentToken, "wal", "000000010000000100000011", "", "")
req2 := newWalSegmentUploadRequest(body2, agentToken, "000000010000000100000011")
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
@@ -228,7 +340,7 @@ func Test_WalUpload_WalSegment_ValidNextSegment_Returns200AndCreatesRecord(t *te
// First WAL segment after the full backup stop segment.
body := bytes.NewReader([]byte("wal segment data"))
req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000011", "", "")
req := newWalSegmentUploadRequest(body, agentToken, "000000010000000100000011")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -255,6 +367,108 @@ func Test_WalUpload_WalSegment_ValidNextSegment_Returns200AndCreatesRecord(t *te
assert.Equal(t, "000000010000000100000011", *walBackup.PgWalSegmentName)
}
func Test_IsWalChainValid_NoFullBackup_ReturnsFalse(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
var response backups_dto.IsWalChainValidResponse
test_utils.MakeGetRequestAndUnmarshal(
t, router,
"/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup",
agentToken,
http.StatusOK,
&response,
)
assert.False(t, response.IsValid)
assert.Equal(t, "no_full_backup", response.Error)
}
func Test_IsWalChainValid_FullBackupOnly_ReturnsTrue(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
var response backups_dto.IsWalChainValidResponse
test_utils.MakeGetRequestAndUnmarshal(
t, router,
"/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup",
agentToken,
http.StatusOK,
&response,
)
assert.True(t, response.IsValid)
assert.Empty(t, response.Error)
}
func Test_IsWalChainValid_ContinuousChain_ReturnsTrue(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
uploadWalSegment(t, router, agentToken, "000000010000000100000011")
uploadWalSegment(t, router, agentToken, "000000010000000100000012")
uploadWalSegment(t, router, agentToken, "000000010000000100000013")
var response backups_dto.IsWalChainValidResponse
test_utils.MakeGetRequestAndUnmarshal(
t, router,
"/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup",
agentToken,
http.StatusOK,
&response,
)
assert.True(t, response.IsValid)
}
func Test_IsWalChainValid_BrokenChain_ReturnsFalse(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
uploadWalSegment(t, router, agentToken, "000000010000000100000011")
uploadWalSegment(t, router, agentToken, "000000010000000100000012")
uploadWalSegment(t, router, agentToken, "000000010000000100000013")
// Delete the middle segment to create a gap.
middleSeg, err := backups_core.GetBackupRepository().FindWalSegmentByName(
db.ID, "000000010000000100000012",
)
require.NoError(t, err)
require.NotNil(t, middleSeg)
require.NoError(t, backups_core.GetBackupRepository().DeleteByID(middleSeg.ID))
var response backups_dto.IsWalChainValidResponse
test_utils.MakeGetRequestAndUnmarshal(
t, router,
"/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup",
agentToken,
http.StatusOK,
&response,
)
assert.False(t, response.IsValid)
assert.Equal(t, "wal_chain_broken", response.Error)
assert.Equal(t, "000000010000000100000011", response.LastContiguousSegment)
}
func Test_IsWalChainValid_InvalidToken_Returns401(t *testing.T) {
router, db, storage, _, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
resp := test_utils.MakeGetRequest(
t, router,
"/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup",
"invalid-token",
http.StatusUnauthorized,
)
assert.Contains(t, string(resp.Body), "invalid agent token")
}
func Test_ReportError_ValidTokenAndError_CreatesFailedBackupRecord(t *testing.T) {
router, db, storage, agentToken, _ := createWalTestSetup(t)
defer removeWalTestSetup(db, storage)
@@ -457,29 +671,14 @@ func Test_GetNextFullBackupTime_WalSegmentAfterFullBackup_DoesNotImpactTime(t *t
setHourlyInterval(t, router, db.ID, ownerToken)
// Upload basebackup via API.
bbBody := bytes.NewReader([]byte("basebackup content"))
bbReq := newWalUploadRequest(
bbBody, agentToken, backups_core.PgWalUploadTypeBasebackup, "",
"000000010000000100000001", "000000010000000100000010",
)
bbW := httptest.NewRecorder()
router.ServeHTTP(bbW, bbReq)
require.Equal(t, http.StatusNoContent, bbW.Code)
uploadBasebackup(t, router, agentToken,
"000000010000000100000001", "000000010000000100000010")
// Shift the full backup's CreatedAt to 2 hours ago.
twoHoursAgo := time.Now().UTC().Add(-2 * time.Hour)
updateLastFullBackupTime(t, db.ID, twoHoursAgo)
// Upload WAL segment via API.
walBody := bytes.NewReader([]byte("wal segment content"))
walReq := newWalUploadRequest(
walBody, agentToken, backups_core.PgWalUploadTypeWal,
"000000010000000100000011", "", "",
)
walW := httptest.NewRecorder()
router.ServeHTTP(walW, walReq)
require.Equal(t, http.StatusNoContent, walW.Code)
uploadWalSegment(t, router, agentToken, "000000010000000100000011")
var response backups_dto.GetNextFullBackupTimeResponse
test_utils.MakeGetRequestAndUnmarshal(
@@ -508,15 +707,8 @@ func Test_GetNextFullBackupTime_FailedBasebackup_DoesNotImpactTime(t *testing.T)
setHourlyInterval(t, router, db.ID, ownerToken)
// Upload a successful basebackup via API.
bbBody := bytes.NewReader([]byte("basebackup content"))
bbReq := newWalUploadRequest(
bbBody, agentToken, backups_core.PgWalUploadTypeBasebackup, "",
"000000010000000100000001", "000000010000000100000010",
)
bbW := httptest.NewRecorder()
router.ServeHTTP(bbW, bbReq)
require.Equal(t, http.StatusNoContent, bbW.Code)
uploadBasebackup(t, router, agentToken,
"000000010000000100000001", "000000010000000100000010")
// Shift the full backup's CreatedAt to 2 hours ago.
twoHoursAgo := time.Now().UTC().Add(-2 * time.Hour)
@@ -563,15 +755,8 @@ func Test_GetNextFullBackupTime_NewCompletedFullBackup_ImpactsTime(t *testing.T)
setHourlyInterval(t, router, db.ID, ownerToken)
// Upload first basebackup via API.
bb1 := bytes.NewReader([]byte("first basebackup"))
bb1Req := newWalUploadRequest(
bb1, agentToken, backups_core.PgWalUploadTypeBasebackup, "",
"000000010000000100000001", "000000010000000100000010",
)
bb1W := httptest.NewRecorder()
router.ServeHTTP(bb1W, bb1Req)
require.Equal(t, http.StatusNoContent, bb1W.Code)
uploadBasebackup(t, router, agentToken,
"000000010000000100000001", "000000010000000100000010")
// Shift the first backup's CreatedAt to 3 hours ago.
threeHoursAgo := time.Now().UTC().Add(-3 * time.Hour)
@@ -595,15 +780,8 @@ func Test_GetNextFullBackupTime_NewCompletedFullBackup_ImpactsTime(t *testing.T)
"first next time should be in the past (old backup)",
)
// Upload second basebackup via API (created now).
bb2 := bytes.NewReader([]byte("second basebackup"))
bb2Req := newWalUploadRequest(
bb2, agentToken, backups_core.PgWalUploadTypeBasebackup, "",
"000000010000000100000011", "000000010000000100000020",
)
bb2W := httptest.NewRecorder()
router.ServeHTTP(bb2W, bb2Req)
require.Equal(t, http.StatusNoContent, bb2W.Code)
uploadBasebackup(t, router, agentToken,
"000000010000000100000011", "000000010000000100000020")
var secondResponse backups_dto.GetNextFullBackupTimeResponse
test_utils.MakeGetRequestAndUnmarshal(
@@ -841,15 +1019,18 @@ func Test_DownloadRestoreFile_UploadThenDownload_ContentMatches(t *testing.T) {
uploadContent := "test-basebackup-content-for-download"
body := bytes.NewReader([]byte(uploadContent))
req := newWalUploadRequest(
body, agentToken, backups_core.PgWalUploadTypeBasebackup, "",
"000000010000000100000001", "000000010000000100000010",
)
req, _ := http.NewRequest(http.MethodPost, "/api/v1/backups/postgres/wal/upload/full-start", body)
req.Header.Set("Authorization", agentToken)
req.Header.Set("Content-Type", "application/octet-stream")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusNoContent, w.Code)
require.Equal(t, http.StatusOK, w.Code)
WaitForBackupCompletion(t, db.ID, 0, 5*time.Second)
var uploadResp backups_dto.UploadBasebackupResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &uploadResp))
completeFullBackupUpload(t, router, agentToken, uploadResp.BackupID,
"000000010000000100000001", "000000010000000100000010", nil)
var planResp backups_dto.GetRestorePlanResponse
test_utils.MakeGetRequestAndUnmarshal(
@@ -883,7 +1064,7 @@ func Test_DownloadRestoreFile_WalSegment_UploadThenDownload_ContentMatches(t *te
walContent := "test-wal-segment-content-for-download"
body := bytes.NewReader([]byte(walContent))
req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000011", "", "")
req := newWalSegmentUploadRequest(body, agentToken, "000000010000000100000011")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusNoContent, w.Code)
@@ -1088,35 +1269,81 @@ func removeWalTestSetup(db *databases.Database, storage *storages.Storage) {
storages.RemoveTestStorage(storage.ID)
}
func newWalUploadRequest(
func newWalSegmentUploadRequest(
body io.Reader,
agentToken string,
uploadType backups_core.PgWalUploadType,
walSegmentName string,
walStart string,
walStop string,
segmentName string,
) *http.Request {
url := "/api/v1/backups/postgres/wal/upload"
if walStart != "" || walStop != "" {
url += "?fullBackupWalStartSegment=" + walStart + "&fullBackupWalStopSegment=" + walStop
}
req, err := http.NewRequest(http.MethodPost, url, body)
req, err := http.NewRequest(http.MethodPost, "/api/v1/backups/postgres/wal/upload/wal", body)
if err != nil {
panic(err)
}
req.Header.Set("Authorization", agentToken)
req.Header.Set("Content-Type", "application/octet-stream")
req.Header.Set("X-Upload-Type", string(uploadType))
if walSegmentName != "" {
req.Header.Set("X-Wal-Segment-Name", walSegmentName)
}
req.Header.Set("X-Wal-Segment-Name", segmentName)
return req
}
func uploadBasebackupPhase1(
t *testing.T,
router *gin.Engine,
agentToken string,
) uuid.UUID {
t.Helper()
body := bytes.NewReader([]byte("test-basebackup-content"))
req, _ := http.NewRequest(http.MethodPost, "/api/v1/backups/postgres/wal/upload/full-start", body)
req.Header.Set("Authorization", agentToken)
req.Header.Set("Content-Type", "application/octet-stream")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
var response backups_dto.UploadBasebackupResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
require.NotEqual(t, uuid.Nil, response.BackupID)
return response.BackupID
}
func completeFullBackupUpload(
t *testing.T,
router *gin.Engine,
agentToken string,
backupID uuid.UUID,
walStart string,
walStop string,
errMsg *string,
) {
t.Helper()
request := backups_dto.FinalizeBasebackupRequest{
BackupID: backupID,
StartSegment: walStart,
StopSegment: walStop,
Error: errMsg,
}
reqBody, _ := json.Marshal(request)
req, _ := http.NewRequest(
http.MethodPost,
"/api/v1/backups/postgres/wal/upload/full-complete",
bytes.NewReader(reqBody),
)
req.Header.Set("Authorization", agentToken)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
}
func uploadBasebackup(
t *testing.T,
router *gin.Engine,
@@ -1126,15 +1353,8 @@ func uploadBasebackup(
) {
t.Helper()
body := bytes.NewReader([]byte("test-basebackup-content"))
req := newWalUploadRequest(
body, agentToken, backups_core.PgWalUploadTypeBasebackup, "",
walStart, walStop,
)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusNoContent, w.Code)
backupID := uploadBasebackupPhase1(t, router, agentToken)
completeFullBackupUpload(t, router, agentToken, backupID, walStart, walStop, nil)
}
func uploadWalSegment(
@@ -1146,9 +1366,12 @@ func uploadWalSegment(
t.Helper()
body := bytes.NewReader([]byte("test-wal-segment-content"))
req := newWalUploadRequest(
body, agentToken, backups_core.PgWalUploadTypeWal, segmentName, "", "",
)
req, _ := http.NewRequest(http.MethodPost, "/api/v1/backups/postgres/wal/upload/wal", body)
req.Header.Set("Authorization", agentToken)
req.Header.Set("Content-Type", "application/octet-stream")
req.Header.Set("X-Wal-Segment-Name", segmentName)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)

View File

@@ -43,7 +43,8 @@ type Backup struct {
PgVersion *string `json:"pgVersion" gorm:"column:pg_version;type:text"`
PgWalSegmentName *string `json:"pgWalSegmentName" gorm:"column:pg_wal_segment_name;type:text"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
UploadCompletedAt *time.Time `json:"uploadCompletedAt" gorm:"column:upload_completed_at"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
}
func (b *Backup) GenerateFilename(dbName string) {

View File

@@ -349,6 +349,24 @@ func (r *BackupRepository) FindWalSegmentByName(
return &backup, nil
}
func (r *BackupRepository) FindStaleUploadedBasebackups(olderThan time.Time) ([]*Backup, error) {
var backups []*Backup
err := storage.
GetDb().
Where(
"status = ? AND upload_completed_at IS NOT NULL AND upload_completed_at < ?",
BackupStatusInProgress,
olderThan,
).
Find(&backups).Error
if err != nil {
return nil, err
}
return backups, nil
}
func (r *BackupRepository) FindLastWalSegmentAfter(
databaseID uuid.UUID,
afterSegmentName string,

View File

@@ -44,10 +44,10 @@ type ReportErrorRequest struct {
Error string `json:"error" binding:"required"`
}
type UploadGapResponse struct {
Error string `json:"error"`
ExpectedSegmentName string `json:"expectedSegmentName"`
ReceivedSegmentName string `json:"receivedSegmentName"`
type IsWalChainValidResponse struct {
IsValid bool `json:"isValid"`
Error string `json:"error,omitempty"`
LastContiguousSegment string `json:"lastContiguousSegment,omitempty"`
}
type RestorePlanFullBackup struct {
@@ -77,3 +77,14 @@ type GetRestorePlanResponse struct {
TotalSizeBytes int64 `json:"totalSizeBytes"`
LatestAvailableSegment string `json:"latestAvailableSegment"`
}
type UploadBasebackupResponse struct {
BackupID uuid.UUID `json:"backupId"`
}
type FinalizeBasebackupRequest struct {
BackupID uuid.UUID `json:"backupId" binding:"required"`
StartSegment string `json:"startSegment" binding:"required"`
StopSegment string `json:"stopSegment" binding:"required"`
Error *string `json:"error"`
}

View File

@@ -30,75 +30,46 @@ type PostgreWalBackupService struct {
backupService *BackupService
}
// UploadWal accepts a streaming WAL segment or basebackup upload from the agent.
// For WAL segments it validates the WAL chain before accepting. Returns an UploadGapResponse
// (409) when the chain is broken so the agent knows to trigger a full basebackup.
func (s *PostgreWalBackupService) UploadWal(
// UploadWalSegment accepts a streaming WAL segment upload from the agent.
// WAL segments are accepted unconditionally.
func (s *PostgreWalBackupService) UploadWalSegment(
ctx context.Context,
database *databases.Database,
uploadType backups_core.PgWalUploadType,
walSegmentName string,
fullBackupWalStartSegment string,
fullBackupWalStopSegment string,
walSegmentSizeBytes int64,
body io.Reader,
) (*backups_dto.UploadGapResponse, error) {
) error {
if err := s.validateWalBackupType(database); err != nil {
return nil, err
}
if uploadType == backups_core.PgWalUploadTypeBasebackup {
if fullBackupWalStartSegment == "" || fullBackupWalStopSegment == "" {
return nil, fmt.Errorf(
"fullBackupWalStartSegment and fullBackupWalStopSegment are required for basebackup uploads",
)
}
return err
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
if err != nil {
return nil, fmt.Errorf("failed to get backup config: %w", err)
return fmt.Errorf("failed to get backup config: %w", err)
}
if backupConfig.Storage == nil {
return nil, fmt.Errorf("no storage configured for database %s", database.ID)
return fmt.Errorf("no storage configured for database %s", database.ID)
}
if uploadType == backups_core.PgWalUploadTypeWal {
// Idempotency: check before chain validation so a successful re-upload is
// not misidentified as a gap.
existing, err := s.backupRepository.FindWalSegmentByName(database.ID, walSegmentName)
if err != nil {
return nil, fmt.Errorf("failed to check for duplicate WAL segment: %w", err)
}
if existing != nil {
return nil, nil
}
gapResp, err := s.validateWalChain(database.ID, walSegmentName, walSegmentSizeBytes)
if err != nil {
return nil, err
}
if gapResp != nil {
return gapResp, nil
}
existing, err := s.backupRepository.FindWalSegmentByName(database.ID, walSegmentName)
if err != nil {
return fmt.Errorf("failed to check for duplicate WAL segment: %w", err)
}
backup := s.createBackupRecord(
if existing != nil {
return nil
}
backup := s.createWalSegmentRecord(
database.ID,
backupConfig.Storage.ID,
uploadType,
database.Name,
walSegmentName,
fullBackupWalStartSegment,
fullBackupWalStopSegment,
backupConfig.Encryption,
)
if err := s.backupRepository.Save(backup); err != nil {
return nil, fmt.Errorf("failed to create backup record: %w", err)
return fmt.Errorf("failed to create backup record: %w", err)
}
sizeBytes, streamErr := s.streamToStorage(ctx, backup, backupConfig, body)
@@ -106,12 +77,106 @@ func (s *PostgreWalBackupService) UploadWal(
errMsg := streamErr.Error()
s.markFailed(backup, errMsg)
return nil, fmt.Errorf("upload failed: %w", streamErr)
return fmt.Errorf("upload failed: %w", streamErr)
}
s.markCompleted(backup, sizeBytes)
return nil, nil
return nil
}
// UploadBasebackup accepts a streaming basebackup upload from the agent (Phase 1).
// The backup stays IN_PROGRESS with UploadCompletedAt set after streaming finishes.
// The agent must call FinalizeBasebackup (Phase 2) with WAL segment names to complete.
func (s *PostgreWalBackupService) UploadBasebackup(
ctx context.Context,
database *databases.Database,
body io.Reader,
) (uuid.UUID, error) {
if err := s.validateWalBackupType(database); err != nil {
return uuid.Nil, err
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
if err != nil {
return uuid.Nil, fmt.Errorf("failed to get backup config: %w", err)
}
if backupConfig.Storage == nil {
return uuid.Nil, fmt.Errorf("no storage configured for database %s", database.ID)
}
backup := s.createBasebackupRecord(
database.ID,
backupConfig.Storage.ID,
database.Name,
backupConfig.Encryption,
)
if err := s.backupRepository.Save(backup); err != nil {
return uuid.Nil, fmt.Errorf("failed to create backup record: %w", err)
}
sizeBytes, streamErr := s.streamToStorage(ctx, backup, backupConfig, body)
if streamErr != nil {
errMsg := streamErr.Error()
s.markFailed(backup, errMsg)
return uuid.Nil, fmt.Errorf("upload failed: %w", streamErr)
}
now := time.Now().UTC()
backup.UploadCompletedAt = &now
backup.BackupSizeMb = float64(sizeBytes) / (1024 * 1024)
if err := s.backupRepository.Save(backup); err != nil {
return uuid.Nil, fmt.Errorf("failed to update backup after upload: %w", err)
}
return backup.ID, nil
}
// FinalizeBasebackup completes a previously uploaded basebackup (Phase 2).
// Sets WAL segment names and marks the backup as COMPLETED, or marks it FAILED if errorMsg is provided.
func (s *PostgreWalBackupService) FinalizeBasebackup(
database *databases.Database,
backupID uuid.UUID,
startSegment string,
stopSegment string,
errorMsg *string,
) error {
if err := s.validateWalBackupType(database); err != nil {
return err
}
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
return fmt.Errorf("backup not found: %w", err)
}
if backup.DatabaseID != database.ID {
return fmt.Errorf("backup does not belong to this database")
}
if backup.Status != backups_core.BackupStatusInProgress || backup.UploadCompletedAt == nil {
return fmt.Errorf("backup is not awaiting finalization")
}
if errorMsg != nil {
s.markFailed(backup, *errorMsg)
return nil
}
backup.PgFullBackupWalStartSegmentName = &startSegment
backup.PgFullBackupWalStopSegmentName = &stopSegment
backup.Status = backups_core.BackupStatusCompleted
if err := s.backupRepository.Save(backup); err != nil {
return fmt.Errorf("failed to finalize backup: %w", err)
}
return nil
}
func (s *PostgreWalBackupService) GetRestorePlan(
@@ -299,97 +364,97 @@ func (s *PostgreWalBackupService) ReportError(
return nil
}
func (s *PostgreWalBackupService) validateWalChain(
databaseID uuid.UUID,
incomingSegment string,
walSegmentSizeBytes int64,
) (*backups_dto.UploadGapResponse, error) {
fullBackup, err := s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(databaseID)
// IsWalChainValid checks whether the WAL chain is continuous since the last completed full backup.
func (s *PostgreWalBackupService) IsWalChainValid(
database *databases.Database,
) (*backups_dto.IsWalChainValidResponse, error) {
if err := s.validateWalBackupType(database); err != nil {
return nil, err
}
fullBackup, err := s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(database.ID)
if err != nil {
return nil, fmt.Errorf("failed to query full backup: %w", err)
}
// No full backup exists yet: cannot accept WAL segments without a chain anchor.
if fullBackup == nil || fullBackup.PgFullBackupWalStopSegmentName == nil {
return &backups_dto.UploadGapResponse{
Error: "no_full_backup",
ExpectedSegmentName: "",
ReceivedSegmentName: incomingSegment,
return &backups_dto.IsWalChainValidResponse{
IsValid: false,
Error: "no_full_backup",
}, nil
}
stopSegment := *fullBackup.PgFullBackupWalStopSegmentName
startSegment := ""
if fullBackup.PgFullBackupWalStartSegmentName != nil {
startSegment = *fullBackup.PgFullBackupWalStartSegmentName
}
lastWal, err := s.backupRepository.FindLastWalSegmentAfter(databaseID, stopSegment)
walSegments, err := s.backupRepository.FindCompletedWalSegmentsAfter(database.ID, startSegment)
if err != nil {
return nil, fmt.Errorf("failed to query last WAL segment: %w", err)
return nil, fmt.Errorf("failed to query WAL segments: %w", err)
}
walCalculator := util_wal.NewWalCalculator(walSegmentSizeBytes)
var chainTail string
if lastWal != nil && lastWal.PgWalSegmentName != nil {
chainTail = *lastWal.PgWalSegmentName
} else {
chainTail = stopSegment
}
expectedNext, err := walCalculator.NextSegment(chainTail)
if err != nil {
return nil, fmt.Errorf("WAL arithmetic failed for %q: %w", chainTail, err)
}
if incomingSegment != expectedNext {
return &backups_dto.UploadGapResponse{
Error: "gap_detected",
ExpectedSegmentName: expectedNext,
ReceivedSegmentName: incomingSegment,
chainErr := s.validateRestoreWalChain(fullBackup, walSegments)
if chainErr != nil {
return &backups_dto.IsWalChainValidResponse{
IsValid: false,
Error: chainErr.Error,
LastContiguousSegment: chainErr.LastContiguousSegment,
}, nil
}
return nil, nil
return &backups_dto.IsWalChainValidResponse{
IsValid: true,
}, nil
}
func (s *PostgreWalBackupService) createBackupRecord(
func (s *PostgreWalBackupService) createBasebackupRecord(
databaseID uuid.UUID,
storageID uuid.UUID,
uploadType backups_core.PgWalUploadType,
dbName string,
walSegmentName string,
fullBackupWalStartSegment string,
fullBackupWalStopSegment string,
encryption backups_config.BackupEncryption,
) *backups_core.Backup {
now := time.Now().UTC()
walBackupType := backups_core.PgWalBackupTypeFullBackup
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: databaseID,
StorageID: storageID,
Status: backups_core.BackupStatusInProgress,
Encryption: encryption,
CreatedAt: now,
ID: uuid.New(),
DatabaseID: databaseID,
StorageID: storageID,
Status: backups_core.BackupStatusInProgress,
PgWalBackupType: &walBackupType,
Encryption: encryption,
CreatedAt: now,
}
backup.GenerateFilename(dbName)
if uploadType == backups_core.PgWalUploadTypeBasebackup {
walBackupType := backups_core.PgWalBackupTypeFullBackup
backup.PgWalBackupType = &walBackupType
return backup
}
if fullBackupWalStartSegment != "" {
backup.PgFullBackupWalStartSegmentName = &fullBackupWalStartSegment
}
func (s *PostgreWalBackupService) createWalSegmentRecord(
databaseID uuid.UUID,
storageID uuid.UUID,
dbName string,
walSegmentName string,
encryption backups_config.BackupEncryption,
) *backups_core.Backup {
now := time.Now().UTC()
walBackupType := backups_core.PgWalBackupTypeWalSegment
if fullBackupWalStopSegment != "" {
backup.PgFullBackupWalStopSegmentName = &fullBackupWalStopSegment
}
} else {
walBackupType := backups_core.PgWalBackupTypeWalSegment
backup.PgWalBackupType = &walBackupType
backup.PgWalSegmentName = &walSegmentName
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: databaseID,
StorageID: storageID,
Status: backups_core.BackupStatusInProgress,
PgWalBackupType: &walBackupType,
PgWalSegmentName: &walSegmentName,
Encryption: encryption,
CreatedAt: now,
}
backup.GenerateFilename(dbName)
return backup
}

View File

@@ -0,0 +1,7 @@
-- +goose Up
ALTER TABLE backups
ADD COLUMN upload_completed_at TIMESTAMPTZ;
-- +goose Down
ALTER TABLE backups
DROP COLUMN upload_completed_at;