mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 00:32:03 +02:00
Compare commits
32 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f0064b4be3 | ||
|
|
94505bab3f | ||
|
|
9acf3cff09 | ||
|
|
0d7e147df6 | ||
|
|
1394b47570 | ||
|
|
a9865ae3e4 | ||
|
|
4b5478e60a | ||
|
|
6355301903 | ||
|
|
29b403a9c6 | ||
|
|
12606053f4 | ||
|
|
904b386378 | ||
|
|
1d9738b808 | ||
|
|
58b37f4c92 | ||
|
|
6c4f814c94 | ||
|
|
bcd13c27d3 | ||
|
|
120f9600bf | ||
|
|
563c7c1d64 | ||
|
|
68f15f7661 | ||
|
|
627d96a00d | ||
|
|
02b9a9ec8d | ||
|
|
415dda8752 | ||
|
|
3faf85796a | ||
|
|
edd2759f5a | ||
|
|
c283856f38 | ||
|
|
6059e1a33b | ||
|
|
2deda2e7ea | ||
|
|
acf1143752 | ||
|
|
889063a8b4 | ||
|
|
a1e20e7b10 | ||
|
|
7e76945550 | ||
|
|
d98acfc4af | ||
|
|
0ffc7c8c96 |
25
.github/workflows/ci-release.yml
vendored
25
.github/workflows/ci-release.yml
vendored
@@ -164,6 +164,27 @@ jobs:
|
||||
cd agent
|
||||
go test -count=1 -failfast ./internal/...
|
||||
|
||||
e2e-agent:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint-agent]
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Run e2e tests
|
||||
run: |
|
||||
cd agent
|
||||
make e2e
|
||||
|
||||
- name: Cleanup
|
||||
if: always()
|
||||
run: |
|
||||
cd agent/e2e
|
||||
docker compose down -v --rmi local || true
|
||||
rm -rf artifacts || true
|
||||
|
||||
# Self-hosted: performant high-frequency CPU is used to start many containers and run tests fast. Tests
|
||||
# step is bottle-neck, because we need a lot of containers and cannot parallelize tests due to shared resources
|
||||
test-backend:
|
||||
runs-on: self-hosted
|
||||
needs: [lint-backend]
|
||||
@@ -497,7 +518,7 @@ jobs:
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: node:20
|
||||
needs: [test-backend, test-frontend, test-agent]
|
||||
needs: [test-backend, test-frontend, test-agent, e2e-agent]
|
||||
if: ${{ github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, '[skip-release]') }}
|
||||
outputs:
|
||||
should_release: ${{ steps.version_bump.outputs.should_release }}
|
||||
@@ -590,7 +611,7 @@ jobs:
|
||||
|
||||
build-only:
|
||||
runs-on: self-hosted
|
||||
needs: [test-backend, test-frontend, test-agent]
|
||||
needs: [test-backend, test-frontend, test-agent, e2e-agent]
|
||||
if: ${{ github.ref == 'refs/heads/main' && contains(github.event.head_commit.message, '[skip-release]') }}
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -5,6 +5,7 @@ databasus-data/
|
||||
.env
|
||||
pgdata/
|
||||
docker-compose.yml
|
||||
!agent/e2e/docker-compose.yml
|
||||
node_modules/
|
||||
.idea
|
||||
/articles
|
||||
|
||||
34
AGENTS.md
34
AGENTS.md
@@ -73,6 +73,10 @@ This is NOT a strict set of rules, but a set of recommendations to help you writ
|
||||
- Patch the answer accordingly
|
||||
- Verify edge cases are handled
|
||||
|
||||
7. **Fix the reason, not the symptom:**
|
||||
- If you find a bug or issue, ask "Why did this happen?" and fix the root cause
|
||||
- Avoid quick fixes that don't address underlying problems
|
||||
|
||||
### Application guidelines:
|
||||
|
||||
**Scale your response to the task:**
|
||||
@@ -88,6 +92,36 @@ This is NOT a strict set of rules, but a set of recommendations to help you writ
|
||||
|
||||
## Backend guidelines
|
||||
|
||||
### Naming
|
||||
|
||||
Variables and functions naming are the most important part of code readability. Always choose descriptive and meaningful names that clearly indicate the purpose and intent of the code.
|
||||
|
||||
Avoid abbreviations, unless they are widely accepted and unambiguous (e.g., `ID`, `URL`, `HTTP`). Use consistent naming conventions across the codebase.
|
||||
|
||||
Do not use one-two letters. For example:
|
||||
|
||||
Bad:
|
||||
|
||||
```
|
||||
u := users.getUser()
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
r := bufio.NewReader(pr)
|
||||
```
|
||||
|
||||
Good:
|
||||
|
||||
```
|
||||
user := users.GetUser()
|
||||
|
||||
pipeReader, pipeWriter := io.Pipe()
|
||||
|
||||
bufferedReader := bufio.NewReader(pipeReader)
|
||||
```
|
||||
|
||||
Exclusion: widely used variables like "db", "ctx", "req", "res", etc.
|
||||
|
||||
### Code style
|
||||
|
||||
**Always place private methods to the bottom of file**
|
||||
|
||||
5
agent/.gitignore
vendored
5
agent/.gitignore
vendored
@@ -1,6 +1,7 @@
|
||||
main
|
||||
.env
|
||||
docker-compose.yml
|
||||
!e2e/docker-compose.yml
|
||||
pgdata
|
||||
pgdata_test/
|
||||
mysqldata/
|
||||
@@ -20,4 +21,6 @@ cmd.exe
|
||||
temp/
|
||||
valkey-data/
|
||||
victoria-logs-data/
|
||||
databasus.json
|
||||
databasus.json
|
||||
.test-tmp/
|
||||
databasus.log
|
||||
@@ -1,3 +1,5 @@
|
||||
.PHONY: run build test lint e2e e2e-clean
|
||||
|
||||
# Usage: make run ARGS="start --pg-host localhost"
|
||||
run:
|
||||
go run cmd/main.go $(ARGS)
|
||||
@@ -9,4 +11,16 @@ test:
|
||||
go test -count=1 -failfast ./internal/...
|
||||
|
||||
lint:
|
||||
golangci-lint fmt ./cmd/... ./internal/... && golangci-lint run ./cmd/... ./internal/...
|
||||
golangci-lint fmt ./cmd/... ./internal/... ./e2e/... && golangci-lint run ./cmd/... ./internal/... ./e2e/...
|
||||
|
||||
e2e:
|
||||
cd e2e && docker compose build
|
||||
cd e2e && docker compose run --rm e2e-agent-builder
|
||||
cd e2e && docker compose up -d e2e-postgres e2e-mock-server
|
||||
cd e2e && docker compose run --rm e2e-agent-runner
|
||||
cd e2e && docker compose run --rm e2e-agent-docker
|
||||
cd e2e && docker compose down -v
|
||||
|
||||
e2e-clean:
|
||||
cd e2e && docker compose down -v --rmi local
|
||||
rm -rf e2e/artifacts
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"databasus-agent/internal/config"
|
||||
"databasus-agent/internal/features/api"
|
||||
"databasus-agent/internal/features/start"
|
||||
"databasus-agent/internal/features/upgrade"
|
||||
"databasus-agent/internal/logger"
|
||||
@@ -24,6 +28,8 @@ func main() {
|
||||
switch os.Args[1] {
|
||||
case "start":
|
||||
runStart(os.Args[2:])
|
||||
case "_run":
|
||||
runDaemon(os.Args[2:])
|
||||
case "stop":
|
||||
runStop()
|
||||
case "status":
|
||||
@@ -42,7 +48,6 @@ func main() {
|
||||
func runStart(args []string) {
|
||||
fs := flag.NewFlagSet("start", flag.ExitOnError)
|
||||
|
||||
isDebug := fs.Bool("debug", false, "Enable debug logging")
|
||||
isSkipUpdate := fs.Bool("skip-update", false, "Skip auto-update check")
|
||||
|
||||
cfg := &config.Config{}
|
||||
@@ -52,26 +57,59 @@ func runStart(args []string) {
|
||||
fmt.Fprintf(os.Stderr, "Failed to save config: %v\n", err)
|
||||
}
|
||||
|
||||
logger.Init(*isDebug)
|
||||
log := logger.GetLogger()
|
||||
|
||||
isDev := checkIsDevelopment()
|
||||
runUpdateCheck(cfg.DatabasusHost, *isSkipUpdate, isDev, log)
|
||||
|
||||
if err := start.Run(cfg, log); err != nil {
|
||||
if err := start.Start(cfg, Version, isDev, log); err != nil {
|
||||
if errors.Is(err, upgrade.ErrUpgradeRestart) {
|
||||
reexecAfterUpgrade(log)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func runDaemon(args []string) {
|
||||
fs := flag.NewFlagSet("_run", flag.ExitOnError)
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
log := logger.GetLogger()
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.LoadFromJSON()
|
||||
|
||||
if err := start.RunDaemon(cfg, Version, checkIsDevelopment(), log); err != nil {
|
||||
if errors.Is(err, upgrade.ErrUpgradeRestart) {
|
||||
reexecAfterUpgrade(log)
|
||||
}
|
||||
|
||||
log.Error("Agent exited with error", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func runStop() {
|
||||
logger.Init(false)
|
||||
logger.GetLogger().Info("stop: stub — not yet implemented")
|
||||
log := logger.GetLogger()
|
||||
|
||||
if err := start.Stop(log); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func runStatus() {
|
||||
logger.Init(false)
|
||||
logger.GetLogger().Info("status: stub — not yet implemented")
|
||||
log := logger.GetLogger()
|
||||
|
||||
if err := start.Status(log); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func runRestore(args []string) {
|
||||
@@ -81,7 +119,6 @@ func runRestore(args []string) {
|
||||
backupID := fs.String("backup-id", "", "Full backup UUID (optional)")
|
||||
targetTime := fs.String("target-time", "", "PITR target time in RFC3339 (optional)")
|
||||
isYes := fs.Bool("yes", false, "Skip confirmation prompt")
|
||||
isDebug := fs.Bool("debug", false, "Enable debug logging")
|
||||
isSkipUpdate := fs.Bool("skip-update", false, "Skip auto-update check")
|
||||
|
||||
cfg := &config.Config{}
|
||||
@@ -91,7 +128,6 @@ func runRestore(args []string) {
|
||||
fmt.Fprintf(os.Stderr, "Failed to save config: %v\n", err)
|
||||
}
|
||||
|
||||
logger.Init(*isDebug)
|
||||
log := logger.GetLogger()
|
||||
|
||||
isDev := checkIsDevelopment()
|
||||
@@ -116,12 +152,7 @@ func printUsage() {
|
||||
fmt.Fprintln(os.Stderr, " version Print agent version")
|
||||
}
|
||||
|
||||
func runUpdateCheck(host string, isSkipUpdate, isDev bool, log interface {
|
||||
Info(string, ...any)
|
||||
Warn(string, ...any)
|
||||
Error(string, ...any)
|
||||
},
|
||||
) {
|
||||
func runUpdateCheck(host string, isSkipUpdate, isDev bool, log *slog.Logger) {
|
||||
if isSkipUpdate {
|
||||
return
|
||||
}
|
||||
@@ -130,10 +161,17 @@ func runUpdateCheck(host string, isSkipUpdate, isDev bool, log interface {
|
||||
return
|
||||
}
|
||||
|
||||
if err := upgrade.CheckAndUpdate(host, Version, isDev, log); err != nil {
|
||||
apiClient := api.NewClient(host, "", log)
|
||||
|
||||
isUpgraded, err := upgrade.CheckAndUpdate(apiClient, Version, isDev, log)
|
||||
if err != nil {
|
||||
log.Error("Auto-update failed", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if isUpgraded {
|
||||
reexecAfterUpgrade(log)
|
||||
}
|
||||
}
|
||||
|
||||
func checkIsDevelopment() bool {
|
||||
@@ -172,3 +210,18 @@ func parseEnvMode(data []byte) bool {
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func reexecAfterUpgrade(log *slog.Logger) {
|
||||
selfPath, err := os.Executable()
|
||||
if err != nil {
|
||||
log.Error("Failed to resolve executable for re-exec", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
log.Info("Re-executing after upgrade...")
|
||||
|
||||
if err := syscall.Exec(selfPath, os.Args, os.Environ()); err != nil {
|
||||
log.Error("Failed to re-exec after upgrade", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
1
agent/e2e/.gitignore
vendored
Normal file
1
agent/e2e/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
artifacts/
|
||||
13
agent/e2e/Dockerfile.agent-builder
Normal file
13
agent/e2e/Dockerfile.agent-builder
Normal file
@@ -0,0 +1,13 @@
|
||||
# Builds agent binaries with different versions so
|
||||
# we can test upgrade behavior (v1 -> v2)
|
||||
FROM golang:1.26.1-alpine AS build
|
||||
WORKDIR /src
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
RUN CGO_ENABLED=0 go build -ldflags "-X main.Version=v1.0.0" -o /out/agent-v1 ./cmd/main.go
|
||||
RUN CGO_ENABLED=0 go build -ldflags "-X main.Version=v2.0.0" -o /out/agent-v2 ./cmd/main.go
|
||||
|
||||
FROM alpine:3.21
|
||||
COPY --from=build /out/ /out/
|
||||
CMD ["cp", "-v", "/out/agent-v1", "/out/agent-v2", "/artifacts/"]
|
||||
8
agent/e2e/Dockerfile.agent-docker
Normal file
8
agent/e2e/Dockerfile.agent-docker
Normal file
@@ -0,0 +1,8 @@
|
||||
# Runs pg_basebackup-via-docker-exec test (test 5) which tests
|
||||
# that the agent can connect to Postgres inside Docker container
|
||||
FROM docker:27-cli
|
||||
|
||||
RUN apk add --no-cache bash curl
|
||||
|
||||
WORKDIR /tmp
|
||||
ENTRYPOINT []
|
||||
14
agent/e2e/Dockerfile.agent-runner
Normal file
14
agent/e2e/Dockerfile.agent-runner
Normal file
@@ -0,0 +1,14 @@
|
||||
# Runs upgrade and host-mode pg_basebackup tests (tests 1-4). Needs
|
||||
# Postgres client tools to be installed inside the system
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
ca-certificates curl gnupg2 postgresql-common && \
|
||||
/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
postgresql-client-17 && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /tmp
|
||||
ENTRYPOINT []
|
||||
10
agent/e2e/Dockerfile.mock-server
Normal file
10
agent/e2e/Dockerfile.mock-server
Normal file
@@ -0,0 +1,10 @@
|
||||
# Mock databasus API server for version checks and binary downloads. Just
|
||||
# serves static responses and files from the `artifacts` directory.
|
||||
FROM golang:1.26.1-alpine AS build
|
||||
WORKDIR /app
|
||||
COPY mock-server/main.go .
|
||||
RUN CGO_ENABLED=0 go build -o mock-server main.go
|
||||
|
||||
FROM alpine:3.21
|
||||
COPY --from=build /app/mock-server /usr/local/bin/mock-server
|
||||
ENTRYPOINT ["mock-server"]
|
||||
64
agent/e2e/docker-compose.yml
Normal file
64
agent/e2e/docker-compose.yml
Normal file
@@ -0,0 +1,64 @@
|
||||
services:
|
||||
e2e-agent-builder:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: e2e/Dockerfile.agent-builder
|
||||
volumes:
|
||||
- ./artifacts:/artifacts
|
||||
container_name: e2e-agent-builder
|
||||
|
||||
e2e-postgres:
|
||||
image: postgres:17
|
||||
environment:
|
||||
POSTGRES_DB: testdb
|
||||
POSTGRES_USER: testuser
|
||||
POSTGRES_PASSWORD: testpassword
|
||||
container_name: e2e-agent-postgres
|
||||
command: postgres -c wal_level=replica -c max_wal_senders=3
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U testuser -d testdb"]
|
||||
interval: 2s
|
||||
timeout: 5s
|
||||
retries: 30
|
||||
|
||||
e2e-mock-server:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.mock-server
|
||||
volumes:
|
||||
- ./artifacts:/artifacts:ro
|
||||
container_name: e2e-mock-server
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "-q", "--spider", "http://localhost:4050/health"]
|
||||
interval: 2s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
|
||||
e2e-agent-runner:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.agent-runner
|
||||
volumes:
|
||||
- ./artifacts:/opt/agent/artifacts:ro
|
||||
- ./scripts:/opt/agent/scripts:ro
|
||||
depends_on:
|
||||
e2e-postgres:
|
||||
condition: service_healthy
|
||||
e2e-mock-server:
|
||||
condition: service_healthy
|
||||
container_name: e2e-agent-runner
|
||||
command: ["bash", "/opt/agent/scripts/run-all.sh", "host"]
|
||||
|
||||
e2e-agent-docker:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.agent-docker
|
||||
volumes:
|
||||
- ./artifacts:/opt/agent/artifacts:ro
|
||||
- ./scripts:/opt/agent/scripts:ro
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
depends_on:
|
||||
e2e-postgres:
|
||||
condition: service_healthy
|
||||
container_name: e2e-agent-docker
|
||||
command: ["bash", "/opt/agent/scripts/run-all.sh", "docker"]
|
||||
108
agent/e2e/mock-server/main.go
Normal file
108
agent/e2e/mock-server/main.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type server struct {
|
||||
mu sync.RWMutex
|
||||
version string
|
||||
binaryPath string
|
||||
}
|
||||
|
||||
func main() {
|
||||
version := "v2.0.0"
|
||||
binaryPath := "/artifacts/agent-v2"
|
||||
port := "4050"
|
||||
|
||||
s := &server{version: version, binaryPath: binaryPath}
|
||||
|
||||
http.HandleFunc("/api/v1/system/version", s.handleVersion)
|
||||
http.HandleFunc("/api/v1/system/agent", s.handleAgentDownload)
|
||||
http.HandleFunc("/mock/set-version", s.handleSetVersion)
|
||||
http.HandleFunc("/mock/set-binary-path", s.handleSetBinaryPath)
|
||||
http.HandleFunc("/health", s.handleHealth)
|
||||
|
||||
addr := ":" + port
|
||||
log.Printf("Mock server starting on %s (version=%s, binary=%s)", addr, version, binaryPath)
|
||||
|
||||
if err := http.ListenAndServe(addr, nil); err != nil {
|
||||
log.Fatalf("Server failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *server) handleVersion(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.RLock()
|
||||
v := s.version
|
||||
s.mu.RUnlock()
|
||||
|
||||
log.Printf("GET /api/v1/system/version -> %s", v)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"version": v})
|
||||
}
|
||||
|
||||
func (s *server) handleAgentDownload(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.RLock()
|
||||
path := s.binaryPath
|
||||
s.mu.RUnlock()
|
||||
|
||||
log.Printf("GET /api/v1/system/agent (arch=%s) -> serving %s", r.URL.Query().Get("arch"), path)
|
||||
|
||||
http.ServeFile(w, r, path)
|
||||
}
|
||||
|
||||
func (s *server) handleSetVersion(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "POST only", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var body struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.version = body.Version
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Printf("POST /mock/set-version -> %s", body.Version)
|
||||
|
||||
_, _ = fmt.Fprintf(w, "version set to %s", body.Version)
|
||||
}
|
||||
|
||||
func (s *server) handleSetBinaryPath(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "POST only", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var body struct {
|
||||
BinaryPath string `json:"binaryPath"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.binaryPath = body.BinaryPath
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Printf("POST /mock/set-binary-path -> %s", body.BinaryPath)
|
||||
|
||||
_, _ = fmt.Fprintf(w, "binary path set to %s", body.BinaryPath)
|
||||
}
|
||||
|
||||
func (s *server) handleHealth(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}
|
||||
49
agent/e2e/scripts/run-all.sh
Normal file
49
agent/e2e/scripts/run-all.sh
Normal file
@@ -0,0 +1,49 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
MODE="${1:-host}"
|
||||
SCRIPT_DIR="$(dirname "$0")"
|
||||
PASSED=0
|
||||
FAILED=0
|
||||
|
||||
run_test() {
|
||||
local name="$1"
|
||||
local script="$2"
|
||||
|
||||
echo ""
|
||||
echo "========================================"
|
||||
echo " $name"
|
||||
echo "========================================"
|
||||
|
||||
if bash "$script"; then
|
||||
echo " PASSED: $name"
|
||||
PASSED=$((PASSED + 1))
|
||||
else
|
||||
echo " FAILED: $name"
|
||||
FAILED=$((FAILED + 1))
|
||||
fi
|
||||
}
|
||||
|
||||
if [ "$MODE" = "host" ]; then
|
||||
run_test "Test 1: Upgrade success (v1 -> v2)" "$SCRIPT_DIR/test-upgrade-success.sh"
|
||||
run_test "Test 2: Upgrade skip (version matches)" "$SCRIPT_DIR/test-upgrade-skip.sh"
|
||||
run_test "Test 3: Background upgrade (v1 -> v2 while running)" "$SCRIPT_DIR/test-upgrade-background.sh"
|
||||
run_test "Test 4: pg_basebackup in PATH" "$SCRIPT_DIR/test-pg-host-path.sh"
|
||||
run_test "Test 5: pg_basebackup via bindir" "$SCRIPT_DIR/test-pg-host-bindir.sh"
|
||||
|
||||
elif [ "$MODE" = "docker" ]; then
|
||||
run_test "Test 6: pg_basebackup via docker exec" "$SCRIPT_DIR/test-pg-docker-exec.sh"
|
||||
|
||||
else
|
||||
echo "Unknown mode: $MODE (expected 'host' or 'docker')"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "========================================"
|
||||
echo " Results: $PASSED passed, $FAILED failed"
|
||||
echo "========================================"
|
||||
|
||||
if [ "$FAILED" -gt 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
61
agent/e2e/scripts/test-pg-docker-exec.sh
Normal file
61
agent/e2e/scripts/test-pg-docker-exec.sh
Normal file
@@ -0,0 +1,61 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
ARTIFACTS="/opt/agent/artifacts"
|
||||
AGENT="/tmp/test-agent"
|
||||
PG_CONTAINER="e2e-agent-postgres"
|
||||
|
||||
# Cleanup from previous runs
|
||||
pkill -f "test-agent" 2>/dev/null || true
|
||||
for i in $(seq 1 20); do
|
||||
pgrep -f "test-agent" > /dev/null 2>&1 || break
|
||||
sleep 0.5
|
||||
done
|
||||
pkill -9 -f "test-agent" 2>/dev/null || true
|
||||
sleep 0.5
|
||||
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
|
||||
|
||||
# Copy agent binary
|
||||
cp "$ARTIFACTS/agent-v1" "$AGENT"
|
||||
chmod +x "$AGENT"
|
||||
|
||||
# Verify docker CLI works and PG container is accessible
|
||||
if ! docker exec "$PG_CONTAINER" pg_basebackup --version > /dev/null 2>&1; then
|
||||
echo "FAIL: Cannot reach pg_basebackup inside container $PG_CONTAINER (test setup issue)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Run start with --skip-update and pg-type=docker
|
||||
echo "Running agent start (pg_basebackup via docker exec)..."
|
||||
OUTPUT=$("$AGENT" start \
|
||||
--skip-update \
|
||||
--databasus-host http://e2e-mock-server:4050 \
|
||||
--db-id test-db-id \
|
||||
--token test-token \
|
||||
--pg-host e2e-postgres \
|
||||
--pg-port 5432 \
|
||||
--pg-user testuser \
|
||||
--pg-password testpassword \
|
||||
--pg-wal-dir /tmp/wal \
|
||||
--pg-type docker \
|
||||
--pg-docker-container-name "$PG_CONTAINER" 2>&1)
|
||||
|
||||
EXIT_CODE=$?
|
||||
echo "$OUTPUT"
|
||||
|
||||
if [ "$EXIT_CODE" -ne 0 ]; then
|
||||
echo "FAIL: Agent exited with code $EXIT_CODE"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! echo "$OUTPUT" | grep -q "pg_basebackup verified (docker)"; then
|
||||
echo "FAIL: Expected output to contain 'pg_basebackup verified (docker)'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! echo "$OUTPUT" | grep -q "PostgreSQL connection verified"; then
|
||||
echo "FAIL: Expected output to contain 'PostgreSQL connection verified'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "pg_basebackup found via docker exec and DB connection verified"
|
||||
67
agent/e2e/scripts/test-pg-host-bindir.sh
Normal file
67
agent/e2e/scripts/test-pg-host-bindir.sh
Normal file
@@ -0,0 +1,67 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
ARTIFACTS="/opt/agent/artifacts"
|
||||
AGENT="/tmp/test-agent"
|
||||
CUSTOM_BIN_DIR="/opt/pg/bin"
|
||||
|
||||
# Cleanup from previous runs
|
||||
pkill -f "test-agent" 2>/dev/null || true
|
||||
for i in $(seq 1 20); do
|
||||
pgrep -f "test-agent" > /dev/null 2>&1 || break
|
||||
sleep 0.5
|
||||
done
|
||||
pkill -9 -f "test-agent" 2>/dev/null || true
|
||||
sleep 0.5
|
||||
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
|
||||
|
||||
# Copy agent binary
|
||||
cp "$ARTIFACTS/agent-v1" "$AGENT"
|
||||
chmod +x "$AGENT"
|
||||
|
||||
# Move pg_basebackup out of PATH into custom directory
|
||||
mkdir -p "$CUSTOM_BIN_DIR"
|
||||
cp "$(which pg_basebackup)" "$CUSTOM_BIN_DIR/pg_basebackup"
|
||||
|
||||
# Hide the system one by prepending an empty dir to PATH
|
||||
export PATH="/opt/empty-path:$PATH"
|
||||
mkdir -p /opt/empty-path
|
||||
|
||||
# Verify pg_basebackup is NOT directly callable from default location
|
||||
# (we copied it, but the original is still there in debian — so we test
|
||||
# that the agent uses the custom dir, not PATH, by checking the output)
|
||||
|
||||
# Run start with --skip-update and custom bin dir
|
||||
echo "Running agent start (pg_basebackup via --pg-host-bin-dir)..."
|
||||
OUTPUT=$("$AGENT" start \
|
||||
--skip-update \
|
||||
--databasus-host http://e2e-mock-server:4050 \
|
||||
--db-id test-db-id \
|
||||
--token test-token \
|
||||
--pg-host e2e-postgres \
|
||||
--pg-port 5432 \
|
||||
--pg-user testuser \
|
||||
--pg-password testpassword \
|
||||
--pg-wal-dir /tmp/wal \
|
||||
--pg-type host \
|
||||
--pg-host-bin-dir "$CUSTOM_BIN_DIR" 2>&1)
|
||||
|
||||
EXIT_CODE=$?
|
||||
echo "$OUTPUT"
|
||||
|
||||
if [ "$EXIT_CODE" -ne 0 ]; then
|
||||
echo "FAIL: Agent exited with code $EXIT_CODE"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! echo "$OUTPUT" | grep -q "pg_basebackup verified"; then
|
||||
echo "FAIL: Expected output to contain 'pg_basebackup verified'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! echo "$OUTPUT" | grep -q "PostgreSQL connection verified"; then
|
||||
echo "FAIL: Expected output to contain 'PostgreSQL connection verified'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "pg_basebackup found via custom bin dir and DB connection verified"
|
||||
59
agent/e2e/scripts/test-pg-host-path.sh
Normal file
59
agent/e2e/scripts/test-pg-host-path.sh
Normal file
@@ -0,0 +1,59 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
ARTIFACTS="/opt/agent/artifacts"
|
||||
AGENT="/tmp/test-agent"
|
||||
|
||||
# Cleanup from previous runs
|
||||
pkill -f "test-agent" 2>/dev/null || true
|
||||
for i in $(seq 1 20); do
|
||||
pgrep -f "test-agent" > /dev/null 2>&1 || break
|
||||
sleep 0.5
|
||||
done
|
||||
pkill -9 -f "test-agent" 2>/dev/null || true
|
||||
sleep 0.5
|
||||
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
|
||||
|
||||
# Copy agent binary
|
||||
cp "$ARTIFACTS/agent-v1" "$AGENT"
|
||||
chmod +x "$AGENT"
|
||||
|
||||
# Verify pg_basebackup is in PATH
|
||||
if ! which pg_basebackup > /dev/null 2>&1; then
|
||||
echo "FAIL: pg_basebackup not found in PATH (test setup issue)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Run start with --skip-update and pg-type=host
|
||||
echo "Running agent start (pg_basebackup in PATH)..."
|
||||
OUTPUT=$("$AGENT" start \
|
||||
--skip-update \
|
||||
--databasus-host http://e2e-mock-server:4050 \
|
||||
--db-id test-db-id \
|
||||
--token test-token \
|
||||
--pg-host e2e-postgres \
|
||||
--pg-port 5432 \
|
||||
--pg-user testuser \
|
||||
--pg-password testpassword \
|
||||
--pg-wal-dir /tmp/wal \
|
||||
--pg-type host 2>&1)
|
||||
|
||||
EXIT_CODE=$?
|
||||
echo "$OUTPUT"
|
||||
|
||||
if [ "$EXIT_CODE" -ne 0 ]; then
|
||||
echo "FAIL: Agent exited with code $EXIT_CODE"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! echo "$OUTPUT" | grep -q "pg_basebackup verified"; then
|
||||
echo "FAIL: Expected output to contain 'pg_basebackup verified'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! echo "$OUTPUT" | grep -q "PostgreSQL connection verified"; then
|
||||
echo "FAIL: Expected output to contain 'PostgreSQL connection verified'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "pg_basebackup found in PATH and DB connection verified"
|
||||
90
agent/e2e/scripts/test-upgrade-background.sh
Normal file
90
agent/e2e/scripts/test-upgrade-background.sh
Normal file
@@ -0,0 +1,90 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
ARTIFACTS="/opt/agent/artifacts"
|
||||
AGENT="/tmp/test-agent"
|
||||
|
||||
# Cleanup from previous runs
|
||||
pkill -f "test-agent" 2>/dev/null || true
|
||||
for i in $(seq 1 20); do
|
||||
pgrep -f "test-agent" > /dev/null 2>&1 || break
|
||||
sleep 0.5
|
||||
done
|
||||
pkill -9 -f "test-agent" 2>/dev/null || true
|
||||
sleep 0.5
|
||||
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
|
||||
|
||||
# Set mock server to v1.0.0 (same as agent — no sync upgrade on start)
|
||||
curl -sf -X POST http://e2e-mock-server:4050/mock/set-version \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"version":"v1.0.0"}'
|
||||
|
||||
curl -sf -X POST http://e2e-mock-server:4050/mock/set-binary-path \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"binaryPath":"/artifacts/agent-v1"}'
|
||||
|
||||
# Copy v1 binary to writable location
|
||||
cp "$ARTIFACTS/agent-v1" "$AGENT"
|
||||
chmod +x "$AGENT"
|
||||
|
||||
# Verify initial version
|
||||
VERSION=$("$AGENT" version)
|
||||
if [ "$VERSION" != "v1.0.0" ]; then
|
||||
echo "FAIL: Expected initial version v1.0.0, got $VERSION"
|
||||
exit 1
|
||||
fi
|
||||
echo "Initial version: $VERSION"
|
||||
|
||||
# Start agent as daemon (versions match → no sync upgrade)
|
||||
mkdir -p /tmp/wal
|
||||
"$AGENT" start \
|
||||
--databasus-host http://e2e-mock-server:4050 \
|
||||
--db-id test-db-id \
|
||||
--token test-token \
|
||||
--pg-host e2e-postgres \
|
||||
--pg-port 5432 \
|
||||
--pg-user testuser \
|
||||
--pg-password testpassword \
|
||||
--pg-wal-dir /tmp/wal \
|
||||
--pg-type host
|
||||
|
||||
echo "Agent started as daemon, waiting for stabilization..."
|
||||
sleep 2
|
||||
|
||||
# Change mock server to v2.0.0 and point to v2 binary
|
||||
curl -sf -X POST http://e2e-mock-server:4050/mock/set-version \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"version":"v2.0.0"}'
|
||||
|
||||
curl -sf -X POST http://e2e-mock-server:4050/mock/set-binary-path \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"binaryPath":"/artifacts/agent-v2"}'
|
||||
|
||||
echo "Mock server updated to v2.0.0, waiting for background upgrade..."
|
||||
|
||||
# Poll for upgrade (timeout 60s, poll every 3s)
|
||||
DEADLINE=$((SECONDS + 60))
|
||||
while [ $SECONDS -lt $DEADLINE ]; do
|
||||
VERSION=$("$AGENT" version)
|
||||
if [ "$VERSION" = "v2.0.0" ]; then
|
||||
echo "Binary upgraded to $VERSION"
|
||||
break
|
||||
fi
|
||||
sleep 3
|
||||
done
|
||||
|
||||
VERSION=$("$AGENT" version)
|
||||
if [ "$VERSION" != "v2.0.0" ]; then
|
||||
echo "FAIL: Expected v2.0.0 after background upgrade, got $VERSION"
|
||||
cat databasus.log 2>/dev/null || true
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Verify agent is still running after restart
|
||||
sleep 2
|
||||
"$AGENT" status || true
|
||||
|
||||
# Cleanup
|
||||
"$AGENT" stop || true
|
||||
|
||||
echo "Background upgrade test passed"
|
||||
61
agent/e2e/scripts/test-upgrade-skip.sh
Normal file
61
agent/e2e/scripts/test-upgrade-skip.sh
Normal file
@@ -0,0 +1,61 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
ARTIFACTS="/opt/agent/artifacts"
|
||||
AGENT="/tmp/test-agent"
|
||||
|
||||
# Cleanup from previous runs
|
||||
pkill -f "test-agent" 2>/dev/null || true
|
||||
for i in $(seq 1 20); do
|
||||
pgrep -f "test-agent" > /dev/null 2>&1 || break
|
||||
sleep 0.5
|
||||
done
|
||||
pkill -9 -f "test-agent" 2>/dev/null || true
|
||||
sleep 0.5
|
||||
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
|
||||
|
||||
# Set mock server to return v1.0.0 (same as agent)
|
||||
curl -sf -X POST http://e2e-mock-server:4050/mock/set-version \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"version":"v1.0.0"}'
|
||||
|
||||
# Copy v1 binary to writable location
|
||||
cp "$ARTIFACTS/agent-v1" "$AGENT"
|
||||
chmod +x "$AGENT"
|
||||
|
||||
# Verify initial version
|
||||
VERSION=$("$AGENT" version)
|
||||
if [ "$VERSION" != "v1.0.0" ]; then
|
||||
echo "FAIL: Expected initial version v1.0.0, got $VERSION"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Run start — agent should see version matches and skip upgrade
|
||||
echo "Running agent start (expecting upgrade skip)..."
|
||||
OUTPUT=$("$AGENT" start \
|
||||
--databasus-host http://e2e-mock-server:4050 \
|
||||
--db-id test-db-id \
|
||||
--token test-token \
|
||||
--pg-host e2e-postgres \
|
||||
--pg-port 5432 \
|
||||
--pg-user testuser \
|
||||
--pg-password testpassword \
|
||||
--pg-wal-dir /tmp/wal \
|
||||
--pg-type host 2>&1) || true
|
||||
|
||||
echo "$OUTPUT"
|
||||
|
||||
# Verify output contains "up to date"
|
||||
if ! echo "$OUTPUT" | grep -qi "up to date"; then
|
||||
echo "FAIL: Expected output to contain 'up to date'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Verify binary is still v1
|
||||
VERSION=$("$AGENT" version)
|
||||
if [ "$VERSION" != "v1.0.0" ]; then
|
||||
echo "FAIL: Expected version v1.0.0 (unchanged), got $VERSION"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Upgrade correctly skipped, version still $VERSION"
|
||||
66
agent/e2e/scripts/test-upgrade-success.sh
Normal file
66
agent/e2e/scripts/test-upgrade-success.sh
Normal file
@@ -0,0 +1,66 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
ARTIFACTS="/opt/agent/artifacts"
|
||||
AGENT="/tmp/test-agent"
|
||||
|
||||
# Cleanup from previous runs
|
||||
pkill -f "test-agent" 2>/dev/null || true
|
||||
for i in $(seq 1 20); do
|
||||
pgrep -f "test-agent" > /dev/null 2>&1 || break
|
||||
sleep 0.5
|
||||
done
|
||||
pkill -9 -f "test-agent" 2>/dev/null || true
|
||||
sleep 0.5
|
||||
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
|
||||
|
||||
# Ensure mock server returns v2.0.0 and serves v2 binary
|
||||
curl -sf -X POST http://e2e-mock-server:4050/mock/set-version \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"version":"v2.0.0"}'
|
||||
|
||||
curl -sf -X POST http://e2e-mock-server:4050/mock/set-binary-path \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"binaryPath":"/artifacts/agent-v2"}'
|
||||
|
||||
# Copy v1 binary to writable location
|
||||
cp "$ARTIFACTS/agent-v1" "$AGENT"
|
||||
chmod +x "$AGENT"
|
||||
|
||||
# Verify initial version
|
||||
VERSION=$("$AGENT" version)
|
||||
if [ "$VERSION" != "v1.0.0" ]; then
|
||||
echo "FAIL: Expected initial version v1.0.0, got $VERSION"
|
||||
exit 1
|
||||
fi
|
||||
echo "Initial version: $VERSION"
|
||||
|
||||
# Run start — agent will:
|
||||
# 1. Fetch version from mock (v2.0.0 != v1.0.0)
|
||||
# 2. Download v2 binary from mock
|
||||
# 3. Replace itself on disk
|
||||
# 4. Re-exec with same args
|
||||
# 5. Re-exec'd v2 fetches version (v2.0.0 == v2.0.0) → skips update
|
||||
# 6. Proceeds to start → verifies pg_basebackup + DB → exits 0 (stub)
|
||||
echo "Running agent start (expecting upgrade v1 -> v2)..."
|
||||
OUTPUT=$("$AGENT" start \
|
||||
--databasus-host http://e2e-mock-server:4050 \
|
||||
--db-id test-db-id \
|
||||
--token test-token \
|
||||
--pg-host e2e-postgres \
|
||||
--pg-port 5432 \
|
||||
--pg-user testuser \
|
||||
--pg-password testpassword \
|
||||
--pg-wal-dir /tmp/wal \
|
||||
--pg-type host 2>&1) || true
|
||||
|
||||
echo "$OUTPUT"
|
||||
|
||||
# Verify binary on disk is now v2
|
||||
VERSION=$("$AGENT" version)
|
||||
if [ "$VERSION" != "v2.0.0" ]; then
|
||||
echo "FAIL: Expected upgraded version v2.0.0, got $VERSION"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Binary upgraded successfully to $VERSION"
|
||||
13
agent/go.mod
13
agent/go.mod
@@ -2,10 +2,21 @@ module databasus-agent
|
||||
|
||||
go 1.26.1
|
||||
|
||||
require github.com/stretchr/testify v1.11.1
|
||||
require (
|
||||
github.com/go-resty/resty/v2 v2.17.2
|
||||
github.com/jackc/pgx/v5 v5.8.0
|
||||
github.com/klauspost/compress v1.18.4
|
||||
github.com/stretchr/testify v1.11.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.14.1 // indirect
|
||||
golang.org/x/net v0.43.0 // indirect
|
||||
golang.org/x/text v0.29.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
35
agent/go.sum
35
agent/go.sum
@@ -1,10 +1,43 @@
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-resty/resty/v2 v2.17.2 h1:FQW5oHYcIlkCNrMD2lloGScxcHJ0gkjshV3qcQAyHQk=
|
||||
github.com/go-resty/resty/v2 v2.17.2/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo=
|
||||
github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c=
|
||||
github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
||||
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
|
||||
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
|
||||
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
|
||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@@ -3,6 +3,7 @@ package config
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"databasus-agent/internal/logger"
|
||||
@@ -13,9 +14,18 @@ var log = logger.GetLogger()
|
||||
const configFileName = "databasus.json"
|
||||
|
||||
type Config struct {
|
||||
DatabasusHost string `json:"databasusHost"`
|
||||
DbID string `json:"dbId"`
|
||||
Token string `json:"token"`
|
||||
DatabasusHost string `json:"databasusHost"`
|
||||
DbID string `json:"dbId"`
|
||||
Token string `json:"token"`
|
||||
PgHost string `json:"pgHost"`
|
||||
PgPort int `json:"pgPort"`
|
||||
PgUser string `json:"pgUser"`
|
||||
PgPassword string `json:"pgPassword"`
|
||||
PgType string `json:"pgType"`
|
||||
PgHostBinDir string `json:"pgHostBinDir"`
|
||||
PgDockerContainerName string `json:"pgDockerContainerName"`
|
||||
PgWalDir string `json:"pgWalDir"`
|
||||
IsDeleteWalAfterUpload *bool `json:"deleteWalAfterUpload"`
|
||||
|
||||
flags parsedFlags
|
||||
}
|
||||
@@ -24,15 +34,24 @@ type Config struct {
|
||||
// and overrides JSON values with any explicitly provided CLI flags.
|
||||
func (c *Config) LoadFromJSONAndArgs(fs *flag.FlagSet, args []string) {
|
||||
c.loadFromJSON()
|
||||
c.applyDefaults()
|
||||
c.initSources()
|
||||
|
||||
c.flags.host = fs.String(
|
||||
c.flags.databasusHost = fs.String(
|
||||
"databasus-host",
|
||||
"",
|
||||
"Databasus server URL (e.g. http://your-server:4005)",
|
||||
)
|
||||
c.flags.dbID = fs.String("db-id", "", "Database ID")
|
||||
c.flags.token = fs.String("token", "", "Agent token")
|
||||
c.flags.pgHost = fs.String("pg-host", "", "PostgreSQL host")
|
||||
c.flags.pgPort = fs.Int("pg-port", 0, "PostgreSQL port")
|
||||
c.flags.pgUser = fs.String("pg-user", "", "PostgreSQL user")
|
||||
c.flags.pgPassword = fs.String("pg-password", "", "PostgreSQL password")
|
||||
c.flags.pgType = fs.String("pg-type", "", "PostgreSQL type: host or docker")
|
||||
c.flags.pgHostBinDir = fs.String("pg-host-bin-dir", "", "Path to PG bin directory (host mode)")
|
||||
c.flags.pgDockerContainerName = fs.String("pg-docker-container-name", "", "Docker container name (docker mode)")
|
||||
c.flags.pgWalDir = fs.String("pg-wal-dir", "", "Path to WAL queue directory")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
os.Exit(1)
|
||||
@@ -54,6 +73,11 @@ func (c *Config) SaveToJSON() error {
|
||||
return os.WriteFile(configFileName, data, 0o644)
|
||||
}
|
||||
|
||||
func (c *Config) LoadFromJSON() {
|
||||
c.loadFromJSON()
|
||||
c.applyDefaults()
|
||||
}
|
||||
|
||||
func (c *Config) loadFromJSON() {
|
||||
data, err := os.ReadFile(configFileName)
|
||||
if err != nil {
|
||||
@@ -76,11 +100,35 @@ func (c *Config) loadFromJSON() {
|
||||
log.Info("Configuration loaded from " + configFileName)
|
||||
}
|
||||
|
||||
func (c *Config) applyDefaults() {
|
||||
if c.PgPort == 0 {
|
||||
c.PgPort = 5432
|
||||
}
|
||||
|
||||
if c.PgType == "" {
|
||||
c.PgType = "host"
|
||||
}
|
||||
|
||||
if c.IsDeleteWalAfterUpload == nil {
|
||||
v := true
|
||||
c.IsDeleteWalAfterUpload = &v
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) initSources() {
|
||||
c.flags.sources = map[string]string{
|
||||
"databasus-host": "not configured",
|
||||
"db-id": "not configured",
|
||||
"token": "not configured",
|
||||
"databasus-host": "not configured",
|
||||
"db-id": "not configured",
|
||||
"token": "not configured",
|
||||
"pg-host": "not configured",
|
||||
"pg-port": "not configured",
|
||||
"pg-user": "not configured",
|
||||
"pg-password": "not configured",
|
||||
"pg-type": "not configured",
|
||||
"pg-host-bin-dir": "not configured",
|
||||
"pg-docker-container-name": "not configured",
|
||||
"pg-wal-dir": "not configured",
|
||||
"delete-wal-after-upload": "not configured",
|
||||
}
|
||||
|
||||
if c.DatabasusHost != "" {
|
||||
@@ -94,11 +142,44 @@ func (c *Config) initSources() {
|
||||
if c.Token != "" {
|
||||
c.flags.sources["token"] = configFileName
|
||||
}
|
||||
|
||||
if c.PgHost != "" {
|
||||
c.flags.sources["pg-host"] = configFileName
|
||||
}
|
||||
|
||||
// PgPort always has a value after applyDefaults
|
||||
c.flags.sources["pg-port"] = configFileName
|
||||
|
||||
if c.PgUser != "" {
|
||||
c.flags.sources["pg-user"] = configFileName
|
||||
}
|
||||
|
||||
if c.PgPassword != "" {
|
||||
c.flags.sources["pg-password"] = configFileName
|
||||
}
|
||||
|
||||
// PgType always has a value after applyDefaults
|
||||
c.flags.sources["pg-type"] = configFileName
|
||||
|
||||
if c.PgHostBinDir != "" {
|
||||
c.flags.sources["pg-host-bin-dir"] = configFileName
|
||||
}
|
||||
|
||||
if c.PgDockerContainerName != "" {
|
||||
c.flags.sources["pg-docker-container-name"] = configFileName
|
||||
}
|
||||
|
||||
if c.PgWalDir != "" {
|
||||
c.flags.sources["pg-wal-dir"] = configFileName
|
||||
}
|
||||
|
||||
// IsDeleteWalAfterUpload always has a value after applyDefaults
|
||||
c.flags.sources["delete-wal-after-upload"] = configFileName
|
||||
}
|
||||
|
||||
func (c *Config) applyFlags() {
|
||||
if c.flags.host != nil && *c.flags.host != "" {
|
||||
c.DatabasusHost = *c.flags.host
|
||||
if c.flags.databasusHost != nil && *c.flags.databasusHost != "" {
|
||||
c.DatabasusHost = *c.flags.databasusHost
|
||||
c.flags.sources["databasus-host"] = "command line args"
|
||||
}
|
||||
|
||||
@@ -111,18 +192,73 @@ func (c *Config) applyFlags() {
|
||||
c.Token = *c.flags.token
|
||||
c.flags.sources["token"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgHost != nil && *c.flags.pgHost != "" {
|
||||
c.PgHost = *c.flags.pgHost
|
||||
c.flags.sources["pg-host"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgPort != nil && *c.flags.pgPort != 0 {
|
||||
c.PgPort = *c.flags.pgPort
|
||||
c.flags.sources["pg-port"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgUser != nil && *c.flags.pgUser != "" {
|
||||
c.PgUser = *c.flags.pgUser
|
||||
c.flags.sources["pg-user"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgPassword != nil && *c.flags.pgPassword != "" {
|
||||
c.PgPassword = *c.flags.pgPassword
|
||||
c.flags.sources["pg-password"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgType != nil && *c.flags.pgType != "" {
|
||||
c.PgType = *c.flags.pgType
|
||||
c.flags.sources["pg-type"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgHostBinDir != nil && *c.flags.pgHostBinDir != "" {
|
||||
c.PgHostBinDir = *c.flags.pgHostBinDir
|
||||
c.flags.sources["pg-host-bin-dir"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgDockerContainerName != nil && *c.flags.pgDockerContainerName != "" {
|
||||
c.PgDockerContainerName = *c.flags.pgDockerContainerName
|
||||
c.flags.sources["pg-docker-container-name"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgWalDir != nil && *c.flags.pgWalDir != "" {
|
||||
c.PgWalDir = *c.flags.pgWalDir
|
||||
c.flags.sources["pg-wal-dir"] = "command line args"
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) logConfigSources() {
|
||||
log.Info(
|
||||
"databasus-host",
|
||||
"value",
|
||||
c.DatabasusHost,
|
||||
"source",
|
||||
c.flags.sources["databasus-host"],
|
||||
)
|
||||
log.Info("databasus-host", "value", c.DatabasusHost, "source", c.flags.sources["databasus-host"])
|
||||
log.Info("db-id", "value", c.DbID, "source", c.flags.sources["db-id"])
|
||||
log.Info("token", "value", maskSensitive(c.Token), "source", c.flags.sources["token"])
|
||||
log.Info("pg-host", "value", c.PgHost, "source", c.flags.sources["pg-host"])
|
||||
log.Info("pg-port", "value", c.PgPort, "source", c.flags.sources["pg-port"])
|
||||
log.Info("pg-user", "value", c.PgUser, "source", c.flags.sources["pg-user"])
|
||||
log.Info("pg-password", "value", maskSensitive(c.PgPassword), "source", c.flags.sources["pg-password"])
|
||||
log.Info("pg-type", "value", c.PgType, "source", c.flags.sources["pg-type"])
|
||||
log.Info("pg-host-bin-dir", "value", c.PgHostBinDir, "source", c.flags.sources["pg-host-bin-dir"])
|
||||
log.Info(
|
||||
"pg-docker-container-name",
|
||||
"value",
|
||||
c.PgDockerContainerName,
|
||||
"source",
|
||||
c.flags.sources["pg-docker-container-name"],
|
||||
)
|
||||
log.Info("pg-wal-dir", "value", c.PgWalDir, "source", c.flags.sources["pg-wal-dir"])
|
||||
log.Info(
|
||||
"delete-wal-after-upload",
|
||||
"value",
|
||||
fmt.Sprintf("%v", *c.IsDeleteWalAfterUpload),
|
||||
"source",
|
||||
c.flags.sources["delete-wal-after-upload"],
|
||||
)
|
||||
}
|
||||
|
||||
func maskSensitive(value string) string {
|
||||
|
||||
@@ -86,10 +86,12 @@ func Test_LoadFromJSONAndArgs_PartialArgsOverrideJSON(t *testing.T) {
|
||||
func Test_SaveToJSON_ConfigSavedCorrectly(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
|
||||
deleteWal := true
|
||||
cfg := &Config{
|
||||
DatabasusHost: "http://save-host:4005",
|
||||
DbID: "save-db-id",
|
||||
Token: "save-token",
|
||||
DatabasusHost: "http://save-host:4005",
|
||||
DbID: "save-db-id",
|
||||
Token: "save-token",
|
||||
IsDeleteWalAfterUpload: &deleteWal,
|
||||
}
|
||||
|
||||
err := cfg.SaveToJSON()
|
||||
@@ -126,6 +128,143 @@ func Test_SaveToJSON_AfterArgsOverrideJSON_SavedFileContainsMergedValues(t *test
|
||||
assert.Equal(t, "json-token", saved.Token)
|
||||
}
|
||||
|
||||
func Test_LoadFromJSONAndArgs_PgFieldsLoadedFromJSON(t *testing.T) {
|
||||
dir := setupTempDir(t)
|
||||
deleteWal := false
|
||||
writeConfigJSON(t, dir, Config{
|
||||
DatabasusHost: "http://json-host:4005",
|
||||
DbID: "json-db-id",
|
||||
Token: "json-token",
|
||||
PgHost: "pg-json-host",
|
||||
PgPort: 5433,
|
||||
PgUser: "pg-json-user",
|
||||
PgPassword: "pg-json-pass",
|
||||
PgType: "docker",
|
||||
PgHostBinDir: "/usr/bin",
|
||||
PgDockerContainerName: "pg-container",
|
||||
PgWalDir: "/opt/wal",
|
||||
IsDeleteWalAfterUpload: &deleteWal,
|
||||
})
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{})
|
||||
|
||||
assert.Equal(t, "pg-json-host", cfg.PgHost)
|
||||
assert.Equal(t, 5433, cfg.PgPort)
|
||||
assert.Equal(t, "pg-json-user", cfg.PgUser)
|
||||
assert.Equal(t, "pg-json-pass", cfg.PgPassword)
|
||||
assert.Equal(t, "docker", cfg.PgType)
|
||||
assert.Equal(t, "/usr/bin", cfg.PgHostBinDir)
|
||||
assert.Equal(t, "pg-container", cfg.PgDockerContainerName)
|
||||
assert.Equal(t, "/opt/wal", cfg.PgWalDir)
|
||||
assert.Equal(t, false, *cfg.IsDeleteWalAfterUpload)
|
||||
}
|
||||
|
||||
func Test_LoadFromJSONAndArgs_PgFieldsLoadedFromArgs(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{
|
||||
"--pg-host", "arg-pg-host",
|
||||
"--pg-port", "5433",
|
||||
"--pg-user", "arg-pg-user",
|
||||
"--pg-password", "arg-pg-pass",
|
||||
"--pg-type", "docker",
|
||||
"--pg-host-bin-dir", "/custom/bin",
|
||||
"--pg-docker-container-name", "my-pg",
|
||||
"--pg-wal-dir", "/var/wal",
|
||||
})
|
||||
|
||||
assert.Equal(t, "arg-pg-host", cfg.PgHost)
|
||||
assert.Equal(t, 5433, cfg.PgPort)
|
||||
assert.Equal(t, "arg-pg-user", cfg.PgUser)
|
||||
assert.Equal(t, "arg-pg-pass", cfg.PgPassword)
|
||||
assert.Equal(t, "docker", cfg.PgType)
|
||||
assert.Equal(t, "/custom/bin", cfg.PgHostBinDir)
|
||||
assert.Equal(t, "my-pg", cfg.PgDockerContainerName)
|
||||
assert.Equal(t, "/var/wal", cfg.PgWalDir)
|
||||
}
|
||||
|
||||
func Test_LoadFromJSONAndArgs_PgArgsOverrideJSON(t *testing.T) {
|
||||
dir := setupTempDir(t)
|
||||
writeConfigJSON(t, dir, Config{
|
||||
PgHost: "json-host",
|
||||
PgPort: 5432,
|
||||
PgUser: "json-user",
|
||||
PgType: "host",
|
||||
PgWalDir: "/json/wal",
|
||||
})
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{
|
||||
"--pg-host", "arg-host",
|
||||
"--pg-port", "5433",
|
||||
"--pg-user", "arg-user",
|
||||
"--pg-type", "docker",
|
||||
"--pg-docker-container-name", "my-container",
|
||||
"--pg-wal-dir", "/arg/wal",
|
||||
})
|
||||
|
||||
assert.Equal(t, "arg-host", cfg.PgHost)
|
||||
assert.Equal(t, 5433, cfg.PgPort)
|
||||
assert.Equal(t, "arg-user", cfg.PgUser)
|
||||
assert.Equal(t, "docker", cfg.PgType)
|
||||
assert.Equal(t, "my-container", cfg.PgDockerContainerName)
|
||||
assert.Equal(t, "/arg/wal", cfg.PgWalDir)
|
||||
}
|
||||
|
||||
func Test_LoadFromJSONAndArgs_DefaultsApplied_WhenNoJSONAndNoArgs(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{})
|
||||
|
||||
assert.Equal(t, 5432, cfg.PgPort)
|
||||
assert.Equal(t, "host", cfg.PgType)
|
||||
require.NotNil(t, cfg.IsDeleteWalAfterUpload)
|
||||
assert.Equal(t, true, *cfg.IsDeleteWalAfterUpload)
|
||||
}
|
||||
|
||||
func Test_SaveToJSON_PgFieldsSavedCorrectly(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
|
||||
deleteWal := false
|
||||
cfg := &Config{
|
||||
DatabasusHost: "http://host:4005",
|
||||
DbID: "db-id",
|
||||
Token: "token",
|
||||
PgHost: "pg-host",
|
||||
PgPort: 5433,
|
||||
PgUser: "pg-user",
|
||||
PgPassword: "pg-pass",
|
||||
PgType: "docker",
|
||||
PgHostBinDir: "/usr/bin",
|
||||
PgDockerContainerName: "pg-container",
|
||||
PgWalDir: "/opt/wal",
|
||||
IsDeleteWalAfterUpload: &deleteWal,
|
||||
}
|
||||
|
||||
err := cfg.SaveToJSON()
|
||||
require.NoError(t, err)
|
||||
|
||||
saved := readConfigJSON(t)
|
||||
|
||||
assert.Equal(t, "pg-host", saved.PgHost)
|
||||
assert.Equal(t, 5433, saved.PgPort)
|
||||
assert.Equal(t, "pg-user", saved.PgUser)
|
||||
assert.Equal(t, "pg-pass", saved.PgPassword)
|
||||
assert.Equal(t, "docker", saved.PgType)
|
||||
assert.Equal(t, "/usr/bin", saved.PgHostBinDir)
|
||||
assert.Equal(t, "pg-container", saved.PgDockerContainerName)
|
||||
assert.Equal(t, "/opt/wal", saved.PgWalDir)
|
||||
require.NotNil(t, saved.IsDeleteWalAfterUpload)
|
||||
assert.Equal(t, false, *saved.IsDeleteWalAfterUpload)
|
||||
}
|
||||
|
||||
func setupTempDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
package config
|
||||
|
||||
type parsedFlags struct {
|
||||
host *string
|
||||
dbID *string
|
||||
token *string
|
||||
databasusHost *string
|
||||
dbID *string
|
||||
token *string
|
||||
pgHost *string
|
||||
pgPort *int
|
||||
pgUser *string
|
||||
pgPassword *string
|
||||
pgType *string
|
||||
pgHostBinDir *string
|
||||
pgDockerContainerName *string
|
||||
pgWalDir *string
|
||||
|
||||
sources map[string]string
|
||||
}
|
||||
|
||||
288
agent/internal/features/api/api.go
Normal file
288
agent/internal/features/api/api.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
chainValidPath = "/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup"
|
||||
nextBackupTimePath = "/api/v1/backups/postgres/wal/next-full-backup-time"
|
||||
walUploadPath = "/api/v1/backups/postgres/wal/upload/wal"
|
||||
fullStartPath = "/api/v1/backups/postgres/wal/upload/full-start"
|
||||
fullCompletePath = "/api/v1/backups/postgres/wal/upload/full-complete"
|
||||
reportErrorPath = "/api/v1/backups/postgres/wal/error"
|
||||
versionPath = "/api/v1/system/version"
|
||||
agentBinaryPath = "/api/v1/system/agent"
|
||||
|
||||
apiCallTimeout = 30 * time.Second
|
||||
maxRetryAttempts = 3
|
||||
retryBaseDelay = 1 * time.Second
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
json *resty.Client
|
||||
stream *resty.Client
|
||||
host string
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
func NewClient(host, token string, log *slog.Logger) *Client {
|
||||
setAuth := func(_ *resty.Client, req *resty.Request) error {
|
||||
if token != "" {
|
||||
req.SetHeader("Authorization", token)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
jsonClient := resty.New().
|
||||
SetTimeout(apiCallTimeout).
|
||||
SetRetryCount(maxRetryAttempts - 1).
|
||||
SetRetryWaitTime(retryBaseDelay).
|
||||
SetRetryMaxWaitTime(4 * retryBaseDelay).
|
||||
AddRetryCondition(func(resp *resty.Response, err error) bool {
|
||||
return err != nil || resp.StatusCode() >= 500
|
||||
}).
|
||||
OnBeforeRequest(setAuth)
|
||||
|
||||
streamClient := resty.New().
|
||||
OnBeforeRequest(setAuth)
|
||||
|
||||
return &Client{
|
||||
json: jsonClient,
|
||||
stream: streamClient,
|
||||
host: host,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) CheckWalChainValidity(ctx context.Context) (*WalChainValidityResponse, error) {
|
||||
var resp WalChainValidityResponse
|
||||
|
||||
httpResp, err := c.json.R().
|
||||
SetContext(ctx).
|
||||
SetResult(&resp).
|
||||
Get(c.buildURL(chainValidPath))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := c.checkResponse(httpResp, "check WAL chain validity"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (c *Client) GetNextFullBackupTime(ctx context.Context) (*NextFullBackupTimeResponse, error) {
|
||||
var resp NextFullBackupTimeResponse
|
||||
|
||||
httpResp, err := c.json.R().
|
||||
SetContext(ctx).
|
||||
SetResult(&resp).
|
||||
Get(c.buildURL(nextBackupTimePath))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := c.checkResponse(httpResp, "get next full backup time"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (c *Client) ReportBackupError(ctx context.Context, errMsg string) error {
|
||||
httpResp, err := c.json.R().
|
||||
SetContext(ctx).
|
||||
SetBody(reportErrorRequest{Error: errMsg}).
|
||||
Post(c.buildURL(reportErrorPath))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.checkResponse(httpResp, "report backup error")
|
||||
}
|
||||
|
||||
func (c *Client) UploadBasebackup(
|
||||
ctx context.Context,
|
||||
body io.Reader,
|
||||
) (*UploadBasebackupResponse, error) {
|
||||
resp, err := c.stream.R().
|
||||
SetContext(ctx).
|
||||
SetBody(body).
|
||||
SetHeader("Content-Type", "application/octet-stream").
|
||||
SetDoNotParseResponse(true).
|
||||
Post(c.buildURL(fullStartPath))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upload request: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.RawBody().Close() }()
|
||||
|
||||
if resp.StatusCode() != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.RawBody())
|
||||
|
||||
return nil, fmt.Errorf("upload failed with status %d: %s", resp.StatusCode(), string(respBody))
|
||||
}
|
||||
|
||||
var result UploadBasebackupResponse
|
||||
if err := json.NewDecoder(resp.RawBody()).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("decode upload response: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (c *Client) FinalizeBasebackup(
|
||||
ctx context.Context,
|
||||
backupID string,
|
||||
startSegment string,
|
||||
stopSegment string,
|
||||
) error {
|
||||
resp, err := c.json.R().
|
||||
SetContext(ctx).
|
||||
SetBody(finalizeBasebackupRequest{
|
||||
BackupID: backupID,
|
||||
StartSegment: startSegment,
|
||||
StopSegment: stopSegment,
|
||||
}).
|
||||
Post(c.buildURL(fullCompletePath))
|
||||
if err != nil {
|
||||
return fmt.Errorf("finalize request: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode() != http.StatusOK {
|
||||
return fmt.Errorf("finalize failed with status %d: %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) FinalizeBasebackupWithError(
|
||||
ctx context.Context,
|
||||
backupID string,
|
||||
errMsg string,
|
||||
) error {
|
||||
resp, err := c.json.R().
|
||||
SetContext(ctx).
|
||||
SetBody(finalizeBasebackupRequest{
|
||||
BackupID: backupID,
|
||||
Error: &errMsg,
|
||||
}).
|
||||
Post(c.buildURL(fullCompletePath))
|
||||
if err != nil {
|
||||
return fmt.Errorf("finalize-with-error request: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode() != http.StatusOK {
|
||||
return fmt.Errorf("finalize-with-error failed with status %d: %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) UploadWalSegment(
|
||||
ctx context.Context,
|
||||
segmentName string,
|
||||
body io.Reader,
|
||||
) (*UploadWalSegmentResult, error) {
|
||||
resp, err := c.stream.R().
|
||||
SetContext(ctx).
|
||||
SetBody(body).
|
||||
SetHeader("Content-Type", "application/octet-stream").
|
||||
SetHeader("X-Wal-Segment-Name", segmentName).
|
||||
SetDoNotParseResponse(true).
|
||||
Post(c.buildURL(walUploadPath))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upload request: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.RawBody().Close() }()
|
||||
|
||||
switch resp.StatusCode() {
|
||||
case http.StatusNoContent:
|
||||
return &UploadWalSegmentResult{IsGapDetected: false}, nil
|
||||
|
||||
case http.StatusConflict:
|
||||
var errResp uploadErrorResponse
|
||||
|
||||
if err := json.NewDecoder(resp.RawBody()).Decode(&errResp); err != nil {
|
||||
return &UploadWalSegmentResult{IsGapDetected: true}, nil
|
||||
}
|
||||
|
||||
return &UploadWalSegmentResult{
|
||||
IsGapDetected: true,
|
||||
ExpectedSegmentName: errResp.ExpectedSegmentName,
|
||||
ReceivedSegmentName: errResp.ReceivedSegmentName,
|
||||
}, nil
|
||||
|
||||
default:
|
||||
respBody, _ := io.ReadAll(resp.RawBody())
|
||||
|
||||
return nil, fmt.Errorf("upload failed with status %d: %s", resp.StatusCode(), string(respBody))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) FetchServerVersion(ctx context.Context) (string, error) {
|
||||
var ver versionResponse
|
||||
|
||||
httpResp, err := c.json.R().
|
||||
SetContext(ctx).
|
||||
SetResult(&ver).
|
||||
Get(c.buildURL(versionPath))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := c.checkResponse(httpResp, "fetch server version"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return ver.Version, nil
|
||||
}
|
||||
|
||||
func (c *Client) DownloadAgentBinary(ctx context.Context, arch, destPath string) error {
|
||||
resp, err := c.stream.R().
|
||||
SetContext(ctx).
|
||||
SetQueryParam("arch", arch).
|
||||
SetDoNotParseResponse(true).
|
||||
Get(c.buildURL(agentBinaryPath))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = resp.RawBody().Close() }()
|
||||
|
||||
if resp.StatusCode() != http.StatusOK {
|
||||
return fmt.Errorf("server returned %d for agent download", resp.StatusCode())
|
||||
}
|
||||
|
||||
f, err := os.Create(destPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
_, err = io.Copy(f, resp.RawBody())
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) buildURL(path string) string {
|
||||
return c.host + path
|
||||
}
|
||||
|
||||
func (c *Client) checkResponse(resp *resty.Response, method string) error {
|
||||
if resp.StatusCode() >= 400 {
|
||||
return fmt.Errorf("%s: server returned status %d: %s", method, resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
44
agent/internal/features/api/dto.go
Normal file
44
agent/internal/features/api/dto.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package api
|
||||
|
||||
import "time"
|
||||
|
||||
type WalChainValidityResponse struct {
|
||||
IsValid bool `json:"isValid"`
|
||||
Error string `json:"error,omitempty"`
|
||||
LastContiguousSegment string `json:"lastContiguousSegment,omitempty"`
|
||||
}
|
||||
|
||||
type NextFullBackupTimeResponse struct {
|
||||
NextFullBackupTime *time.Time `json:"nextFullBackupTime"`
|
||||
}
|
||||
|
||||
type UploadWalSegmentResult struct {
|
||||
IsGapDetected bool
|
||||
ExpectedSegmentName string
|
||||
ReceivedSegmentName string
|
||||
}
|
||||
|
||||
type reportErrorRequest struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
type versionResponse struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
type UploadBasebackupResponse struct {
|
||||
BackupID string `json:"backupId"`
|
||||
}
|
||||
|
||||
type finalizeBasebackupRequest struct {
|
||||
BackupID string `json:"backupId"`
|
||||
StartSegment string `json:"startSegment"`
|
||||
StopSegment string `json:"stopSegment"`
|
||||
Error *string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type uploadErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
ExpectedSegmentName string `json:"expectedSegmentName"`
|
||||
ReceivedSegmentName string `json:"receivedSegmentName"`
|
||||
}
|
||||
292
agent/internal/features/full_backup/backuper.go
Normal file
292
agent/internal/features/full_backup/backuper.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package full_backup
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"databasus-agent/internal/config"
|
||||
"databasus-agent/internal/features/api"
|
||||
)
|
||||
|
||||
const (
|
||||
checkInterval = 30 * time.Second
|
||||
retryDelay = 1 * time.Minute
|
||||
uploadTimeout = 30 * time.Minute
|
||||
)
|
||||
|
||||
var retryDelayOverride *time.Duration
|
||||
|
||||
type CmdBuilder func(ctx context.Context) *exec.Cmd
|
||||
|
||||
// FullBackuper runs pg_basebackup when the WAL chain is broken or a scheduled backup is due.
|
||||
//
|
||||
// Every 30 seconds it checks two conditions via the Databasus API:
|
||||
// 1. WAL chain validity — if broken or no full backup exists, triggers an immediate basebackup.
|
||||
// 2. Scheduled backup time — if the next full backup time has passed, triggers a basebackup.
|
||||
//
|
||||
// Only one basebackup runs at a time (guarded by atomic bool).
|
||||
// On failure the error is reported to the server and the backup retries after 1 minute, indefinitely.
|
||||
// WAL segment uploads (handled by wal.Streamer) continue independently and are not paused.
|
||||
//
|
||||
// pg_basebackup runs as "pg_basebackup -Ft -D - -X none --verbose". Stdout (tar) is zstd-compressed
|
||||
// and uploaded to the server. Stderr is parsed for WAL start/stop segment names (LSN → segment arithmetic).
|
||||
type FullBackuper struct {
|
||||
cfg *config.Config
|
||||
apiClient *api.Client
|
||||
log *slog.Logger
|
||||
isRunning atomic.Bool
|
||||
cmdBuilder CmdBuilder
|
||||
}
|
||||
|
||||
func NewFullBackuper(cfg *config.Config, apiClient *api.Client, log *slog.Logger) *FullBackuper {
|
||||
backuper := &FullBackuper{
|
||||
cfg: cfg,
|
||||
apiClient: apiClient,
|
||||
log: log,
|
||||
}
|
||||
|
||||
backuper.cmdBuilder = backuper.defaultCmdBuilder
|
||||
|
||||
return backuper
|
||||
}
|
||||
|
||||
func (backuper *FullBackuper) Run(ctx context.Context) {
|
||||
backuper.log.Info("Full backuper started")
|
||||
|
||||
backuper.checkAndRunIfNeeded(ctx)
|
||||
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
backuper.log.Info("Full backuper stopping")
|
||||
return
|
||||
case <-ticker.C:
|
||||
backuper.checkAndRunIfNeeded(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (backuper *FullBackuper) checkAndRunIfNeeded(ctx context.Context) {
|
||||
if backuper.isRunning.Load() {
|
||||
backuper.log.Debug("Skipping check: basebackup already in progress")
|
||||
return
|
||||
}
|
||||
|
||||
chainResp, err := backuper.apiClient.CheckWalChainValidity(ctx)
|
||||
if err != nil {
|
||||
backuper.log.Error("Failed to check WAL chain validity", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !chainResp.IsValid {
|
||||
backuper.log.Info("WAL chain is invalid, triggering basebackup",
|
||||
"error", chainResp.Error,
|
||||
"lastContiguousSegment", chainResp.LastContiguousSegment,
|
||||
)
|
||||
|
||||
backuper.runBasebackupWithRetry(ctx)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
nextTimeResp, err := backuper.apiClient.GetNextFullBackupTime(ctx)
|
||||
if err != nil {
|
||||
backuper.log.Error("Failed to check next full backup time", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if nextTimeResp.NextFullBackupTime == nil || !nextTimeResp.NextFullBackupTime.After(time.Now().UTC()) {
|
||||
backuper.log.Info("Scheduled full backup is due, triggering basebackup")
|
||||
backuper.runBasebackupWithRetry(ctx)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
backuper.log.Debug("No basebackup needed",
|
||||
"nextFullBackupTime", nextTimeResp.NextFullBackupTime,
|
||||
)
|
||||
}
|
||||
|
||||
func (backuper *FullBackuper) runBasebackupWithRetry(ctx context.Context) {
|
||||
if !backuper.isRunning.CompareAndSwap(false, true) {
|
||||
backuper.log.Debug("Skipping basebackup: already running")
|
||||
return
|
||||
}
|
||||
defer backuper.isRunning.Store(false)
|
||||
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
backuper.log.Info("Starting pg_basebackup")
|
||||
|
||||
err := backuper.executeAndUploadBasebackup(ctx)
|
||||
if err == nil {
|
||||
backuper.log.Info("Basebackup completed successfully")
|
||||
return
|
||||
}
|
||||
|
||||
backuper.log.Error("Basebackup failed", "error", err)
|
||||
backuper.reportError(ctx, err.Error())
|
||||
|
||||
delay := retryDelay
|
||||
if retryDelayOverride != nil {
|
||||
delay = *retryDelayOverride
|
||||
}
|
||||
|
||||
backuper.log.Info("Retrying basebackup after delay", "delay", delay)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(delay):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (backuper *FullBackuper) executeAndUploadBasebackup(ctx context.Context) error {
|
||||
cmd := backuper.cmdBuilder(ctx)
|
||||
|
||||
var stderrBuf bytes.Buffer
|
||||
cmd.Stderr = &stderrBuf
|
||||
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fmt.Errorf("start pg_basebackup: %w", err)
|
||||
}
|
||||
|
||||
// Phase 1: Stream compressed data via io.Pipe directly to the API.
|
||||
pipeReader, pipeWriter := io.Pipe()
|
||||
go backuper.compressAndStream(pipeWriter, stdoutPipe)
|
||||
|
||||
uploadCtx, cancel := context.WithTimeout(ctx, uploadTimeout)
|
||||
defer cancel()
|
||||
|
||||
uploadResp, uploadErr := backuper.apiClient.UploadBasebackup(uploadCtx, pipeReader)
|
||||
|
||||
cmdErr := cmd.Wait()
|
||||
|
||||
if uploadErr != nil {
|
||||
return fmt.Errorf("upload basebackup: %w", uploadErr)
|
||||
}
|
||||
|
||||
if cmdErr != nil {
|
||||
errMsg := fmt.Sprintf("pg_basebackup exited with error: %v (stderr: %s)", cmdErr, stderrBuf.String())
|
||||
_ = backuper.apiClient.FinalizeBasebackupWithError(ctx, uploadResp.BackupID, errMsg)
|
||||
|
||||
return fmt.Errorf("pg_basebackup: %w", cmdErr)
|
||||
}
|
||||
|
||||
// Phase 2: Parse stderr for WAL segments and finalize the backup.
|
||||
stderrStr := stderrBuf.String()
|
||||
backuper.log.Debug("pg_basebackup stderr", "stderr", stderrStr)
|
||||
|
||||
startSegment, stopSegment, err := ParseBasebackupStderr(stderrStr)
|
||||
if err != nil {
|
||||
errMsg := fmt.Sprintf("parse pg_basebackup stderr: %v", err)
|
||||
_ = backuper.apiClient.FinalizeBasebackupWithError(ctx, uploadResp.BackupID, errMsg)
|
||||
|
||||
return fmt.Errorf("parse pg_basebackup stderr: %w", err)
|
||||
}
|
||||
|
||||
backuper.log.Info("Basebackup WAL segments parsed",
|
||||
"startSegment", startSegment,
|
||||
"stopSegment", stopSegment,
|
||||
"backupId", uploadResp.BackupID,
|
||||
)
|
||||
|
||||
if err := backuper.apiClient.FinalizeBasebackup(ctx, uploadResp.BackupID, startSegment, stopSegment); err != nil {
|
||||
return fmt.Errorf("finalize basebackup: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (backuper *FullBackuper) compressAndStream(pipeWriter *io.PipeWriter, reader io.Reader) {
|
||||
encoder, err := zstd.NewWriter(pipeWriter,
|
||||
zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(5)),
|
||||
zstd.WithEncoderCRC(true),
|
||||
)
|
||||
if err != nil {
|
||||
_ = pipeWriter.CloseWithError(fmt.Errorf("create zstd encoder: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := io.Copy(encoder, reader); err != nil {
|
||||
_ = encoder.Close()
|
||||
_ = pipeWriter.CloseWithError(fmt.Errorf("compress: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := encoder.Close(); err != nil {
|
||||
_ = pipeWriter.CloseWithError(fmt.Errorf("close encoder: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
_ = pipeWriter.Close()
|
||||
}
|
||||
|
||||
func (backuper *FullBackuper) reportError(ctx context.Context, errMsg string) {
|
||||
if err := backuper.apiClient.ReportBackupError(ctx, errMsg); err != nil {
|
||||
backuper.log.Error("Failed to report error to server", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (backuper *FullBackuper) defaultCmdBuilder(ctx context.Context) *exec.Cmd {
|
||||
switch backuper.cfg.PgType {
|
||||
case "docker":
|
||||
return backuper.buildDockerCmd(ctx)
|
||||
default:
|
||||
return backuper.buildHostCmd(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (backuper *FullBackuper) buildHostCmd(ctx context.Context) *exec.Cmd {
|
||||
binary := "pg_basebackup"
|
||||
if backuper.cfg.PgHostBinDir != "" {
|
||||
binary = filepath.Join(backuper.cfg.PgHostBinDir, "pg_basebackup")
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, binary,
|
||||
"-Ft", "-D", "-", "-X", "none", "--verbose",
|
||||
"-h", backuper.cfg.PgHost,
|
||||
"-p", fmt.Sprintf("%d", backuper.cfg.PgPort),
|
||||
"-U", backuper.cfg.PgUser,
|
||||
)
|
||||
|
||||
cmd.Env = append(os.Environ(), "PGPASSWORD="+backuper.cfg.PgPassword)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (backuper *FullBackuper) buildDockerCmd(ctx context.Context) *exec.Cmd {
|
||||
cmd := exec.CommandContext(ctx, "docker", "exec",
|
||||
"-e", "PGPASSWORD="+backuper.cfg.PgPassword,
|
||||
"-i", backuper.cfg.PgDockerContainerName,
|
||||
"pg_basebackup",
|
||||
"-Ft", "-D", "-", "-X", "none", "--verbose",
|
||||
"-h", backuper.cfg.PgHost,
|
||||
"-p", fmt.Sprintf("%d", backuper.cfg.PgPort),
|
||||
"-U", backuper.cfg.PgUser,
|
||||
)
|
||||
|
||||
return cmd
|
||||
}
|
||||
671
agent/internal/features/full_backup/backuper_test.go
Normal file
671
agent/internal/features/full_backup/backuper_test.go
Normal file
@@ -0,0 +1,671 @@
|
||||
package full_backup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"databasus-agent/internal/config"
|
||||
"databasus-agent/internal/features/api"
|
||||
"databasus-agent/internal/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
testChainValidPath = "/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup"
|
||||
testNextBackupTimePath = "/api/v1/backups/postgres/wal/next-full-backup-time"
|
||||
testFullStartPath = "/api/v1/backups/postgres/wal/upload/full-start"
|
||||
testFullCompletePath = "/api/v1/backups/postgres/wal/upload/full-complete"
|
||||
testReportErrorPath = "/api/v1/backups/postgres/wal/error"
|
||||
|
||||
testBackupID = "test-backup-id-1234"
|
||||
)
|
||||
|
||||
func Test_RunFullBackup_WhenChainBroken_BasebackupTriggered(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var uploadReceived bool
|
||||
var uploadHeaders http.Header
|
||||
var finalizeReceived bool
|
||||
var finalizeBody map[string]any
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{
|
||||
IsValid: false,
|
||||
Error: "wal_chain_broken",
|
||||
LastContiguousSegment: "000000010000000100000011",
|
||||
})
|
||||
case testFullStartPath:
|
||||
mu.Lock()
|
||||
uploadReceived = true
|
||||
uploadHeaders = r.Header.Clone()
|
||||
mu.Unlock()
|
||||
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
mu.Lock()
|
||||
finalizeReceived = true
|
||||
_ = json.NewDecoder(r.Body).Decode(&finalizeBody)
|
||||
mu.Unlock()
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "test-backup-data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
waitForCondition(t, func() bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return finalizeReceived
|
||||
}, 5*time.Second)
|
||||
cancel()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
assert.True(t, uploadReceived)
|
||||
assert.Equal(t, "application/octet-stream", uploadHeaders.Get("Content-Type"))
|
||||
assert.Equal(t, "test-token", uploadHeaders.Get("Authorization"))
|
||||
|
||||
assert.True(t, finalizeReceived)
|
||||
assert.Equal(t, testBackupID, finalizeBody["backupId"])
|
||||
assert.Equal(t, "000000010000000000000002", finalizeBody["startSegment"])
|
||||
assert.Equal(t, "000000010000000000000002", finalizeBody["stopSegment"])
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenScheduledBackupDue_BasebackupTriggered(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var finalizeReceived bool
|
||||
|
||||
pastTime := time.Now().UTC().Add(-1 * time.Hour)
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{IsValid: true})
|
||||
case testNextBackupTimePath:
|
||||
writeJSON(w, api.NextFullBackupTimeResponse{NextFullBackupTime: &pastTime})
|
||||
case testFullStartPath:
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
mu.Lock()
|
||||
finalizeReceived = true
|
||||
mu.Unlock()
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "scheduled-backup-data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
waitForCondition(t, func() bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return finalizeReceived
|
||||
}, 5*time.Second)
|
||||
cancel()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
assert.True(t, finalizeReceived)
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenNoFullBackupExists_ImmediateBasebackupTriggered(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var finalizeReceived bool
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{
|
||||
IsValid: false,
|
||||
Error: "no_full_backup",
|
||||
})
|
||||
case testFullStartPath:
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
mu.Lock()
|
||||
finalizeReceived = true
|
||||
mu.Unlock()
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "first-backup-data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
waitForCondition(t, func() bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return finalizeReceived
|
||||
}, 5*time.Second)
|
||||
cancel()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
assert.True(t, finalizeReceived)
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenUploadFails_RetriesAfterDelay(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var uploadAttempts int
|
||||
var errorReported bool
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{
|
||||
IsValid: false,
|
||||
Error: "no_full_backup",
|
||||
})
|
||||
case testFullStartPath:
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
|
||||
mu.Lock()
|
||||
uploadAttempts++
|
||||
attempt := uploadAttempts
|
||||
mu.Unlock()
|
||||
|
||||
if attempt == 1 {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte(`{"error":"storage unavailable"}`))
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
case testReportErrorPath:
|
||||
mu.Lock()
|
||||
errorReported = true
|
||||
mu.Unlock()
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "retry-backup-data", validStderr())
|
||||
|
||||
origRetryDelay := retryDelay
|
||||
setRetryDelay(100 * time.Millisecond)
|
||||
defer setRetryDelay(origRetryDelay)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
waitForCondition(t, func() bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return uploadAttempts >= 2
|
||||
}, 10*time.Second)
|
||||
cancel()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
assert.GreaterOrEqual(t, uploadAttempts, 2)
|
||||
assert.True(t, errorReported)
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenAlreadyRunning_SkipsExecution(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var uploadCount int
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{
|
||||
IsValid: false,
|
||||
Error: "no_full_backup",
|
||||
})
|
||||
case testFullStartPath:
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
|
||||
mu.Lock()
|
||||
uploadCount++
|
||||
mu.Unlock()
|
||||
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr())
|
||||
|
||||
fb.isRunning.Store(true)
|
||||
|
||||
fb.checkAndRunIfNeeded(context.Background())
|
||||
|
||||
mu.Lock()
|
||||
count := uploadCount
|
||||
mu.Unlock()
|
||||
|
||||
assert.Equal(t, 0, count, "should not trigger backup when already running")
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenContextCancelled_StopsCleanly(t *testing.T) {
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{
|
||||
IsValid: false,
|
||||
Error: "no_full_backup",
|
||||
})
|
||||
case testFullStartPath:
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
case testFullCompletePath:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
case testReportErrorPath:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr())
|
||||
|
||||
origRetryDelay := retryDelay
|
||||
setRetryDelay(5 * time.Second)
|
||||
defer setRetryDelay(origRetryDelay)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
fb.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Run should have stopped after context cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenChainValidAndNotScheduled_NoBasebackupTriggered(t *testing.T) {
|
||||
var uploadReceived atomic.Bool
|
||||
|
||||
futureTime := time.Now().UTC().Add(24 * time.Hour)
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{IsValid: true})
|
||||
case testNextBackupTimePath:
|
||||
writeJSON(w, api.NextFullBackupTimeResponse{NextFullBackupTime: &futureTime})
|
||||
case testFullStartPath:
|
||||
uploadReceived.Store(true)
|
||||
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
assert.False(t, uploadReceived.Load(), "should not trigger backup when chain valid and not scheduled")
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenStderrParsingFails_FinalizesWithErrorAndRetries(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var errorReported bool
|
||||
var finalizeWithErrorReceived bool
|
||||
var finalizeBody map[string]any
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{
|
||||
IsValid: false,
|
||||
Error: "no_full_backup",
|
||||
})
|
||||
case testFullStartPath:
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
mu.Lock()
|
||||
finalizeWithErrorReceived = true
|
||||
_ = json.NewDecoder(r.Body).Decode(&finalizeBody)
|
||||
mu.Unlock()
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
case testReportErrorPath:
|
||||
mu.Lock()
|
||||
errorReported = true
|
||||
mu.Unlock()
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "data", "pg_basebackup: unexpected output with no LSN info")
|
||||
|
||||
origRetryDelay := retryDelay
|
||||
setRetryDelay(100 * time.Millisecond)
|
||||
defer setRetryDelay(origRetryDelay)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
waitForCondition(t, func() bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return errorReported
|
||||
}, 2*time.Second)
|
||||
cancel()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
assert.True(t, errorReported)
|
||||
assert.True(t, finalizeWithErrorReceived, "should finalize with error when stderr parsing fails")
|
||||
assert.Equal(t, testBackupID, finalizeBody["backupId"])
|
||||
assert.NotNil(t, finalizeBody["error"], "finalize should include error message")
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenNextBackupTimeNull_BasebackupTriggered(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var finalizeReceived bool
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{IsValid: true})
|
||||
case testNextBackupTimePath:
|
||||
writeJSON(w, api.NextFullBackupTimeResponse{NextFullBackupTime: nil})
|
||||
case testFullStartPath:
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
mu.Lock()
|
||||
finalizeReceived = true
|
||||
mu.Unlock()
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "first-run-data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
waitForCondition(t, func() bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return finalizeReceived
|
||||
}, 5*time.Second)
|
||||
cancel()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
assert.True(t, finalizeReceived)
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenChainValidityReturns401_NoBasebackupTriggered(t *testing.T) {
|
||||
var uploadReceived atomic.Bool
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"error":"invalid token"}`))
|
||||
case testFullStartPath:
|
||||
uploadReceived.Store(true)
|
||||
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
assert.False(t, uploadReceived.Load(), "should not trigger backup when API returns 401")
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenUploadSucceeds_BodyIsZstdCompressed(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var receivedBody []byte
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{
|
||||
IsValid: false,
|
||||
Error: "no_full_backup",
|
||||
})
|
||||
case testFullStartPath:
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
|
||||
mu.Lock()
|
||||
receivedBody = body
|
||||
mu.Unlock()
|
||||
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
originalContent := "test-backup-content-for-compression-check"
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, originalContent, validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
waitForCondition(t, func() bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return len(receivedBody) > 0
|
||||
}, 5*time.Second)
|
||||
cancel()
|
||||
|
||||
mu.Lock()
|
||||
body := receivedBody
|
||||
mu.Unlock()
|
||||
|
||||
decoder, err := zstd.NewReader(nil)
|
||||
require.NoError(t, err)
|
||||
defer decoder.Close()
|
||||
|
||||
decompressed, err := decoder.DecodeAll(body, nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, originalContent, string(decompressed))
|
||||
}
|
||||
|
||||
func newTestServer(t *testing.T, handler http.HandlerFunc) *httptest.Server {
|
||||
t.Helper()
|
||||
|
||||
server := httptest.NewServer(handler)
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
func newTestFullBackuper(serverURL string) *FullBackuper {
|
||||
cfg := &config.Config{
|
||||
DatabasusHost: serverURL,
|
||||
DbID: "test-db-id",
|
||||
Token: "test-token",
|
||||
PgHost: "localhost",
|
||||
PgPort: 5432,
|
||||
PgUser: "postgres",
|
||||
PgPassword: "password",
|
||||
PgType: "host",
|
||||
}
|
||||
|
||||
apiClient := api.NewClient(serverURL, cfg.Token, logger.GetLogger())
|
||||
|
||||
return NewFullBackuper(cfg, apiClient, logger.GetLogger())
|
||||
}
|
||||
|
||||
func mockCmdBuilder(t *testing.T, stdoutContent, stderrContent string) CmdBuilder {
|
||||
t.Helper()
|
||||
|
||||
return func(ctx context.Context) *exec.Cmd {
|
||||
cmd := exec.CommandContext(ctx, os.Args[0],
|
||||
"-test.run=TestHelperProcess",
|
||||
"--",
|
||||
stdoutContent,
|
||||
stderrContent,
|
||||
)
|
||||
|
||||
cmd.Env = append(os.Environ(), "GO_TEST_HELPER_PROCESS=1")
|
||||
|
||||
return cmd
|
||||
}
|
||||
}
|
||||
|
||||
func TestHelperProcess(t *testing.T) {
|
||||
if os.Getenv("GO_TEST_HELPER_PROCESS") != "1" {
|
||||
return
|
||||
}
|
||||
|
||||
args := os.Args
|
||||
for i, arg := range args {
|
||||
if arg == "--" {
|
||||
args = args[i+1:]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(args) >= 1 {
|
||||
_, _ = fmt.Fprint(os.Stdout, args[0])
|
||||
}
|
||||
|
||||
if len(args) >= 2 {
|
||||
_, _ = fmt.Fprint(os.Stderr, args[1])
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func validStderr() string {
|
||||
return `pg_basebackup: initiating base backup, waiting for checkpoint to complete
|
||||
pg_basebackup: checkpoint completed
|
||||
pg_basebackup: write-ahead log start point: 0/2000028, on timeline 1
|
||||
pg_basebackup: checkpoint redo point at 0/2000028
|
||||
pg_basebackup: write-ahead log end point: 0/2000100
|
||||
pg_basebackup: base backup completed`
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, v any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if err := json.NewEncoder(w).Encode(v); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func waitForCondition(t *testing.T, condition func() bool, timeout time.Duration) {
|
||||
t.Helper()
|
||||
|
||||
deadline := time.Now().Add(timeout)
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
if condition() {
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Fatalf("condition not met within %v", timeout)
|
||||
}
|
||||
|
||||
func setRetryDelay(d time.Duration) {
|
||||
retryDelayOverride = &d
|
||||
}
|
||||
|
||||
func init() {
|
||||
retryDelayOverride = nil
|
||||
}
|
||||
75
agent/internal/features/full_backup/stderr_parser.go
Normal file
75
agent/internal/features/full_backup/stderr_parser.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package full_backup
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const defaultWalSegmentSize uint32 = 16 * 1024 * 1024 // 16 MB
|
||||
|
||||
var (
|
||||
startLSNRegex = regexp.MustCompile(`checkpoint redo point at ([0-9A-Fa-f]+/[0-9A-Fa-f]+)`)
|
||||
stopLSNRegex = regexp.MustCompile(`write-ahead log end point: ([0-9A-Fa-f]+/[0-9A-Fa-f]+)`)
|
||||
)
|
||||
|
||||
func ParseBasebackupStderr(stderr string) (startSegment, stopSegment string, err error) {
|
||||
startMatch := startLSNRegex.FindStringSubmatch(stderr)
|
||||
if len(startMatch) < 2 {
|
||||
return "", "", fmt.Errorf("failed to parse start WAL location from pg_basebackup stderr")
|
||||
}
|
||||
|
||||
stopMatch := stopLSNRegex.FindStringSubmatch(stderr)
|
||||
if len(stopMatch) < 2 {
|
||||
return "", "", fmt.Errorf("failed to parse stop WAL location from pg_basebackup stderr")
|
||||
}
|
||||
|
||||
startSegment, err = LSNToSegmentName(startMatch[1], 1, defaultWalSegmentSize)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to convert start LSN to segment name: %w", err)
|
||||
}
|
||||
|
||||
stopSegment, err = LSNToSegmentName(stopMatch[1], 1, defaultWalSegmentSize)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to convert stop LSN to segment name: %w", err)
|
||||
}
|
||||
|
||||
return startSegment, stopSegment, nil
|
||||
}
|
||||
|
||||
func LSNToSegmentName(lsn string, timelineID, walSegmentSize uint32) (string, error) {
|
||||
high, low, err := parseLSN(lsn)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
segmentsPerXLogID := uint32(0x100000000 / uint64(walSegmentSize))
|
||||
logID := high
|
||||
segmentOffset := low / walSegmentSize
|
||||
|
||||
if segmentOffset >= segmentsPerXLogID {
|
||||
return "", fmt.Errorf("segment offset %d exceeds segments per XLogId %d", segmentOffset, segmentsPerXLogID)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%08X%08X%08X", timelineID, logID, segmentOffset), nil
|
||||
}
|
||||
|
||||
func parseLSN(lsn string) (high, low uint32, err error) {
|
||||
parts := strings.SplitN(lsn, "/", 2)
|
||||
if len(parts) != 2 {
|
||||
return 0, 0, fmt.Errorf("invalid LSN format: %q (expected X/Y)", lsn)
|
||||
}
|
||||
|
||||
highVal, err := strconv.ParseUint(parts[0], 16, 32)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("invalid LSN high part %q: %w", parts[0], err)
|
||||
}
|
||||
|
||||
lowVal, err := strconv.ParseUint(parts[1], 16, 32)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("invalid LSN low part %q: %w", parts[1], err)
|
||||
}
|
||||
|
||||
return uint32(highVal), uint32(lowVal), nil
|
||||
}
|
||||
162
agent/internal/features/full_backup/stderr_parser_test.go
Normal file
162
agent/internal/features/full_backup/stderr_parser_test.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package full_backup
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_ParseBasebackupStderr_WithPG17Output_ExtractsCorrectSegments(t *testing.T) {
|
||||
stderr := `pg_basebackup: initiating base backup, waiting for checkpoint to complete
|
||||
pg_basebackup: checkpoint completed
|
||||
pg_basebackup: write-ahead log start point: 0/2000028, on timeline 1
|
||||
pg_basebackup: starting background WAL receiver
|
||||
pg_basebackup: checkpoint redo point at 0/2000028
|
||||
pg_basebackup: write-ahead log end point: 0/2000100
|
||||
pg_basebackup: waiting for background process to finish streaming ...
|
||||
pg_basebackup: syncing data to disk ...
|
||||
pg_basebackup: renaming backup_manifest.tmp to backup_manifest
|
||||
pg_basebackup: base backup completed`
|
||||
|
||||
startSeg, stopSeg, err := ParseBasebackupStderr(stderr)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "000000010000000000000002", startSeg)
|
||||
assert.Equal(t, "000000010000000000000002", stopSeg)
|
||||
}
|
||||
|
||||
func Test_ParseBasebackupStderr_WithPG15Output_ExtractsCorrectSegments(t *testing.T) {
|
||||
stderr := `pg_basebackup: initiating base backup, waiting for checkpoint to complete
|
||||
pg_basebackup: checkpoint completed
|
||||
pg_basebackup: write-ahead log start point: 1/AB000028, on timeline 1
|
||||
pg_basebackup: checkpoint redo point at 1/AB000028
|
||||
pg_basebackup: write-ahead log end point: 1/AC000000
|
||||
pg_basebackup: base backup completed`
|
||||
|
||||
startSeg, stopSeg, err := ParseBasebackupStderr(stderr)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "0000000100000001000000AB", startSeg)
|
||||
assert.Equal(t, "0000000100000001000000AC", stopSeg)
|
||||
}
|
||||
|
||||
func Test_ParseBasebackupStderr_WithHighLogID_ExtractsCorrectSegments(t *testing.T) {
|
||||
stderr := `pg_basebackup: checkpoint redo point at A/FF000028
|
||||
pg_basebackup: write-ahead log end point: B/1000000`
|
||||
|
||||
startSeg, stopSeg, err := ParseBasebackupStderr(stderr)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "000000010000000A000000FF", startSeg)
|
||||
assert.Equal(t, "000000010000000B00000001", stopSeg)
|
||||
}
|
||||
|
||||
func Test_ParseBasebackupStderr_WhenStartLSNMissing_ReturnsError(t *testing.T) {
|
||||
stderr := `pg_basebackup: write-ahead log end point: 0/2000100
|
||||
pg_basebackup: base backup completed`
|
||||
|
||||
_, _, err := ParseBasebackupStderr(stderr)
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to parse start WAL location")
|
||||
}
|
||||
|
||||
func Test_ParseBasebackupStderr_WhenStopLSNMissing_ReturnsError(t *testing.T) {
|
||||
stderr := `pg_basebackup: checkpoint redo point at 0/2000028
|
||||
pg_basebackup: base backup completed`
|
||||
|
||||
_, _, err := ParseBasebackupStderr(stderr)
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to parse stop WAL location")
|
||||
}
|
||||
|
||||
func Test_ParseBasebackupStderr_WhenEmptyStderr_ReturnsError(t *testing.T) {
|
||||
_, _, err := ParseBasebackupStderr("")
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to parse start WAL location")
|
||||
}
|
||||
|
||||
func Test_LSNToSegmentName_WithBoundaryValues_ConvertsCorrectly(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
lsn string
|
||||
timeline uint32
|
||||
segSize uint32
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "first segment",
|
||||
lsn: "0/1000000",
|
||||
timeline: 1,
|
||||
segSize: 16 * 1024 * 1024,
|
||||
expected: "000000010000000000000001",
|
||||
},
|
||||
{
|
||||
name: "segment at boundary FF",
|
||||
lsn: "0/FF000000",
|
||||
timeline: 1,
|
||||
segSize: 16 * 1024 * 1024,
|
||||
expected: "0000000100000000000000FF",
|
||||
},
|
||||
{
|
||||
name: "segment in second log file",
|
||||
lsn: "1/0",
|
||||
timeline: 1,
|
||||
segSize: 16 * 1024 * 1024,
|
||||
expected: "000000010000000100000000",
|
||||
},
|
||||
{
|
||||
name: "segment with offset within 16MB",
|
||||
lsn: "0/200ABCD",
|
||||
timeline: 1,
|
||||
segSize: 16 * 1024 * 1024,
|
||||
expected: "000000010000000000000002",
|
||||
},
|
||||
{
|
||||
name: "zero LSN",
|
||||
lsn: "0/0",
|
||||
timeline: 1,
|
||||
segSize: 16 * 1024 * 1024,
|
||||
expected: "000000010000000000000000",
|
||||
},
|
||||
{
|
||||
name: "high timeline ID",
|
||||
lsn: "0/1000000",
|
||||
timeline: 2,
|
||||
segSize: 16 * 1024 * 1024,
|
||||
expected: "000000020000000000000001",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := LSNToSegmentName(tt.lsn, tt.timeline, tt.segSize)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_LSNToSegmentName_WithInvalidLSN_ReturnsError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
lsn string
|
||||
}{
|
||||
{name: "no slash", lsn: "012345"},
|
||||
{name: "empty string", lsn: ""},
|
||||
{name: "invalid hex high", lsn: "GG/0"},
|
||||
{name: "invalid hex low", lsn: "0/ZZ"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := LSNToSegmentName(tt.lsn, 1, 16*1024*1024)
|
||||
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
121
agent/internal/features/start/daemon.go
Normal file
121
agent/internal/features/start/daemon.go
Normal file
@@ -0,0 +1,121 @@
|
||||
//go:build !windows
|
||||
|
||||
package start
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
logFileName = "databasus.log"
|
||||
stopTimeout = 30 * time.Second
|
||||
stopPollInterval = 500 * time.Millisecond
|
||||
daemonStartupDelay = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
func Stop(log *slog.Logger) error {
|
||||
pid, err := ReadLockFilePID()
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return errors.New("agent is not running (no lock file found)")
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to read lock file: %w", err)
|
||||
}
|
||||
|
||||
if !isProcessAlive(pid) {
|
||||
_ = os.Remove(lockFileName)
|
||||
return fmt.Errorf("agent is not running (stale lock file removed, PID %d)", pid)
|
||||
}
|
||||
|
||||
log.Info("Sending SIGTERM to agent", "pid", pid)
|
||||
|
||||
if err := syscall.Kill(pid, syscall.SIGTERM); err != nil {
|
||||
return fmt.Errorf("failed to send SIGTERM to PID %d: %w", pid, err)
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(stopTimeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if !isProcessAlive(pid) {
|
||||
log.Info("Agent stopped", "pid", pid)
|
||||
return nil
|
||||
}
|
||||
|
||||
time.Sleep(stopPollInterval)
|
||||
}
|
||||
|
||||
return fmt.Errorf("agent (PID %d) did not stop within %s — process may be stuck", pid, stopTimeout)
|
||||
}
|
||||
|
||||
func Status(log *slog.Logger) error {
|
||||
pid, err := ReadLockFilePID()
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
fmt.Println("Agent is not running")
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to read lock file: %w", err)
|
||||
}
|
||||
|
||||
if isProcessAlive(pid) {
|
||||
fmt.Printf("Agent is running (PID %d)\n", pid)
|
||||
} else {
|
||||
fmt.Println("Agent is not running (stale lock file)")
|
||||
_ = os.Remove(lockFileName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func spawnDaemon(log *slog.Logger) (int, error) {
|
||||
execPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to resolve executable path: %w", err)
|
||||
}
|
||||
|
||||
args := []string{"_run"}
|
||||
|
||||
logFile, err := os.OpenFile(logFileName, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to open log file %s: %w", logFileName, err)
|
||||
}
|
||||
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
_ = logFile.Close()
|
||||
return 0, fmt.Errorf("failed to get working directory: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(context.Background(), execPath, args...)
|
||||
cmd.Dir = cwd
|
||||
cmd.Stderr = logFile
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
_ = logFile.Close()
|
||||
return 0, fmt.Errorf("failed to start daemon process: %w", err)
|
||||
}
|
||||
|
||||
pid := cmd.Process.Pid
|
||||
|
||||
// Detach — we don't wait for the child
|
||||
_ = logFile.Close()
|
||||
|
||||
time.Sleep(daemonStartupDelay)
|
||||
|
||||
if !isProcessAlive(pid) {
|
||||
return 0, fmt.Errorf("daemon process (PID %d) exited immediately — check %s for details", pid, logFileName)
|
||||
}
|
||||
|
||||
log.Info("Daemon spawned", "pid", pid, "log", logFileName)
|
||||
|
||||
return pid, nil
|
||||
}
|
||||
20
agent/internal/features/start/daemon_windows.go
Normal file
20
agent/internal/features/start/daemon_windows.go
Normal file
@@ -0,0 +1,20 @@
|
||||
//go:build windows
|
||||
|
||||
package start
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
func Stop(log *slog.Logger) error {
|
||||
return errors.New("stop is not supported on Windows — use Ctrl+C in the terminal where the agent is running")
|
||||
}
|
||||
|
||||
func Status(log *slog.Logger) error {
|
||||
return errors.New("status is not supported on Windows — check the terminal where the agent is running")
|
||||
}
|
||||
|
||||
func spawnDaemon(_ *slog.Logger) (int, error) {
|
||||
return 0, errors.New("daemon mode is not supported on Windows")
|
||||
}
|
||||
117
agent/internal/features/start/lock.go
Normal file
117
agent/internal/features/start/lock.go
Normal file
@@ -0,0 +1,117 @@
|
||||
//go:build !windows
|
||||
|
||||
package start
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
const lockFileName = "databasus.lock"
|
||||
|
||||
func AcquireLock(log *slog.Logger) (*os.File, error) {
|
||||
f, err := os.OpenFile(lockFileName, os.O_CREATE|os.O_RDWR, 0o644)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open lock file: %w", err)
|
||||
}
|
||||
|
||||
err = syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB)
|
||||
if err == nil {
|
||||
if err := writePID(f); err != nil {
|
||||
_ = f.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Info("Process lock acquired", "pid", os.Getpid(), "lockFile", lockFileName)
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
if !errors.Is(err, syscall.EWOULDBLOCK) {
|
||||
_ = f.Close()
|
||||
return nil, fmt.Errorf("failed to acquire lock: %w", err)
|
||||
}
|
||||
|
||||
pid, pidErr := readLockPID(f)
|
||||
_ = f.Close()
|
||||
|
||||
if pidErr != nil {
|
||||
return nil, fmt.Errorf("another instance is already running")
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("another instance is already running (PID %d)", pid)
|
||||
}
|
||||
|
||||
func ReleaseLock(f *os.File) {
|
||||
_ = syscall.Flock(int(f.Fd()), syscall.LOCK_UN)
|
||||
_ = f.Close()
|
||||
_ = os.Remove(lockFileName)
|
||||
}
|
||||
|
||||
func ReadLockFilePID() (int, error) {
|
||||
f, err := os.Open(lockFileName)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
return readLockPID(f)
|
||||
}
|
||||
|
||||
func writePID(f *os.File) error {
|
||||
if err := f.Truncate(0); err != nil {
|
||||
return fmt.Errorf("failed to truncate lock file: %w", err)
|
||||
}
|
||||
|
||||
if _, err := f.Seek(0, io.SeekStart); err != nil {
|
||||
return fmt.Errorf("failed to seek lock file: %w", err)
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(f, "%d\n", os.Getpid()); err != nil {
|
||||
return fmt.Errorf("failed to write PID to lock file: %w", err)
|
||||
}
|
||||
|
||||
return f.Sync()
|
||||
}
|
||||
|
||||
func readLockPID(f *os.File) (int, error) {
|
||||
if _, err := f.Seek(0, io.SeekStart); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
s := strings.TrimSpace(string(data))
|
||||
if s == "" {
|
||||
return 0, errors.New("lock file is empty")
|
||||
}
|
||||
|
||||
pid, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid PID in lock file: %w", err)
|
||||
}
|
||||
|
||||
return pid, nil
|
||||
}
|
||||
|
||||
func isProcessAlive(pid int) bool {
|
||||
err := syscall.Kill(pid, 0)
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if errors.Is(err, syscall.EPERM) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
148
agent/internal/features/start/lock_test.go
Normal file
148
agent/internal/features/start/lock_test.go
Normal file
@@ -0,0 +1,148 @@
|
||||
//go:build !windows
|
||||
|
||||
package start
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"databasus-agent/internal/logger"
|
||||
)
|
||||
|
||||
func Test_AcquireLock_LockFileCreatedWithPID(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
log := logger.GetLogger()
|
||||
|
||||
lockFile, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
data, err := os.ReadFile(lockFileName)
|
||||
require.NoError(t, err)
|
||||
|
||||
pid, err := strconv.Atoi(strings.TrimSpace(string(data)))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, os.Getpid(), pid)
|
||||
}
|
||||
|
||||
func Test_AcquireLock_SecondAcquireFails_WhenFirstHeld(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
log := logger.GetLogger()
|
||||
|
||||
first, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(first)
|
||||
|
||||
second, err := AcquireLock(log)
|
||||
assert.Nil(t, second)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "another instance is already running")
|
||||
assert.Contains(t, err.Error(), fmt.Sprintf("PID %d", os.Getpid()))
|
||||
}
|
||||
|
||||
func Test_AcquireLock_StaleLockReacquired_WhenProcessDead(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
log := logger.GetLogger()
|
||||
|
||||
err := os.WriteFile(lockFileName, []byte("999999999\n"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
lockFile, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
data, err := os.ReadFile(lockFileName)
|
||||
require.NoError(t, err)
|
||||
|
||||
pid, err := strconv.Atoi(strings.TrimSpace(string(data)))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, os.Getpid(), pid)
|
||||
}
|
||||
|
||||
func Test_ReleaseLock_LockFileRemoved(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
log := logger.GetLogger()
|
||||
|
||||
lockFile, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
|
||||
ReleaseLock(lockFile)
|
||||
|
||||
_, err = os.Stat(lockFileName)
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
}
|
||||
|
||||
func Test_AcquireLock_ReacquiredAfterRelease(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
log := logger.GetLogger()
|
||||
|
||||
first, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
ReleaseLock(first)
|
||||
|
||||
second, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(second)
|
||||
|
||||
data, err := os.ReadFile(lockFileName)
|
||||
require.NoError(t, err)
|
||||
|
||||
pid, err := strconv.Atoi(strings.TrimSpace(string(data)))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, os.Getpid(), pid)
|
||||
}
|
||||
|
||||
func Test_isProcessAlive_ReturnsTrueForSelf(t *testing.T) {
|
||||
assert.True(t, isProcessAlive(os.Getpid()))
|
||||
}
|
||||
|
||||
func Test_isProcessAlive_ReturnsFalseForNonExistentPID(t *testing.T) {
|
||||
assert.False(t, isProcessAlive(999999999))
|
||||
}
|
||||
|
||||
func Test_readLockPID_ParsesValidPID(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
|
||||
f, err := os.CreateTemp("", "lock-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(f.Name())
|
||||
|
||||
_, err = f.WriteString("12345\n")
|
||||
require.NoError(t, err)
|
||||
|
||||
pid, err := readLockPID(f)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 12345, pid)
|
||||
}
|
||||
|
||||
func Test_readLockPID_ReturnsErrorForEmptyFile(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
|
||||
f, err := os.CreateTemp("", "lock-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(f.Name())
|
||||
|
||||
_, err = readLockPID(f)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "lock file is empty")
|
||||
}
|
||||
|
||||
func setupTempDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
origDir, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
|
||||
dir := t.TempDir()
|
||||
require.NoError(t, os.Chdir(dir))
|
||||
|
||||
t.Cleanup(func() { _ = os.Chdir(origDir) })
|
||||
|
||||
return dir
|
||||
}
|
||||
90
agent/internal/features/start/lock_watcher.go
Normal file
90
agent/internal/features/start/lock_watcher.go
Normal file
@@ -0,0 +1,90 @@
|
||||
//go:build !windows
|
||||
|
||||
package start
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
const lockWatchInterval = 5 * time.Second
|
||||
|
||||
type LockWatcher struct {
|
||||
originalInode uint64
|
||||
cancel context.CancelFunc
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
func NewLockWatcher(lockFile *os.File, cancel context.CancelFunc, log *slog.Logger) (*LockWatcher, error) {
|
||||
inode, err := getFileInode(lockFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &LockWatcher{
|
||||
originalInode: inode,
|
||||
cancel: cancel,
|
||||
log: log,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *LockWatcher) Run(ctx context.Context) {
|
||||
ticker := time.NewTicker(lockWatchInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
w.check()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *LockWatcher) check() {
|
||||
info, err := os.Stat(lockFileName)
|
||||
if err != nil {
|
||||
w.log.Error("Lock file disappeared, shutting down", "file", lockFileName, "error", err)
|
||||
w.cancel()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
currentInode, err := getStatInode(info)
|
||||
if err != nil {
|
||||
w.log.Error("Failed to read lock file inode, shutting down", "error", err)
|
||||
w.cancel()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if currentInode != w.originalInode {
|
||||
w.log.Error("Lock file was replaced (inode changed), shutting down",
|
||||
"originalInode", w.originalInode,
|
||||
"currentInode", currentInode,
|
||||
)
|
||||
w.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
func getFileInode(f *os.File) (uint64, error) {
|
||||
info, err := f.Stat()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return getStatInode(info)
|
||||
}
|
||||
|
||||
func getStatInode(info os.FileInfo) (uint64, error) {
|
||||
stat, ok := info.Sys().(*syscall.Stat_t)
|
||||
if !ok {
|
||||
return 0, os.ErrInvalid
|
||||
}
|
||||
|
||||
return stat.Ino, nil
|
||||
}
|
||||
110
agent/internal/features/start/lock_watcher_test.go
Normal file
110
agent/internal/features/start/lock_watcher_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
//go:build !windows
|
||||
|
||||
package start
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"databasus-agent/internal/logger"
|
||||
)
|
||||
|
||||
func Test_NewLockWatcher_CapturesInode(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
log := logger.GetLogger()
|
||||
|
||||
lockFile, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
watcher, err := NewLockWatcher(lockFile, cancel, log)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, watcher.originalInode)
|
||||
}
|
||||
|
||||
func Test_LockWatcher_FileUnchanged_ContextNotCancelled(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
log := logger.GetLogger()
|
||||
|
||||
lockFile, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
watcher, err := NewLockWatcher(lockFile, cancel, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
watcher.check()
|
||||
watcher.check()
|
||||
watcher.check()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("context should not be cancelled when lock file is unchanged")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func Test_LockWatcher_FileDeleted_CancelsContext(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
log := logger.GetLogger()
|
||||
|
||||
lockFile, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
watcher, err := NewLockWatcher(lockFile, cancel, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.Remove(lockFileName)
|
||||
require.NoError(t, err)
|
||||
|
||||
watcher.check()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
default:
|
||||
t.Fatal("context should be cancelled when lock file is deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_LockWatcher_FileReplacedWithDifferentInode_CancelsContext(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
log := logger.GetLogger()
|
||||
|
||||
lockFile, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
watcher, err := NewLockWatcher(lockFile, cancel, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.Remove(lockFileName)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(lockFileName, []byte("99999\n"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
watcher.check()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
default:
|
||||
t.Fatal("context should be cancelled when lock file inode changes")
|
||||
}
|
||||
}
|
||||
17
agent/internal/features/start/lock_watcher_windows.go
Normal file
17
agent/internal/features/start/lock_watcher_windows.go
Normal file
@@ -0,0 +1,17 @@
|
||||
//go:build windows
|
||||
|
||||
package start
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
)
|
||||
|
||||
type LockWatcher struct{}
|
||||
|
||||
func NewLockWatcher(_ *os.File, _ context.CancelFunc, _ *slog.Logger) (*LockWatcher, error) {
|
||||
return &LockWatcher{}, nil
|
||||
}
|
||||
|
||||
func (w *LockWatcher) Run(_ context.Context) {}
|
||||
18
agent/internal/features/start/lock_windows.go
Normal file
18
agent/internal/features/start/lock_windows.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package start
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
)
|
||||
|
||||
func AcquireLock(log *slog.Logger) (*os.File, error) {
|
||||
log.Warn("Process locking is not supported on Windows, skipping")
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func ReleaseLock(f *os.File) {
|
||||
if f != nil {
|
||||
_ = f.Close()
|
||||
}
|
||||
}
|
||||
@@ -1,21 +1,101 @@
|
||||
package start
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"databasus-agent/internal/config"
|
||||
"databasus-agent/internal/features/api"
|
||||
full_backup "databasus-agent/internal/features/full_backup"
|
||||
"databasus-agent/internal/features/upgrade"
|
||||
"databasus-agent/internal/features/wal"
|
||||
)
|
||||
|
||||
func Run(cfg *config.Config, log *slog.Logger) error {
|
||||
const (
|
||||
pgBasebackupVerifyTimeout = 10 * time.Second
|
||||
dbVerifyTimeout = 10 * time.Second
|
||||
minPgMajorVersion = 15
|
||||
)
|
||||
|
||||
func Start(cfg *config.Config, agentVersion string, isDev bool, log *slog.Logger) error {
|
||||
if err := validateConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("start: stub — not yet implemented",
|
||||
"dbId", cfg.DbID,
|
||||
"hasToken", cfg.Token != "",
|
||||
)
|
||||
if err := verifyPgBasebackup(cfg, log); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := verifyDatabase(cfg, log); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
return RunDaemon(cfg, agentVersion, isDev, log)
|
||||
}
|
||||
|
||||
pid, err := spawnDaemon(log)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("Agent started in background (PID %d)\n", pid)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func RunDaemon(cfg *config.Config, agentVersion string, isDev bool, log *slog.Logger) error {
|
||||
lockFile, err := AcquireLock(log)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer cancel()
|
||||
|
||||
watcher, err := NewLockWatcher(lockFile, cancel, log)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize lock watcher: %w", err)
|
||||
}
|
||||
go watcher.Run(ctx)
|
||||
|
||||
apiClient := api.NewClient(cfg.DatabasusHost, cfg.Token, log)
|
||||
|
||||
var backgroundUpgrader *upgrade.BackgroundUpgrader
|
||||
if agentVersion != "dev" && runtime.GOOS != "windows" {
|
||||
backgroundUpgrader = upgrade.NewBackgroundUpgrader(apiClient, agentVersion, isDev, cancel, log)
|
||||
go backgroundUpgrader.Run(ctx)
|
||||
}
|
||||
|
||||
fullBackuper := full_backup.NewFullBackuper(cfg, apiClient, log)
|
||||
go fullBackuper.Run(ctx)
|
||||
|
||||
streamer := wal.NewStreamer(cfg, apiClient, log)
|
||||
streamer.Run(ctx)
|
||||
|
||||
if backgroundUpgrader != nil {
|
||||
backgroundUpgrader.WaitForCompletion(30 * time.Second)
|
||||
|
||||
if backgroundUpgrader.IsUpgraded() {
|
||||
return upgrade.ErrUpgradeRestart
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("Agent stopped")
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -33,5 +113,159 @@ func validateConfig(cfg *config.Config) error {
|
||||
return errors.New("argument token is required")
|
||||
}
|
||||
|
||||
if cfg.PgHost == "" {
|
||||
return errors.New("argument pg-host is required")
|
||||
}
|
||||
|
||||
if cfg.PgPort <= 0 {
|
||||
return errors.New("argument pg-port must be a positive number")
|
||||
}
|
||||
|
||||
if cfg.PgUser == "" {
|
||||
return errors.New("argument pg-user is required")
|
||||
}
|
||||
|
||||
if cfg.PgType != "host" && cfg.PgType != "docker" {
|
||||
return fmt.Errorf("argument pg-type must be 'host' or 'docker', got '%s'", cfg.PgType)
|
||||
}
|
||||
|
||||
if cfg.PgWalDir == "" {
|
||||
return errors.New("argument pg-wal-dir is required")
|
||||
}
|
||||
|
||||
if cfg.PgType == "docker" && cfg.PgDockerContainerName == "" {
|
||||
return errors.New("argument pg-docker-container-name is required when pg-type is 'docker'")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyPgBasebackup(cfg *config.Config, log *slog.Logger) error {
|
||||
switch cfg.PgType {
|
||||
case "host":
|
||||
return verifyPgBasebackupHost(cfg, log)
|
||||
case "docker":
|
||||
return verifyPgBasebackupDocker(cfg, log)
|
||||
default:
|
||||
return fmt.Errorf("unexpected pg-type: %s", cfg.PgType)
|
||||
}
|
||||
}
|
||||
|
||||
func verifyPgBasebackupHost(cfg *config.Config, log *slog.Logger) error {
|
||||
binary := "pg_basebackup"
|
||||
if cfg.PgHostBinDir != "" {
|
||||
binary = filepath.Join(cfg.PgHostBinDir, "pg_basebackup")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgBasebackupVerifyTimeout)
|
||||
defer cancel()
|
||||
|
||||
output, err := exec.CommandContext(ctx, binary, "--version").CombinedOutput()
|
||||
if err != nil {
|
||||
if cfg.PgHostBinDir != "" {
|
||||
return fmt.Errorf(
|
||||
"pg_basebackup not found at '%s': %w. Verify pg-host-bin-dir is correct",
|
||||
binary, err,
|
||||
)
|
||||
}
|
||||
|
||||
return fmt.Errorf(
|
||||
"pg_basebackup not found in PATH: %w. Install PostgreSQL client tools or set pg-host-bin-dir",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
log.Info("pg_basebackup verified", "version", strings.TrimSpace(string(output)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyPgBasebackupDocker(cfg *config.Config, log *slog.Logger) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgBasebackupVerifyTimeout)
|
||||
defer cancel()
|
||||
|
||||
output, err := exec.CommandContext(ctx,
|
||||
"docker", "exec", cfg.PgDockerContainerName,
|
||||
"pg_basebackup", "--version",
|
||||
).CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"pg_basebackup not available in container '%s': %w. "+
|
||||
"Check that the container is running and pg_basebackup is installed inside it",
|
||||
cfg.PgDockerContainerName, err,
|
||||
)
|
||||
}
|
||||
|
||||
log.Info("pg_basebackup verified (docker)",
|
||||
"container", cfg.PgDockerContainerName,
|
||||
"version", strings.TrimSpace(string(output)),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyDatabase(cfg *config.Config, log *slog.Logger) error {
|
||||
connStr := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=postgres sslmode=disable",
|
||||
cfg.PgHost, cfg.PgPort, cfg.PgUser, cfg.PgPassword,
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dbVerifyTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pgx.Connect(ctx, connStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to connect to PostgreSQL at %s:%d as user '%s': %w",
|
||||
cfg.PgHost, cfg.PgPort, cfg.PgUser, err,
|
||||
)
|
||||
}
|
||||
defer func() { _ = conn.Close(ctx) }()
|
||||
|
||||
if err := conn.Ping(ctx); err != nil {
|
||||
return fmt.Errorf("PostgreSQL ping failed at %s:%d: %w",
|
||||
cfg.PgHost, cfg.PgPort, err,
|
||||
)
|
||||
}
|
||||
|
||||
var versionNumStr string
|
||||
if err := conn.QueryRow(ctx, "SHOW server_version_num").Scan(&versionNumStr); err != nil {
|
||||
return fmt.Errorf("failed to query PostgreSQL version: %w", err)
|
||||
}
|
||||
|
||||
majorVersion, err := parsePgVersionNum(versionNumStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse PostgreSQL version '%s': %w", versionNumStr, err)
|
||||
}
|
||||
|
||||
if majorVersion < minPgMajorVersion {
|
||||
return fmt.Errorf(
|
||||
"PostgreSQL %d is not supported, minimum required version is %d",
|
||||
majorVersion, minPgMajorVersion,
|
||||
)
|
||||
}
|
||||
|
||||
log.Info("PostgreSQL connection verified",
|
||||
"host", cfg.PgHost,
|
||||
"port", cfg.PgPort,
|
||||
"user", cfg.PgUser,
|
||||
"version", majorVersion,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parsePgVersionNum(versionNumStr string) (int, error) {
|
||||
versionNum, err := strconv.Atoi(strings.TrimSpace(versionNumStr))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid version number: %w", err)
|
||||
}
|
||||
|
||||
if versionNum <= 0 {
|
||||
return 0, fmt.Errorf("invalid version number: %d", versionNum)
|
||||
}
|
||||
|
||||
majorVersion := versionNum / 10000
|
||||
|
||||
return majorVersion, nil
|
||||
}
|
||||
|
||||
84
agent/internal/features/start/start_test.go
Normal file
84
agent/internal/features/start/start_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package start
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_ParsePgVersionNum_SupportedVersions_ReturnsMajorVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
versionNumStr string
|
||||
expectedMajor int
|
||||
}{
|
||||
{name: "PG 15.0", versionNumStr: "150000", expectedMajor: 15},
|
||||
{name: "PG 15.4", versionNumStr: "150004", expectedMajor: 15},
|
||||
{name: "PG 16.0", versionNumStr: "160000", expectedMajor: 16},
|
||||
{name: "PG 16.3", versionNumStr: "160003", expectedMajor: 16},
|
||||
{name: "PG 17.2", versionNumStr: "170002", expectedMajor: 17},
|
||||
{name: "PG 18.0", versionNumStr: "180000", expectedMajor: 18},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
major, err := parsePgVersionNum(tt.versionNumStr)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedMajor, major)
|
||||
assert.GreaterOrEqual(t, major, minPgMajorVersion)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ParsePgVersionNum_UnsupportedVersions_ReturnsMajorVersionBelow15(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
versionNumStr string
|
||||
expectedMajor int
|
||||
}{
|
||||
{name: "PG 12.5", versionNumStr: "120005", expectedMajor: 12},
|
||||
{name: "PG 13.0", versionNumStr: "130000", expectedMajor: 13},
|
||||
{name: "PG 14.12", versionNumStr: "140012", expectedMajor: 14},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
major, err := parsePgVersionNum(tt.versionNumStr)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedMajor, major)
|
||||
assert.Less(t, major, minPgMajorVersion)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ParsePgVersionNum_InvalidInput_ReturnsError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
versionNumStr string
|
||||
}{
|
||||
{name: "empty string", versionNumStr: ""},
|
||||
{name: "non-numeric", versionNumStr: "abc"},
|
||||
{name: "negative number", versionNumStr: "-1"},
|
||||
{name: "zero", versionNumStr: "0"},
|
||||
{name: "float", versionNumStr: "15.4"},
|
||||
{name: "whitespace only", versionNumStr: " "},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := parsePgVersionNum(tt.versionNumStr)
|
||||
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ParsePgVersionNum_WithWhitespace_ParsesCorrectly(t *testing.T) {
|
||||
major, err := parsePgVersionNum(" 150004 ")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 15, major)
|
||||
}
|
||||
88
agent/internal/features/upgrade/background_upgrader.go
Normal file
88
agent/internal/features/upgrade/background_upgrader.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package upgrade
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"databasus-agent/internal/features/api"
|
||||
)
|
||||
|
||||
const backgroundCheckInterval = 5 * time.Second
|
||||
|
||||
type BackgroundUpgrader struct {
|
||||
apiClient *api.Client
|
||||
currentVersion string
|
||||
isDev bool
|
||||
cancel context.CancelFunc
|
||||
isUpgraded atomic.Bool
|
||||
log *slog.Logger
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func NewBackgroundUpgrader(
|
||||
apiClient *api.Client,
|
||||
currentVersion string,
|
||||
isDev bool,
|
||||
cancel context.CancelFunc,
|
||||
log *slog.Logger,
|
||||
) *BackgroundUpgrader {
|
||||
return &BackgroundUpgrader{
|
||||
apiClient,
|
||||
currentVersion,
|
||||
isDev,
|
||||
cancel,
|
||||
atomic.Bool{},
|
||||
log,
|
||||
make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (u *BackgroundUpgrader) Run(ctx context.Context) {
|
||||
defer close(u.done)
|
||||
|
||||
ticker := time.NewTicker(backgroundCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if u.checkAndUpgrade() {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *BackgroundUpgrader) IsUpgraded() bool {
|
||||
return u.isUpgraded.Load()
|
||||
}
|
||||
|
||||
func (u *BackgroundUpgrader) WaitForCompletion(timeout time.Duration) {
|
||||
select {
|
||||
case <-u.done:
|
||||
case <-time.After(timeout):
|
||||
}
|
||||
}
|
||||
|
||||
func (u *BackgroundUpgrader) checkAndUpgrade() bool {
|
||||
isUpgraded, err := CheckAndUpdate(u.apiClient, u.currentVersion, u.isDev, u.log)
|
||||
if err != nil {
|
||||
u.log.Warn("Background update check failed", "error", err)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
if !isUpgraded {
|
||||
return false
|
||||
}
|
||||
|
||||
u.log.Info("Background upgrade complete, restarting...")
|
||||
u.isUpgraded.Store(true)
|
||||
u.cancel()
|
||||
|
||||
return true
|
||||
}
|
||||
5
agent/internal/features/upgrade/errors.go
Normal file
5
agent/internal/features/upgrade/errors.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package upgrade
|
||||
|
||||
import "errors"
|
||||
|
||||
var ErrUpgradeRestart = errors.New("agent upgraded, restart required")
|
||||
@@ -2,49 +2,47 @@ package upgrade
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"databasus-agent/internal/features/api"
|
||||
)
|
||||
|
||||
type Logger interface {
|
||||
Info(msg string, args ...any)
|
||||
Warn(msg string, args ...any)
|
||||
Error(msg string, args ...any)
|
||||
}
|
||||
|
||||
type versionResponse struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
func CheckAndUpdate(databasusHost, currentVersion string, isDev bool, log Logger) error {
|
||||
// CheckAndUpdate checks if a new version is available and upgrades the binary on disk.
|
||||
// Returns (true, nil) if the binary was upgraded, (false, nil) if already up to date,
|
||||
// or (false, err) on failure. Callers are responsible for re-exec or restart signaling.
|
||||
func CheckAndUpdate(apiClient *api.Client, currentVersion string, isDev bool, log *slog.Logger) (bool, error) {
|
||||
if isDev {
|
||||
log.Info("Skipping update check (development mode)")
|
||||
return nil
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
serverVersion, err := fetchServerVersion(databasusHost, log)
|
||||
serverVersion, err := apiClient.FetchServerVersion(context.Background())
|
||||
if err != nil {
|
||||
return nil
|
||||
log.Warn("Could not reach server for update check", "error", err)
|
||||
|
||||
return false, fmt.Errorf(
|
||||
"unable to check version, please verify Databasus server is available: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
if serverVersion == currentVersion {
|
||||
log.Info("Agent version is up to date", "version", currentVersion)
|
||||
return nil
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
log.Info("Updating agent...", "current", currentVersion, "target", serverVersion)
|
||||
|
||||
selfPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to determine executable path: %w", err)
|
||||
return false, fmt.Errorf("failed to determine executable path: %w", err)
|
||||
}
|
||||
|
||||
tempPath := selfPath + ".update"
|
||||
@@ -53,93 +51,25 @@ func CheckAndUpdate(databasusHost, currentVersion string, isDev bool, log Logger
|
||||
_ = os.Remove(tempPath)
|
||||
}()
|
||||
|
||||
if err := downloadBinary(databasusHost, tempPath); err != nil {
|
||||
return fmt.Errorf("failed to download update: %w", err)
|
||||
if err := apiClient.DownloadAgentBinary(context.Background(), runtime.GOARCH, tempPath); err != nil {
|
||||
return false, fmt.Errorf("failed to download update: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(tempPath, 0o755); err != nil {
|
||||
return fmt.Errorf("failed to set permissions on update: %w", err)
|
||||
return false, fmt.Errorf("failed to set permissions on update: %w", err)
|
||||
}
|
||||
|
||||
if err := verifyBinary(tempPath, serverVersion); err != nil {
|
||||
return fmt.Errorf("update verification failed: %w", err)
|
||||
return false, fmt.Errorf("update verification failed: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tempPath, selfPath); err != nil {
|
||||
return fmt.Errorf("failed to replace binary (try --skip-update if this persists): %w", err)
|
||||
return false, fmt.Errorf("failed to replace binary (try --skip-update if this persists): %w", err)
|
||||
}
|
||||
|
||||
log.Info("Update complete, re-executing...")
|
||||
log.Info("Agent binary updated", "version", serverVersion)
|
||||
|
||||
return syscall.Exec(selfPath, os.Args, os.Environ())
|
||||
}
|
||||
|
||||
func fetchServerVersion(host string, log Logger) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, host+"/api/v1/system/version", nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
log.Warn("Could not reach server for update check, continuing", "error", err)
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Warn(
|
||||
"Server returned non-OK status for version check, continuing",
|
||||
"status",
|
||||
resp.StatusCode,
|
||||
)
|
||||
return "", fmt.Errorf("status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var ver versionResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&ver); err != nil {
|
||||
log.Warn("Failed to parse server version response, continuing", "error", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
return ver.Version, nil
|
||||
}
|
||||
|
||||
func downloadBinary(host, destPath string) error {
|
||||
url := fmt.Sprintf("%s/api/v1/system/agent?arch=%s", host, runtime.GOARCH)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("server returned %d for agent download", resp.StatusCode)
|
||||
}
|
||||
|
||||
f, err := os.Create(destPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
_, err = io.Copy(f, resp.Body)
|
||||
|
||||
return err
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func verifyBinary(binaryPath, expectedVersion string) error {
|
||||
|
||||
182
agent/internal/features/wal/streamer.go
Normal file
182
agent/internal/features/wal/streamer.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"databasus-agent/internal/config"
|
||||
"databasus-agent/internal/features/api"
|
||||
)
|
||||
|
||||
const (
|
||||
pollInterval = 2 * time.Second
|
||||
uploadTimeout = 5 * time.Minute
|
||||
)
|
||||
|
||||
var segmentNameRegex = regexp.MustCompile(`^[0-9A-Fa-f]{24}$`)
|
||||
|
||||
type Streamer struct {
|
||||
cfg *config.Config
|
||||
apiClient *api.Client
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
func NewStreamer(cfg *config.Config, apiClient *api.Client, log *slog.Logger) *Streamer {
|
||||
return &Streamer{
|
||||
cfg: cfg,
|
||||
apiClient: apiClient,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Streamer) Run(ctx context.Context) {
|
||||
s.log.Info("WAL streamer started", "pgWalDir", s.cfg.PgWalDir)
|
||||
|
||||
s.processQueue(ctx)
|
||||
|
||||
ticker := time.NewTicker(pollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.log.Info("WAL streamer stopping")
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.processQueue(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Streamer) processQueue(ctx context.Context) {
|
||||
segments, err := s.listSegments()
|
||||
if err != nil {
|
||||
s.log.Error("Failed to list WAL segments", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, segmentName := range segments {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.uploadSegment(ctx, segmentName); err != nil {
|
||||
s.log.Error("Failed to upload WAL segment",
|
||||
"segment", segmentName,
|
||||
"error", err,
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Streamer) listSegments() ([]string, error) {
|
||||
entries, err := os.ReadDir(s.cfg.PgWalDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read wal dir: %w", err)
|
||||
}
|
||||
|
||||
var segments []string
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
name := entry.Name()
|
||||
|
||||
if strings.HasSuffix(name, ".tmp") {
|
||||
continue
|
||||
}
|
||||
|
||||
if !segmentNameRegex.MatchString(name) {
|
||||
continue
|
||||
}
|
||||
|
||||
segments = append(segments, name)
|
||||
}
|
||||
|
||||
sort.Strings(segments)
|
||||
|
||||
return segments, nil
|
||||
}
|
||||
|
||||
func (s *Streamer) uploadSegment(ctx context.Context, segmentName string) error {
|
||||
filePath := filepath.Join(s.cfg.PgWalDir, segmentName)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
go s.compressAndStream(pw, filePath)
|
||||
|
||||
uploadCtx, cancel := context.WithTimeout(ctx, uploadTimeout)
|
||||
defer cancel()
|
||||
|
||||
result, err := s.apiClient.UploadWalSegment(uploadCtx, segmentName, pr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if result.IsGapDetected {
|
||||
s.log.Warn("WAL chain gap detected",
|
||||
"segment", segmentName,
|
||||
"expected", result.ExpectedSegmentName,
|
||||
"received", result.ReceivedSegmentName,
|
||||
)
|
||||
|
||||
return fmt.Errorf("gap detected for segment %s", segmentName)
|
||||
}
|
||||
|
||||
s.log.Debug("WAL segment uploaded", "segment", segmentName)
|
||||
|
||||
if *s.cfg.IsDeleteWalAfterUpload {
|
||||
if err := os.Remove(filePath); err != nil {
|
||||
s.log.Warn("Failed to delete uploaded WAL segment",
|
||||
"segment", segmentName,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Streamer) compressAndStream(pw *io.PipeWriter, filePath string) {
|
||||
f, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
_ = pw.CloseWithError(fmt.Errorf("open file: %w", err))
|
||||
return
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
encoder, err := zstd.NewWriter(pw,
|
||||
zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(5)),
|
||||
zstd.WithEncoderCRC(true),
|
||||
)
|
||||
if err != nil {
|
||||
_ = pw.CloseWithError(fmt.Errorf("create zstd encoder: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := io.Copy(encoder, f); err != nil {
|
||||
_ = encoder.Close()
|
||||
_ = pw.CloseWithError(fmt.Errorf("compress: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := encoder.Close(); err != nil {
|
||||
_ = pw.CloseWithError(fmt.Errorf("close encoder: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
_ = pw.Close()
|
||||
}
|
||||
348
agent/internal/features/wal/streamer_test.go
Normal file
348
agent/internal/features/wal/streamer_test.go
Normal file
@@ -0,0 +1,348 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"databasus-agent/internal/config"
|
||||
"databasus-agent/internal/features/api"
|
||||
"databasus-agent/internal/logger"
|
||||
)
|
||||
|
||||
func Test_UploadSegment_SingleSegment_ServerReceivesCorrectHeadersAndBody(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
segmentContent := []byte("test-wal-segment-data-for-upload")
|
||||
writeTestSegment(t, walDir, "000000010000000100000001", segmentContent)
|
||||
|
||||
var receivedHeaders http.Header
|
||||
var receivedBody []byte
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedHeaders = r.Header.Clone()
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
receivedBody = body
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
require.NotNil(t, receivedHeaders)
|
||||
assert.Equal(t, "test-token", receivedHeaders.Get("Authorization"))
|
||||
assert.Equal(t, "application/octet-stream", receivedHeaders.Get("Content-Type"))
|
||||
assert.Equal(t, "000000010000000100000001", receivedHeaders.Get("X-Wal-Segment-Name"))
|
||||
|
||||
decompressed := decompressZstd(t, receivedBody)
|
||||
assert.Equal(t, segmentContent, decompressed)
|
||||
}
|
||||
|
||||
func Test_UploadSegments_MultipleSegmentsOutOfOrder_UploadedInAscendingOrder(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
writeTestSegment(t, walDir, "000000010000000100000003", []byte("third"))
|
||||
writeTestSegment(t, walDir, "000000010000000100000001", []byte("first"))
|
||||
writeTestSegment(t, walDir, "000000010000000100000002", []byte("second"))
|
||||
|
||||
var mu sync.Mutex
|
||||
var uploadOrder []string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
uploadOrder = append(uploadOrder, r.Header.Get("X-Wal-Segment-Name"))
|
||||
mu.Unlock()
|
||||
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
require.Len(t, uploadOrder, 3)
|
||||
assert.Equal(t, "000000010000000100000001", uploadOrder[0])
|
||||
assert.Equal(t, "000000010000000100000002", uploadOrder[1])
|
||||
assert.Equal(t, "000000010000000100000003", uploadOrder[2])
|
||||
}
|
||||
|
||||
func Test_UploadSegments_DirectoryHasTmpFiles_TmpFilesIgnored(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
writeTestSegment(t, walDir, "000000010000000100000001", []byte("real segment"))
|
||||
writeTestSegment(t, walDir, "000000010000000100000002.tmp", []byte("partial copy"))
|
||||
|
||||
var mu sync.Mutex
|
||||
var uploadedSegments []string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
uploadedSegments = append(uploadedSegments, r.Header.Get("X-Wal-Segment-Name"))
|
||||
mu.Unlock()
|
||||
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
require.Len(t, uploadedSegments, 1)
|
||||
assert.Equal(t, "000000010000000100000001", uploadedSegments[0])
|
||||
}
|
||||
|
||||
func Test_UploadSegment_DeleteEnabled_FileRemovedAfterUpload(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
segmentName := "000000010000000100000001"
|
||||
writeTestSegment(t, walDir, segmentName, []byte("segment data"))
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
isDeleteEnabled := true
|
||||
cfg := createTestConfig(walDir, server.URL)
|
||||
cfg.IsDeleteWalAfterUpload = &isDeleteEnabled
|
||||
apiClient := api.NewClient(server.URL, cfg.Token, logger.GetLogger())
|
||||
streamer := NewStreamer(cfg, apiClient, logger.GetLogger())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
_, err := os.Stat(filepath.Join(walDir, segmentName))
|
||||
assert.True(t, os.IsNotExist(err), "segment file should be deleted after successful upload")
|
||||
}
|
||||
|
||||
func Test_UploadSegment_DeleteDisabled_FileKeptAfterUpload(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
segmentName := "000000010000000100000001"
|
||||
writeTestSegment(t, walDir, segmentName, []byte("segment data"))
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
isDeleteDisabled := false
|
||||
cfg := createTestConfig(walDir, server.URL)
|
||||
cfg.IsDeleteWalAfterUpload = &isDeleteDisabled
|
||||
apiClient := api.NewClient(server.URL, cfg.Token, logger.GetLogger())
|
||||
streamer := NewStreamer(cfg, apiClient, logger.GetLogger())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
_, err := os.Stat(filepath.Join(walDir, segmentName))
|
||||
assert.NoError(t, err, "segment file should be kept when delete is disabled")
|
||||
}
|
||||
|
||||
func Test_UploadSegment_ServerReturns500_FileKeptInQueue(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
segmentName := "000000010000000100000001"
|
||||
writeTestSegment(t, walDir, segmentName, []byte("segment data"))
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte(`{"error":"internal server error"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
_, err := os.Stat(filepath.Join(walDir, segmentName))
|
||||
assert.NoError(t, err, "segment file should remain in queue after server error")
|
||||
}
|
||||
|
||||
func Test_ProcessQueue_EmptyDirectory_NoUploads(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
|
||||
uploadCount := 0
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
uploadCount++
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
assert.Equal(t, 0, uploadCount, "no uploads should occur for empty directory")
|
||||
}
|
||||
|
||||
func Test_Run_ContextCancelled_StopsImmediately(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
|
||||
streamer := newTestStreamer(walDir, "http://localhost:0")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
streamer.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Run should have stopped immediately when context is already cancelled")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_UploadSegment_ServerReturns409_FileNotDeleted(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
segmentName := "000000010000000100000005"
|
||||
writeTestSegment(t, walDir, segmentName, []byte("gap segment"))
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusConflict)
|
||||
|
||||
resp := map[string]string{
|
||||
"error": "gap_detected",
|
||||
"expectedSegmentName": "000000010000000100000003",
|
||||
"receivedSegmentName": segmentName,
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
_, err := os.Stat(filepath.Join(walDir, segmentName))
|
||||
assert.NoError(t, err, "segment file should not be deleted on gap detection")
|
||||
}
|
||||
|
||||
func newTestStreamer(walDir, serverURL string) *Streamer {
|
||||
cfg := createTestConfig(walDir, serverURL)
|
||||
apiClient := api.NewClient(serverURL, cfg.Token, logger.GetLogger())
|
||||
|
||||
return NewStreamer(cfg, apiClient, logger.GetLogger())
|
||||
}
|
||||
|
||||
func createTestWalDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
baseDir := filepath.Join(".", ".test-tmp")
|
||||
if err := os.MkdirAll(baseDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create base test dir: %v", err)
|
||||
}
|
||||
|
||||
dir, err := os.MkdirTemp(baseDir, t.Name()+"-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test wal dir: %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(dir)
|
||||
})
|
||||
|
||||
return dir
|
||||
}
|
||||
|
||||
func writeTestSegment(t *testing.T, dir, name string, content []byte) {
|
||||
t.Helper()
|
||||
|
||||
if err := os.WriteFile(filepath.Join(dir, name), content, 0o644); err != nil {
|
||||
t.Fatalf("failed to write test segment %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
func createTestConfig(walDir, serverURL string) *config.Config {
|
||||
isDeleteEnabled := true
|
||||
|
||||
return &config.Config{
|
||||
DatabasusHost: serverURL,
|
||||
DbID: "test-db-id",
|
||||
Token: "test-token",
|
||||
PgWalDir: walDir,
|
||||
IsDeleteWalAfterUpload: &isDeleteEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func decompressZstd(t *testing.T, data []byte) []byte {
|
||||
t.Helper()
|
||||
|
||||
decoder, err := zstd.NewReader(nil)
|
||||
require.NoError(t, err)
|
||||
defer decoder.Close()
|
||||
|
||||
decoded, err := decoder.DecodeAll(data, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
return decoded
|
||||
}
|
||||
@@ -1,47 +1,119 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
logFileName = "databasus.log"
|
||||
oldLogFileName = "databasus.log.old"
|
||||
maxLogFileSize = 5 * 1024 * 1024 // 5MB
|
||||
)
|
||||
|
||||
type rotatingWriter struct {
|
||||
mu sync.Mutex
|
||||
file *os.File
|
||||
currentSize int64
|
||||
maxSize int64
|
||||
logPath string
|
||||
oldLogPath string
|
||||
}
|
||||
|
||||
func (w *rotatingWriter) Write(p []byte) (int, error) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if w.currentSize+int64(len(p)) > w.maxSize {
|
||||
if err := w.rotate(); err != nil {
|
||||
return 0, fmt.Errorf("failed to rotate log file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
n, err := w.file.Write(p)
|
||||
w.currentSize += int64(n)
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (w *rotatingWriter) rotate() error {
|
||||
if err := w.file.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close %s: %w", w.logPath, err)
|
||||
}
|
||||
|
||||
if err := os.Remove(w.oldLogPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to remove %s: %w", w.oldLogPath, err)
|
||||
}
|
||||
|
||||
if err := os.Rename(w.logPath, w.oldLogPath); err != nil {
|
||||
return fmt.Errorf("failed to rename %s to %s: %w", w.logPath, w.oldLogPath, err)
|
||||
}
|
||||
|
||||
f, err := os.OpenFile(w.logPath, os.O_CREATE|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create new %s: %w", w.logPath, err)
|
||||
}
|
||||
|
||||
w.file = f
|
||||
w.currentSize = 0
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
loggerInstance *slog.Logger
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
func Init(isDebug bool) {
|
||||
level := slog.LevelInfo
|
||||
if isDebug {
|
||||
level = slog.LevelDebug
|
||||
}
|
||||
|
||||
once.Do(func() {
|
||||
loggerInstance = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
|
||||
Level: level,
|
||||
ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr {
|
||||
if a.Key == slog.TimeKey {
|
||||
a.Value = slog.StringValue(time.Now().Format("2006/01/02 15:04:05"))
|
||||
}
|
||||
if a.Key == slog.LevelKey {
|
||||
return slog.Attr{}
|
||||
}
|
||||
|
||||
return a
|
||||
},
|
||||
}))
|
||||
|
||||
loggerInstance.Info("Text structured logger initialized")
|
||||
})
|
||||
}
|
||||
|
||||
// GetLogger returns a singleton slog.Logger that logs to the console
|
||||
func GetLogger() *slog.Logger {
|
||||
if loggerInstance == nil {
|
||||
Init(false)
|
||||
}
|
||||
once.Do(func() {
|
||||
initialize()
|
||||
})
|
||||
|
||||
return loggerInstance
|
||||
}
|
||||
|
||||
func initialize() {
|
||||
writer := buildWriter()
|
||||
|
||||
loggerInstance = slog.New(slog.NewTextHandler(writer, &slog.HandlerOptions{
|
||||
Level: slog.LevelInfo,
|
||||
ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr {
|
||||
if a.Key == slog.TimeKey {
|
||||
a.Value = slog.StringValue(time.Now().Format("2006/01/02 15:04:05"))
|
||||
}
|
||||
if a.Key == slog.LevelKey {
|
||||
return slog.Attr{}
|
||||
}
|
||||
|
||||
return a
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
func buildWriter() io.Writer {
|
||||
f, err := os.OpenFile(logFileName, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Warning: failed to open %s for logging: %v\n", logFileName, err)
|
||||
return os.Stdout
|
||||
}
|
||||
|
||||
var currentSize int64
|
||||
if info, err := f.Stat(); err == nil {
|
||||
currentSize = info.Size()
|
||||
}
|
||||
|
||||
rw := &rotatingWriter{
|
||||
file: f,
|
||||
currentSize: currentSize,
|
||||
maxSize: maxLogFileSize,
|
||||
logPath: logFileName,
|
||||
oldLogPath: oldLogFileName,
|
||||
}
|
||||
|
||||
return io.MultiWriter(os.Stdout, rw)
|
||||
}
|
||||
|
||||
128
agent/internal/logger/logger_test.go
Normal file
128
agent/internal/logger/logger_test.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_Write_DataWrittenToFile(t *testing.T) {
|
||||
rw, logPath, _ := setupRotatingWriter(t, 1024)
|
||||
|
||||
data := []byte("hello world\n")
|
||||
n, err := rw.Write(data)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(data), n)
|
||||
assert.Equal(t, int64(len(data)), rw.currentSize)
|
||||
|
||||
content, err := os.ReadFile(logPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(data), string(content))
|
||||
}
|
||||
|
||||
func Test_Write_WhenLimitExceeded_FileRotated(t *testing.T) {
|
||||
rw, logPath, oldLogPath := setupRotatingWriter(t, 100)
|
||||
|
||||
firstData := []byte(strings.Repeat("A", 80))
|
||||
_, err := rw.Write(firstData)
|
||||
require.NoError(t, err)
|
||||
|
||||
secondData := []byte(strings.Repeat("B", 30))
|
||||
_, err = rw.Write(secondData)
|
||||
require.NoError(t, err)
|
||||
|
||||
oldContent, err := os.ReadFile(oldLogPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(firstData), string(oldContent))
|
||||
|
||||
newContent, err := os.ReadFile(logPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(secondData), string(newContent))
|
||||
|
||||
assert.Equal(t, int64(len(secondData)), rw.currentSize)
|
||||
}
|
||||
|
||||
func Test_Write_WhenOldFileExists_OldFileReplaced(t *testing.T) {
|
||||
rw, _, oldLogPath := setupRotatingWriter(t, 100)
|
||||
|
||||
require.NoError(t, os.WriteFile(oldLogPath, []byte("stale data"), 0o644))
|
||||
|
||||
_, err := rw.Write([]byte(strings.Repeat("A", 80)))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = rw.Write([]byte(strings.Repeat("B", 30)))
|
||||
require.NoError(t, err)
|
||||
|
||||
oldContent, err := os.ReadFile(oldLogPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, strings.Repeat("A", 80), string(oldContent))
|
||||
}
|
||||
|
||||
func Test_Write_MultipleSmallWrites_CurrentSizeAccumulated(t *testing.T) {
|
||||
rw, _, _ := setupRotatingWriter(t, 1024)
|
||||
|
||||
var totalWritten int64
|
||||
for i := 0; i < 10; i++ {
|
||||
data := []byte("line\n")
|
||||
n, err := rw.Write(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
totalWritten += int64(n)
|
||||
}
|
||||
|
||||
assert.Equal(t, totalWritten, rw.currentSize)
|
||||
assert.Equal(t, int64(50), rw.currentSize)
|
||||
}
|
||||
|
||||
func Test_Write_ExactlyAtBoundary_NoRotationUntilNextByte(t *testing.T) {
|
||||
rw, logPath, oldLogPath := setupRotatingWriter(t, 100)
|
||||
|
||||
exactData := []byte(strings.Repeat("X", 100))
|
||||
_, err := rw.Write(exactData)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = os.Stat(oldLogPath)
|
||||
assert.True(t, os.IsNotExist(err), ".old file should not exist yet")
|
||||
|
||||
content, err := os.ReadFile(logPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(exactData), string(content))
|
||||
|
||||
_, err = rw.Write([]byte("Z"))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = os.Stat(oldLogPath)
|
||||
assert.NoError(t, err, ".old file should exist after exceeding limit")
|
||||
|
||||
assert.Equal(t, int64(1), rw.currentSize)
|
||||
}
|
||||
|
||||
func setupRotatingWriter(t *testing.T, maxSize int64) (*rotatingWriter, string, string) {
|
||||
t.Helper()
|
||||
|
||||
dir := t.TempDir()
|
||||
logPath := filepath.Join(dir, "test.log")
|
||||
oldLogPath := filepath.Join(dir, "test.log.old")
|
||||
|
||||
f, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY, 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
rw := &rotatingWriter{
|
||||
file: f,
|
||||
currentSize: 0,
|
||||
maxSize: maxSize,
|
||||
logPath: logPath,
|
||||
oldLogPath: oldLogPath,
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
rw.file.Close()
|
||||
})
|
||||
|
||||
return rw, logPath, oldLogPath
|
||||
}
|
||||
@@ -59,6 +59,10 @@ func (c *BackupCleaner) Run(ctx context.Context) {
|
||||
if err := c.cleanExceededBackups(); err != nil {
|
||||
c.logger.Error("Failed to clean exceeded backups", "error", err)
|
||||
}
|
||||
|
||||
if err := c.cleanStaleUploadedBasebackups(); err != nil {
|
||||
c.logger.Error("Failed to clean stale uploaded basebackups", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -100,6 +104,67 @@ func (c *BackupCleaner) AddBackupRemoveListener(listener backups_core.BackupRemo
|
||||
c.backupRemoveListeners = append(c.backupRemoveListeners, listener)
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanStaleUploadedBasebackups() error {
|
||||
staleBackups, err := c.backupRepository.FindStaleUploadedBasebackups(
|
||||
time.Now().UTC().Add(-10 * time.Minute),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find stale uploaded basebackups: %w", err)
|
||||
}
|
||||
|
||||
for _, backup := range staleBackups {
|
||||
staleStorage, storageErr := c.storageService.GetStorageByID(backup.StorageID)
|
||||
if storageErr != nil {
|
||||
c.logger.Error(
|
||||
"Failed to get storage for stale basebackup cleanup",
|
||||
"backupId", backup.ID,
|
||||
"storageId", backup.StorageID,
|
||||
"error", storageErr,
|
||||
)
|
||||
} else {
|
||||
if err := staleStorage.DeleteFile(c.fieldEncryptor, backup.FileName); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to delete stale basebackup file",
|
||||
"backupId", backup.ID,
|
||||
"fileName", backup.FileName,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
|
||||
metadataFileName := backup.FileName + ".metadata"
|
||||
if err := staleStorage.DeleteFile(c.fieldEncryptor, metadataFileName); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to delete stale basebackup metadata file",
|
||||
"backupId", backup.ID,
|
||||
"fileName", metadataFileName,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
failMsg := "basebackup finalization timed out after 10 minutes"
|
||||
backup.Status = backups_core.BackupStatusFailed
|
||||
backup.FailMessage = &failMsg
|
||||
|
||||
if err := c.backupRepository.Save(backup); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to mark stale uploaded basebackup as failed",
|
||||
"backupId", backup.ID,
|
||||
"error", err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Marked stale uploaded basebackup as failed and cleaned storage",
|
||||
"backupId", backup.ID,
|
||||
"databaseId", backup.DatabaseID,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanByRetentionPolicy() error {
|
||||
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
|
||||
@@ -1004,6 +1004,191 @@ func (m *mockBackupRemoveListener) OnBeforeBackupRemove(backup *backups_core.Bac
|
||||
return nil
|
||||
}
|
||||
|
||||
func Test_CleanStaleUploadedBasebackups_MarksAsFailed(t *testing.T) {
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
staleTime := time.Now().UTC().Add(-15 * time.Minute)
|
||||
walBackupType := backups_core.PgWalBackupTypeFullBackup
|
||||
staleBackup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
PgWalBackupType: &walBackupType,
|
||||
UploadCompletedAt: &staleTime,
|
||||
CreatedAt: staleTime,
|
||||
}
|
||||
|
||||
err := backupRepository.Save(staleBackup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanStaleUploadedBasebackups()
|
||||
assert.NoError(t, err)
|
||||
|
||||
updated, err := backupRepository.FindByID(staleBackup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusFailed, updated.Status)
|
||||
assert.NotNil(t, updated.FailMessage)
|
||||
assert.Contains(t, *updated.FailMessage, "finalization timed out")
|
||||
}
|
||||
|
||||
func Test_CleanStaleUploadedBasebackups_SkipsRecentUploads(t *testing.T) {
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
recentTime := time.Now().UTC().Add(-2 * time.Minute)
|
||||
walBackupType := backups_core.PgWalBackupTypeFullBackup
|
||||
recentBackup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
PgWalBackupType: &walBackupType,
|
||||
UploadCompletedAt: &recentTime,
|
||||
CreatedAt: recentTime,
|
||||
}
|
||||
|
||||
err := backupRepository.Save(recentBackup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanStaleUploadedBasebackups()
|
||||
assert.NoError(t, err)
|
||||
|
||||
updated, err := backupRepository.FindByID(recentBackup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusInProgress, updated.Status)
|
||||
}
|
||||
|
||||
func Test_CleanStaleUploadedBasebackups_SkipsActiveStreaming(t *testing.T) {
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
walBackupType := backups_core.PgWalBackupTypeFullBackup
|
||||
activeBackup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
PgWalBackupType: &walBackupType,
|
||||
CreatedAt: time.Now().UTC().Add(-30 * time.Minute),
|
||||
}
|
||||
|
||||
err := backupRepository.Save(activeBackup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanStaleUploadedBasebackups()
|
||||
assert.NoError(t, err)
|
||||
|
||||
updated, err := backupRepository.FindByID(activeBackup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusInProgress, updated.Status)
|
||||
assert.Nil(t, updated.UploadCompletedAt)
|
||||
}
|
||||
|
||||
func Test_CleanStaleUploadedBasebackups_CleansStorageFiles(t *testing.T) {
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
staleTime := time.Now().UTC().Add(-15 * time.Minute)
|
||||
walBackupType := backups_core.PgWalBackupTypeFullBackup
|
||||
staleBackup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
PgWalBackupType: &walBackupType,
|
||||
UploadCompletedAt: &staleTime,
|
||||
BackupSizeMb: 500,
|
||||
FileName: "stale-basebackup-test-file",
|
||||
CreatedAt: staleTime,
|
||||
}
|
||||
|
||||
err := backupRepository.Save(staleBackup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanStaleUploadedBasebackups()
|
||||
assert.NoError(t, err)
|
||||
|
||||
updated, err := backupRepository.FindByID(staleBackup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusFailed, updated.Status)
|
||||
assert.NotNil(t, updated.FailMessage)
|
||||
assert.Contains(t, *updated.FailMessage, "finalization timed out")
|
||||
}
|
||||
|
||||
func createTestInterval() *intervals.Interval {
|
||||
timeOfDay := "04:00"
|
||||
interval := &intervals.Interval{
|
||||
|
||||
@@ -3,12 +3,10 @@ package backups_controllers
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_dto "databasus-backend/internal/features/backups/backups/dto"
|
||||
backups_services "databasus-backend/internal/features/backups/backups/services"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -25,8 +23,11 @@ func (c *PostgreWalBackupController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
walRoutes := router.Group("/backups/postgres/wal")
|
||||
|
||||
walRoutes.GET("/next-full-backup-time", c.GetNextFullBackupTime)
|
||||
walRoutes.GET("/is-wal-chain-valid-since-last-full-backup", c.IsWalChainValidSinceLastBackup)
|
||||
walRoutes.POST("/error", c.ReportError)
|
||||
walRoutes.POST("/upload", c.Upload)
|
||||
walRoutes.POST("/upload/wal", c.UploadWalSegment)
|
||||
walRoutes.POST("/upload/full-start", c.StartFullBackupUpload)
|
||||
walRoutes.POST("/upload/full-complete", c.CompleteFullBackupUpload)
|
||||
walRoutes.GET("/restore/plan", c.GetRestorePlan)
|
||||
walRoutes.GET("/restore/download", c.DownloadBackupFile)
|
||||
}
|
||||
@@ -90,91 +91,66 @@ func (c *PostgreWalBackupController) ReportError(ctx *gin.Context) {
|
||||
ctx.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
// Upload
|
||||
// @Summary Stream upload a basebackup or WAL segment
|
||||
// @Description Accepts a zstd-compressed binary stream and stores it in the database's configured storage.
|
||||
// The server generates the storage filename; agents do not control the destination path.
|
||||
// For WAL segment uploads the server validates the WAL chain and returns 409 if a gap is detected
|
||||
// or 400 if no full backup exists yet (agent should trigger a full basebackup in both cases).
|
||||
// IsWalChainValidSinceLastBackup
|
||||
// @Summary Check WAL chain validity since last full backup
|
||||
// @Description Checks whether the WAL chain is continuous since the last completed full backup.
|
||||
// Returns isValid=true if the chain is intact, or isValid=false with error details if not.
|
||||
// @Tags backups-wal
|
||||
// @Accept application/octet-stream
|
||||
// @Produce json
|
||||
// @Security AgentToken
|
||||
// @Param X-Upload-Type header string true "Upload type" Enums(basebackup, wal)
|
||||
// @Param X-Wal-Segment-Name header string false "24-hex WAL segment identifier (required for wal uploads, e.g. 0000000100000001000000AB)"
|
||||
// @Param X-Wal-Segment-Size header int false "WAL segment size in bytes reported by the PostgreSQL instance (default: 16777216)"
|
||||
// @Param fullBackupWalStartSegment query string false "First WAL segment needed to make the basebackup consistent (required for basebackup uploads)"
|
||||
// @Param fullBackupWalStopSegment query string false "Last WAL segment included in the basebackup (required for basebackup uploads)"
|
||||
// @Success 204
|
||||
// @Failure 400 {object} backups_dto.UploadGapResponse "No full backup exists (error: no_full_backup)"
|
||||
// @Success 200 {object} backups_dto.IsWalChainValidResponse
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 409 {object} backups_dto.UploadGapResponse "WAL chain gap detected (error: gap_detected)"
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /backups/postgres/wal/upload [post]
|
||||
func (c *PostgreWalBackupController) Upload(ctx *gin.Context) {
|
||||
// @Router /backups/postgres/wal/is-wal-chain-valid-since-last-full-backup [get]
|
||||
func (c *PostgreWalBackupController) IsWalChainValidSinceLastBackup(ctx *gin.Context) {
|
||||
database, err := c.getDatabase(ctx)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
|
||||
return
|
||||
}
|
||||
|
||||
uploadType := backups_core.PgWalUploadType(ctx.GetHeader("X-Upload-Type"))
|
||||
if uploadType != backups_core.PgWalUploadTypeBasebackup &&
|
||||
uploadType != backups_core.PgWalUploadTypeWal {
|
||||
response, err := c.walService.IsWalChainValid(database)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// UploadWalSegment
|
||||
// @Summary Stream upload a WAL segment
|
||||
// @Description Accepts a zstd-compressed WAL segment binary stream and stores it in the database's configured storage.
|
||||
// WAL segments are accepted unconditionally.
|
||||
// @Tags backups-wal
|
||||
// @Accept application/octet-stream
|
||||
// @Security AgentToken
|
||||
// @Param X-Wal-Segment-Name header string true "24-hex WAL segment identifier (e.g. 0000000100000001000000AB)"
|
||||
// @Success 204
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /backups/postgres/wal/upload/wal [post]
|
||||
func (c *PostgreWalBackupController) UploadWalSegment(ctx *gin.Context) {
|
||||
database, err := c.getDatabase(ctx)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
|
||||
return
|
||||
}
|
||||
|
||||
walSegmentName := ctx.GetHeader("X-Wal-Segment-Name")
|
||||
if walSegmentName == "" {
|
||||
ctx.JSON(
|
||||
http.StatusBadRequest,
|
||||
gin.H{"error": "X-Upload-Type must be 'basebackup' or 'wal'"},
|
||||
gin.H{"error": "X-Wal-Segment-Name is required for wal uploads"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
walSegmentName := ""
|
||||
if uploadType == backups_core.PgWalUploadTypeWal {
|
||||
walSegmentName = ctx.GetHeader("X-Wal-Segment-Name")
|
||||
if walSegmentName == "" {
|
||||
ctx.JSON(
|
||||
http.StatusBadRequest,
|
||||
gin.H{"error": "X-Wal-Segment-Name is required for wal uploads"},
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if uploadType == backups_core.PgWalUploadTypeBasebackup {
|
||||
if ctx.Query("fullBackupWalStartSegment") == "" ||
|
||||
ctx.Query("fullBackupWalStopSegment") == "" {
|
||||
ctx.JSON(
|
||||
http.StatusBadRequest,
|
||||
gin.H{
|
||||
"error": "fullBackupWalStartSegment and fullBackupWalStopSegment are required for basebackup uploads",
|
||||
},
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
walSegmentSizeBytes := int64(0)
|
||||
if raw := ctx.GetHeader("X-Wal-Segment-Size"); raw != "" {
|
||||
parsed, parseErr := strconv.ParseInt(raw, 10, 64)
|
||||
if parseErr != nil || parsed <= 0 {
|
||||
ctx.JSON(
|
||||
http.StatusBadRequest,
|
||||
gin.H{"error": "X-Wal-Segment-Size must be a positive integer"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
walSegmentSizeBytes = parsed
|
||||
}
|
||||
|
||||
gapResp, uploadErr := c.walService.UploadWal(
|
||||
uploadErr := c.walService.UploadWalSegment(
|
||||
ctx.Request.Context(),
|
||||
database,
|
||||
uploadType,
|
||||
walSegmentName,
|
||||
ctx.Query("fullBackupWalStartSegment"),
|
||||
ctx.Query("fullBackupWalStopSegment"),
|
||||
walSegmentSizeBytes,
|
||||
ctx.Request.Body,
|
||||
)
|
||||
|
||||
@@ -183,17 +159,89 @@ func (c *PostgreWalBackupController) Upload(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if gapResp != nil {
|
||||
if gapResp.Error == "no_full_backup" {
|
||||
ctx.JSON(http.StatusBadRequest, gapResp)
|
||||
return
|
||||
}
|
||||
ctx.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusConflict, gapResp)
|
||||
// StartFullBackupUpload
|
||||
// @Summary Stream upload a full basebackup (Phase 1)
|
||||
// @Description Accepts a zstd-compressed basebackup binary stream and stores it in the database's configured storage.
|
||||
// Returns a backupId that must be completed via /upload/full-complete with WAL segment names.
|
||||
// @Tags backups-wal
|
||||
// @Accept application/octet-stream
|
||||
// @Produce json
|
||||
// @Security AgentToken
|
||||
// @Success 200 {object} backups_dto.UploadBasebackupResponse
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /backups/postgres/wal/upload/full-start [post]
|
||||
func (c *PostgreWalBackupController) StartFullBackupUpload(ctx *gin.Context) {
|
||||
database, err := c.getDatabase(ctx)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Status(http.StatusNoContent)
|
||||
backupID, uploadErr := c.walService.UploadBasebackup(
|
||||
ctx.Request.Context(),
|
||||
database,
|
||||
ctx.Request.Body,
|
||||
)
|
||||
|
||||
if uploadErr != nil {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": uploadErr.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, backups_dto.UploadBasebackupResponse{
|
||||
BackupID: backupID,
|
||||
})
|
||||
}
|
||||
|
||||
// CompleteFullBackupUpload
|
||||
// @Summary Complete a previously uploaded basebackup (Phase 2)
|
||||
// @Description Sets WAL segment names and marks the basebackup as completed, or marks it as failed if an error is provided.
|
||||
// @Tags backups-wal
|
||||
// @Accept json
|
||||
// @Security AgentToken
|
||||
// @Param request body backups_dto.FinalizeBasebackupRequest true "Completion details"
|
||||
// @Success 200
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /backups/postgres/wal/upload/full-complete [post]
|
||||
func (c *PostgreWalBackupController) CompleteFullBackupUpload(ctx *gin.Context) {
|
||||
database, err := c.getDatabase(ctx)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
|
||||
return
|
||||
}
|
||||
|
||||
var request backups_dto.FinalizeBasebackupRequest
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if request.Error == nil && (request.StartSegment == "" || request.StopSegment == "") {
|
||||
ctx.JSON(
|
||||
http.StatusBadRequest,
|
||||
gin.H{"error": "startSegment and stopSegment are required when no error is provided"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.walService.FinalizeBasebackup(
|
||||
database,
|
||||
request.BackupID,
|
||||
request.StartSegment,
|
||||
request.StopSegment,
|
||||
request.Error,
|
||||
); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
// GetRestorePlan
|
||||
|
||||
@@ -38,7 +38,7 @@ func Test_WalUpload_InProgressStatusSetBeforeStream(t *testing.T) {
|
||||
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
req := newWalUploadRequest(pr, agentToken, "wal", "000000010000000100000011", "", "")
|
||||
req := newWalSegmentUploadRequest(pr, agentToken, "000000010000000100000011")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
done := make(chan struct{})
|
||||
@@ -67,7 +67,7 @@ func Test_WalUpload_CompletedStatusAfterSuccessfulStream(t *testing.T) {
|
||||
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
|
||||
|
||||
body := bytes.NewReader([]byte("wal segment content"))
|
||||
req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000011", "", "")
|
||||
req := newWalSegmentUploadRequest(body, agentToken, "000000010000000100000011")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
@@ -99,7 +99,7 @@ func Test_WalUpload_FailedStatusWithErrorOnStreamError(t *testing.T) {
|
||||
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
req := newWalUploadRequest(pr, agentToken, "wal", "000000010000000100000011", "", "")
|
||||
req := newWalSegmentUploadRequest(pr, agentToken, "000000010000000100000011")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
done := make(chan struct{})
|
||||
@@ -129,59 +129,171 @@ func Test_WalUpload_FailedStatusWithErrorOnStreamError(t *testing.T) {
|
||||
assert.NotNil(t, walBackup.FailMessage)
|
||||
}
|
||||
|
||||
func Test_WalUpload_Basebackup_MissingWalSegments_Returns400(t *testing.T) {
|
||||
func Test_WalUpload_Basebackup_StreamingUpload_Returns200WithBackupId(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
body := bytes.NewReader([]byte("basebackup content"))
|
||||
req := newWalUploadRequest(body, agentToken, backups_core.PgWalUploadTypeBasebackup, "", "", "")
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/v1/backups/postgres/wal/upload/full-start", body)
|
||||
req.Header.Set("Authorization", agentToken)
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var response backups_dto.UploadBasebackupResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
assert.NotEqual(t, uuid.Nil, response.BackupID)
|
||||
|
||||
backup, err := backups_core.GetBackupRepository().FindByID(response.BackupID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusInProgress, backup.Status)
|
||||
assert.NotNil(t, backup.UploadCompletedAt)
|
||||
}
|
||||
|
||||
func Test_FinalizeBasebackup_ValidSegments_MarksCompleted(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
backupID := uploadBasebackupPhase1(t, router, agentToken)
|
||||
|
||||
completeFullBackupUpload(t, router, agentToken, backupID,
|
||||
"000000010000000100000001", "000000010000000100000010", nil)
|
||||
|
||||
backup, err := backups_core.GetBackupRepository().FindByID(backupID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
require.NotNil(t, backup.PgFullBackupWalStartSegmentName)
|
||||
assert.Equal(t, "000000010000000100000001", *backup.PgFullBackupWalStartSegmentName)
|
||||
require.NotNil(t, backup.PgFullBackupWalStopSegmentName)
|
||||
assert.Equal(t, "000000010000000100000010", *backup.PgFullBackupWalStopSegmentName)
|
||||
}
|
||||
|
||||
func Test_FinalizeBasebackup_WithError_MarksFailed(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
backupID := uploadBasebackupPhase1(t, router, agentToken)
|
||||
|
||||
errMsg := "pg_basebackup stderr parse failed"
|
||||
completeFullBackupUpload(t, router, agentToken, backupID, "", "", &errMsg)
|
||||
|
||||
backup, err := backups_core.GetBackupRepository().FindByID(backupID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusFailed, backup.Status)
|
||||
require.NotNil(t, backup.FailMessage)
|
||||
assert.Equal(t, errMsg, *backup.FailMessage)
|
||||
}
|
||||
|
||||
func Test_FinalizeBasebackup_InvalidBackupId_Returns400(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
nonExistentID := uuid.New()
|
||||
body, _ := json.Marshal(backups_dto.FinalizeBasebackupRequest{
|
||||
BackupID: nonExistentID,
|
||||
StartSegment: "000000010000000100000001",
|
||||
StopSegment: "000000010000000100000010",
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/backups/postgres/wal/upload/full-complete",
|
||||
bytes.NewReader(body),
|
||||
)
|
||||
req.Header.Set("Authorization", agentToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
}
|
||||
|
||||
func Test_WalUpload_WalSegment_NoFullBackup_Returns400(t *testing.T) {
|
||||
func Test_FinalizeBasebackup_AlreadyCompleted_Returns400(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
backupID := uploadBasebackupPhase1(t, router, agentToken)
|
||||
|
||||
completeFullBackupUpload(t, router, agentToken, backupID,
|
||||
"000000010000000100000001", "000000010000000100000010", nil)
|
||||
|
||||
// Second finalize should fail.
|
||||
body, _ := json.Marshal(backups_dto.FinalizeBasebackupRequest{
|
||||
BackupID: backupID,
|
||||
StartSegment: "000000010000000100000001",
|
||||
StopSegment: "000000010000000100000010",
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/backups/postgres/wal/upload/full-complete",
|
||||
bytes.NewReader(body),
|
||||
)
|
||||
req.Header.Set("Authorization", agentToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
}
|
||||
|
||||
func Test_FinalizeBasebackup_InvalidToken_Returns401(t *testing.T) {
|
||||
router, db, storage, _, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
body, _ := json.Marshal(backups_dto.FinalizeBasebackupRequest{
|
||||
BackupID: uuid.New(),
|
||||
StartSegment: "000000010000000100000001",
|
||||
StopSegment: "000000010000000100000010",
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/backups/postgres/wal/upload/full-complete",
|
||||
bytes.NewReader(body),
|
||||
)
|
||||
req.Header.Set("Authorization", "invalid-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
|
||||
func Test_WalUpload_WalSegment_WithoutFullBackup_Returns204(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
// No full backup inserted — chain anchor is missing.
|
||||
body := bytes.NewReader([]byte("wal content"))
|
||||
req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000001", "", "")
|
||||
req := newWalSegmentUploadRequest(body, agentToken, "000000010000000100000001")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
|
||||
var resp backups_dto.UploadGapResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||
assert.Equal(t, "no_full_backup", resp.Error)
|
||||
assert.Equal(t, http.StatusNoContent, w.Code)
|
||||
}
|
||||
|
||||
func Test_WalUpload_WalSegment_GapDetected_Returns409WithExpectedAndReceived(t *testing.T) {
|
||||
func Test_WalUpload_WalSegment_WithGap_Returns204(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
// Full backup stops at ...0010; upload one WAL segment at ...0011.
|
||||
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
|
||||
uploadWalSegment(t, router, agentToken, "000000010000000100000011")
|
||||
|
||||
// Send ...0013 — should be rejected because ...0012 is missing.
|
||||
// Skip ...0012, upload ...0013 — should succeed (no chain validation on upload).
|
||||
body := bytes.NewReader([]byte("wal content"))
|
||||
req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000013", "", "")
|
||||
req := newWalSegmentUploadRequest(body, agentToken, "000000010000000100000013")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusConflict, w.Code)
|
||||
|
||||
var resp backups_dto.UploadGapResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||
assert.Equal(t, "gap_detected", resp.Error)
|
||||
assert.Equal(t, "000000010000000100000012", resp.ExpectedSegmentName)
|
||||
assert.Equal(t, "000000010000000100000013", resp.ReceivedSegmentName)
|
||||
assert.Equal(t, http.StatusNoContent, w.Code)
|
||||
}
|
||||
|
||||
func Test_WalUpload_WalSegment_DuplicateSegment_Returns200Idempotent(t *testing.T) {
|
||||
@@ -192,14 +304,14 @@ func Test_WalUpload_WalSegment_DuplicateSegment_Returns200Idempotent(t *testing.
|
||||
|
||||
// Upload ...0011 once.
|
||||
body1 := bytes.NewReader([]byte("wal content"))
|
||||
req1 := newWalUploadRequest(body1, agentToken, "wal", "000000010000000100000011", "", "")
|
||||
req1 := newWalSegmentUploadRequest(body1, agentToken, "000000010000000100000011")
|
||||
w1 := httptest.NewRecorder()
|
||||
router.ServeHTTP(w1, req1)
|
||||
require.Equal(t, http.StatusNoContent, w1.Code)
|
||||
|
||||
// Upload the same segment again — must return 204 (idempotent).
|
||||
body2 := bytes.NewReader([]byte("wal content"))
|
||||
req2 := newWalUploadRequest(body2, agentToken, "wal", "000000010000000100000011", "", "")
|
||||
req2 := newWalSegmentUploadRequest(body2, agentToken, "000000010000000100000011")
|
||||
w2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(w2, req2)
|
||||
|
||||
@@ -228,7 +340,7 @@ func Test_WalUpload_WalSegment_ValidNextSegment_Returns200AndCreatesRecord(t *te
|
||||
|
||||
// First WAL segment after the full backup stop segment.
|
||||
body := bytes.NewReader([]byte("wal segment data"))
|
||||
req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000011", "", "")
|
||||
req := newWalSegmentUploadRequest(body, agentToken, "000000010000000100000011")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
@@ -255,6 +367,108 @@ func Test_WalUpload_WalSegment_ValidNextSegment_Returns200AndCreatesRecord(t *te
|
||||
assert.Equal(t, "000000010000000100000011", *walBackup.PgWalSegmentName)
|
||||
}
|
||||
|
||||
func Test_IsWalChainValid_NoFullBackup_ReturnsFalse(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
var response backups_dto.IsWalChainValidResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t, router,
|
||||
"/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup",
|
||||
agentToken,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.False(t, response.IsValid)
|
||||
assert.Equal(t, "no_full_backup", response.Error)
|
||||
}
|
||||
|
||||
func Test_IsWalChainValid_FullBackupOnly_ReturnsTrue(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
|
||||
|
||||
var response backups_dto.IsWalChainValidResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t, router,
|
||||
"/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup",
|
||||
agentToken,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.True(t, response.IsValid)
|
||||
assert.Empty(t, response.Error)
|
||||
}
|
||||
|
||||
func Test_IsWalChainValid_ContinuousChain_ReturnsTrue(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
|
||||
uploadWalSegment(t, router, agentToken, "000000010000000100000011")
|
||||
uploadWalSegment(t, router, agentToken, "000000010000000100000012")
|
||||
uploadWalSegment(t, router, agentToken, "000000010000000100000013")
|
||||
|
||||
var response backups_dto.IsWalChainValidResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t, router,
|
||||
"/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup",
|
||||
agentToken,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.True(t, response.IsValid)
|
||||
}
|
||||
|
||||
func Test_IsWalChainValid_BrokenChain_ReturnsFalse(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
|
||||
uploadWalSegment(t, router, agentToken, "000000010000000100000011")
|
||||
uploadWalSegment(t, router, agentToken, "000000010000000100000012")
|
||||
uploadWalSegment(t, router, agentToken, "000000010000000100000013")
|
||||
|
||||
// Delete the middle segment to create a gap.
|
||||
middleSeg, err := backups_core.GetBackupRepository().FindWalSegmentByName(
|
||||
db.ID, "000000010000000100000012",
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, middleSeg)
|
||||
require.NoError(t, backups_core.GetBackupRepository().DeleteByID(middleSeg.ID))
|
||||
|
||||
var response backups_dto.IsWalChainValidResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t, router,
|
||||
"/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup",
|
||||
agentToken,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.False(t, response.IsValid)
|
||||
assert.Equal(t, "wal_chain_broken", response.Error)
|
||||
assert.Equal(t, "000000010000000100000011", response.LastContiguousSegment)
|
||||
}
|
||||
|
||||
func Test_IsWalChainValid_InvalidToken_Returns401(t *testing.T) {
|
||||
router, db, storage, _, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
resp := test_utils.MakeGetRequest(
|
||||
t, router,
|
||||
"/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup",
|
||||
"invalid-token",
|
||||
http.StatusUnauthorized,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(resp.Body), "invalid agent token")
|
||||
}
|
||||
|
||||
func Test_ReportError_ValidTokenAndError_CreatesFailedBackupRecord(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
@@ -457,29 +671,14 @@ func Test_GetNextFullBackupTime_WalSegmentAfterFullBackup_DoesNotImpactTime(t *t
|
||||
|
||||
setHourlyInterval(t, router, db.ID, ownerToken)
|
||||
|
||||
// Upload basebackup via API.
|
||||
bbBody := bytes.NewReader([]byte("basebackup content"))
|
||||
bbReq := newWalUploadRequest(
|
||||
bbBody, agentToken, backups_core.PgWalUploadTypeBasebackup, "",
|
||||
"000000010000000100000001", "000000010000000100000010",
|
||||
)
|
||||
bbW := httptest.NewRecorder()
|
||||
router.ServeHTTP(bbW, bbReq)
|
||||
require.Equal(t, http.StatusNoContent, bbW.Code)
|
||||
uploadBasebackup(t, router, agentToken,
|
||||
"000000010000000100000001", "000000010000000100000010")
|
||||
|
||||
// Shift the full backup's CreatedAt to 2 hours ago.
|
||||
twoHoursAgo := time.Now().UTC().Add(-2 * time.Hour)
|
||||
updateLastFullBackupTime(t, db.ID, twoHoursAgo)
|
||||
|
||||
// Upload WAL segment via API.
|
||||
walBody := bytes.NewReader([]byte("wal segment content"))
|
||||
walReq := newWalUploadRequest(
|
||||
walBody, agentToken, backups_core.PgWalUploadTypeWal,
|
||||
"000000010000000100000011", "", "",
|
||||
)
|
||||
walW := httptest.NewRecorder()
|
||||
router.ServeHTTP(walW, walReq)
|
||||
require.Equal(t, http.StatusNoContent, walW.Code)
|
||||
uploadWalSegment(t, router, agentToken, "000000010000000100000011")
|
||||
|
||||
var response backups_dto.GetNextFullBackupTimeResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
@@ -508,15 +707,8 @@ func Test_GetNextFullBackupTime_FailedBasebackup_DoesNotImpactTime(t *testing.T)
|
||||
|
||||
setHourlyInterval(t, router, db.ID, ownerToken)
|
||||
|
||||
// Upload a successful basebackup via API.
|
||||
bbBody := bytes.NewReader([]byte("basebackup content"))
|
||||
bbReq := newWalUploadRequest(
|
||||
bbBody, agentToken, backups_core.PgWalUploadTypeBasebackup, "",
|
||||
"000000010000000100000001", "000000010000000100000010",
|
||||
)
|
||||
bbW := httptest.NewRecorder()
|
||||
router.ServeHTTP(bbW, bbReq)
|
||||
require.Equal(t, http.StatusNoContent, bbW.Code)
|
||||
uploadBasebackup(t, router, agentToken,
|
||||
"000000010000000100000001", "000000010000000100000010")
|
||||
|
||||
// Shift the full backup's CreatedAt to 2 hours ago.
|
||||
twoHoursAgo := time.Now().UTC().Add(-2 * time.Hour)
|
||||
@@ -563,15 +755,8 @@ func Test_GetNextFullBackupTime_NewCompletedFullBackup_ImpactsTime(t *testing.T)
|
||||
|
||||
setHourlyInterval(t, router, db.ID, ownerToken)
|
||||
|
||||
// Upload first basebackup via API.
|
||||
bb1 := bytes.NewReader([]byte("first basebackup"))
|
||||
bb1Req := newWalUploadRequest(
|
||||
bb1, agentToken, backups_core.PgWalUploadTypeBasebackup, "",
|
||||
"000000010000000100000001", "000000010000000100000010",
|
||||
)
|
||||
bb1W := httptest.NewRecorder()
|
||||
router.ServeHTTP(bb1W, bb1Req)
|
||||
require.Equal(t, http.StatusNoContent, bb1W.Code)
|
||||
uploadBasebackup(t, router, agentToken,
|
||||
"000000010000000100000001", "000000010000000100000010")
|
||||
|
||||
// Shift the first backup's CreatedAt to 3 hours ago.
|
||||
threeHoursAgo := time.Now().UTC().Add(-3 * time.Hour)
|
||||
@@ -595,15 +780,8 @@ func Test_GetNextFullBackupTime_NewCompletedFullBackup_ImpactsTime(t *testing.T)
|
||||
"first next time should be in the past (old backup)",
|
||||
)
|
||||
|
||||
// Upload second basebackup via API (created now).
|
||||
bb2 := bytes.NewReader([]byte("second basebackup"))
|
||||
bb2Req := newWalUploadRequest(
|
||||
bb2, agentToken, backups_core.PgWalUploadTypeBasebackup, "",
|
||||
"000000010000000100000011", "000000010000000100000020",
|
||||
)
|
||||
bb2W := httptest.NewRecorder()
|
||||
router.ServeHTTP(bb2W, bb2Req)
|
||||
require.Equal(t, http.StatusNoContent, bb2W.Code)
|
||||
uploadBasebackup(t, router, agentToken,
|
||||
"000000010000000100000011", "000000010000000100000020")
|
||||
|
||||
var secondResponse backups_dto.GetNextFullBackupTimeResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
@@ -841,15 +1019,18 @@ func Test_DownloadRestoreFile_UploadThenDownload_ContentMatches(t *testing.T) {
|
||||
|
||||
uploadContent := "test-basebackup-content-for-download"
|
||||
body := bytes.NewReader([]byte(uploadContent))
|
||||
req := newWalUploadRequest(
|
||||
body, agentToken, backups_core.PgWalUploadTypeBasebackup, "",
|
||||
"000000010000000100000001", "000000010000000100000010",
|
||||
)
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/v1/backups/postgres/wal/upload/full-start", body)
|
||||
req.Header.Set("Authorization", agentToken)
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
require.Equal(t, http.StatusNoContent, w.Code)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
WaitForBackupCompletion(t, db.ID, 0, 5*time.Second)
|
||||
var uploadResp backups_dto.UploadBasebackupResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &uploadResp))
|
||||
|
||||
completeFullBackupUpload(t, router, agentToken, uploadResp.BackupID,
|
||||
"000000010000000100000001", "000000010000000100000010", nil)
|
||||
|
||||
var planResp backups_dto.GetRestorePlanResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
@@ -883,7 +1064,7 @@ func Test_DownloadRestoreFile_WalSegment_UploadThenDownload_ContentMatches(t *te
|
||||
|
||||
walContent := "test-wal-segment-content-for-download"
|
||||
body := bytes.NewReader([]byte(walContent))
|
||||
req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000011", "", "")
|
||||
req := newWalSegmentUploadRequest(body, agentToken, "000000010000000100000011")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
require.Equal(t, http.StatusNoContent, w.Code)
|
||||
@@ -1088,35 +1269,81 @@ func removeWalTestSetup(db *databases.Database, storage *storages.Storage) {
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
}
|
||||
|
||||
func newWalUploadRequest(
|
||||
func newWalSegmentUploadRequest(
|
||||
body io.Reader,
|
||||
agentToken string,
|
||||
uploadType backups_core.PgWalUploadType,
|
||||
walSegmentName string,
|
||||
walStart string,
|
||||
walStop string,
|
||||
segmentName string,
|
||||
) *http.Request {
|
||||
url := "/api/v1/backups/postgres/wal/upload"
|
||||
if walStart != "" || walStop != "" {
|
||||
url += "?fullBackupWalStartSegment=" + walStart + "&fullBackupWalStopSegment=" + walStop
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, url, body)
|
||||
req, err := http.NewRequest(http.MethodPost, "/api/v1/backups/postgres/wal/upload/wal", body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", agentToken)
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
req.Header.Set("X-Upload-Type", string(uploadType))
|
||||
|
||||
if walSegmentName != "" {
|
||||
req.Header.Set("X-Wal-Segment-Name", walSegmentName)
|
||||
}
|
||||
req.Header.Set("X-Wal-Segment-Name", segmentName)
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func uploadBasebackupPhase1(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
agentToken string,
|
||||
) uuid.UUID {
|
||||
t.Helper()
|
||||
|
||||
body := bytes.NewReader([]byte("test-basebackup-content"))
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/v1/backups/postgres/wal/upload/full-start", body)
|
||||
req.Header.Set("Authorization", agentToken)
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var response backups_dto.UploadBasebackupResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
require.NotEqual(t, uuid.Nil, response.BackupID)
|
||||
|
||||
return response.BackupID
|
||||
}
|
||||
|
||||
func completeFullBackupUpload(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
agentToken string,
|
||||
backupID uuid.UUID,
|
||||
walStart string,
|
||||
walStop string,
|
||||
errMsg *string,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
request := backups_dto.FinalizeBasebackupRequest{
|
||||
BackupID: backupID,
|
||||
StartSegment: walStart,
|
||||
StopSegment: walStop,
|
||||
Error: errMsg,
|
||||
}
|
||||
|
||||
reqBody, _ := json.Marshal(request)
|
||||
req, _ := http.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/backups/postgres/wal/upload/full-complete",
|
||||
bytes.NewReader(reqBody),
|
||||
)
|
||||
req.Header.Set("Authorization", agentToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func uploadBasebackup(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
@@ -1126,15 +1353,8 @@ func uploadBasebackup(
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
body := bytes.NewReader([]byte("test-basebackup-content"))
|
||||
req := newWalUploadRequest(
|
||||
body, agentToken, backups_core.PgWalUploadTypeBasebackup, "",
|
||||
walStart, walStop,
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusNoContent, w.Code)
|
||||
backupID := uploadBasebackupPhase1(t, router, agentToken)
|
||||
completeFullBackupUpload(t, router, agentToken, backupID, walStart, walStop, nil)
|
||||
}
|
||||
|
||||
func uploadWalSegment(
|
||||
@@ -1146,9 +1366,12 @@ func uploadWalSegment(
|
||||
t.Helper()
|
||||
|
||||
body := bytes.NewReader([]byte("test-wal-segment-content"))
|
||||
req := newWalUploadRequest(
|
||||
body, agentToken, backups_core.PgWalUploadTypeWal, segmentName, "", "",
|
||||
)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/v1/backups/postgres/wal/upload/wal", body)
|
||||
req.Header.Set("Authorization", agentToken)
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
req.Header.Set("X-Wal-Segment-Name", segmentName)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
|
||||
@@ -43,7 +43,8 @@ type Backup struct {
|
||||
PgVersion *string `json:"pgVersion" gorm:"column:pg_version;type:text"`
|
||||
PgWalSegmentName *string `json:"pgWalSegmentName" gorm:"column:pg_wal_segment_name;type:text"`
|
||||
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
|
||||
UploadCompletedAt *time.Time `json:"uploadCompletedAt" gorm:"column:upload_completed_at"`
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
|
||||
}
|
||||
|
||||
func (b *Backup) GenerateFilename(dbName string) {
|
||||
|
||||
@@ -349,6 +349,24 @@ func (r *BackupRepository) FindWalSegmentByName(
|
||||
return &backup, nil
|
||||
}
|
||||
|
||||
func (r *BackupRepository) FindStaleUploadedBasebackups(olderThan time.Time) ([]*Backup, error) {
|
||||
var backups []*Backup
|
||||
|
||||
err := storage.
|
||||
GetDb().
|
||||
Where(
|
||||
"status = ? AND upload_completed_at IS NOT NULL AND upload_completed_at < ?",
|
||||
BackupStatusInProgress,
|
||||
olderThan,
|
||||
).
|
||||
Find(&backups).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return backups, nil
|
||||
}
|
||||
|
||||
func (r *BackupRepository) FindLastWalSegmentAfter(
|
||||
databaseID uuid.UUID,
|
||||
afterSegmentName string,
|
||||
|
||||
@@ -44,10 +44,10 @@ type ReportErrorRequest struct {
|
||||
Error string `json:"error" binding:"required"`
|
||||
}
|
||||
|
||||
type UploadGapResponse struct {
|
||||
Error string `json:"error"`
|
||||
ExpectedSegmentName string `json:"expectedSegmentName"`
|
||||
ReceivedSegmentName string `json:"receivedSegmentName"`
|
||||
type IsWalChainValidResponse struct {
|
||||
IsValid bool `json:"isValid"`
|
||||
Error string `json:"error,omitempty"`
|
||||
LastContiguousSegment string `json:"lastContiguousSegment,omitempty"`
|
||||
}
|
||||
|
||||
type RestorePlanFullBackup struct {
|
||||
@@ -77,3 +77,14 @@ type GetRestorePlanResponse struct {
|
||||
TotalSizeBytes int64 `json:"totalSizeBytes"`
|
||||
LatestAvailableSegment string `json:"latestAvailableSegment"`
|
||||
}
|
||||
|
||||
type UploadBasebackupResponse struct {
|
||||
BackupID uuid.UUID `json:"backupId"`
|
||||
}
|
||||
|
||||
type FinalizeBasebackupRequest struct {
|
||||
BackupID uuid.UUID `json:"backupId" binding:"required"`
|
||||
StartSegment string `json:"startSegment"`
|
||||
StopSegment string `json:"stopSegment"`
|
||||
Error *string `json:"error"`
|
||||
}
|
||||
|
||||
@@ -30,75 +30,46 @@ type PostgreWalBackupService struct {
|
||||
backupService *BackupService
|
||||
}
|
||||
|
||||
// UploadWal accepts a streaming WAL segment or basebackup upload from the agent.
|
||||
// For WAL segments it validates the WAL chain before accepting. Returns an UploadGapResponse
|
||||
// (409) when the chain is broken so the agent knows to trigger a full basebackup.
|
||||
func (s *PostgreWalBackupService) UploadWal(
|
||||
// UploadWalSegment accepts a streaming WAL segment upload from the agent.
|
||||
// WAL segments are accepted unconditionally.
|
||||
func (s *PostgreWalBackupService) UploadWalSegment(
|
||||
ctx context.Context,
|
||||
database *databases.Database,
|
||||
uploadType backups_core.PgWalUploadType,
|
||||
walSegmentName string,
|
||||
fullBackupWalStartSegment string,
|
||||
fullBackupWalStopSegment string,
|
||||
walSegmentSizeBytes int64,
|
||||
body io.Reader,
|
||||
) (*backups_dto.UploadGapResponse, error) {
|
||||
) error {
|
||||
if err := s.validateWalBackupType(database); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if uploadType == backups_core.PgWalUploadTypeBasebackup {
|
||||
if fullBackupWalStartSegment == "" || fullBackupWalStopSegment == "" {
|
||||
return nil, fmt.Errorf(
|
||||
"fullBackupWalStartSegment and fullBackupWalStopSegment are required for basebackup uploads",
|
||||
)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get backup config: %w", err)
|
||||
return fmt.Errorf("failed to get backup config: %w", err)
|
||||
}
|
||||
|
||||
if backupConfig.Storage == nil {
|
||||
return nil, fmt.Errorf("no storage configured for database %s", database.ID)
|
||||
return fmt.Errorf("no storage configured for database %s", database.ID)
|
||||
}
|
||||
|
||||
if uploadType == backups_core.PgWalUploadTypeWal {
|
||||
// Idempotency: check before chain validation so a successful re-upload is
|
||||
// not misidentified as a gap.
|
||||
existing, err := s.backupRepository.FindWalSegmentByName(database.ID, walSegmentName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check for duplicate WAL segment: %w", err)
|
||||
}
|
||||
|
||||
if existing != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
gapResp, err := s.validateWalChain(database.ID, walSegmentName, walSegmentSizeBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if gapResp != nil {
|
||||
return gapResp, nil
|
||||
}
|
||||
existing, err := s.backupRepository.FindWalSegmentByName(database.ID, walSegmentName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check for duplicate WAL segment: %w", err)
|
||||
}
|
||||
|
||||
backup := s.createBackupRecord(
|
||||
if existing != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
backup := s.createWalSegmentRecord(
|
||||
database.ID,
|
||||
backupConfig.Storage.ID,
|
||||
uploadType,
|
||||
database.Name,
|
||||
walSegmentName,
|
||||
fullBackupWalStartSegment,
|
||||
fullBackupWalStopSegment,
|
||||
backupConfig.Encryption,
|
||||
)
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
return nil, fmt.Errorf("failed to create backup record: %w", err)
|
||||
return fmt.Errorf("failed to create backup record: %w", err)
|
||||
}
|
||||
|
||||
sizeBytes, streamErr := s.streamToStorage(ctx, backup, backupConfig, body)
|
||||
@@ -106,12 +77,106 @@ func (s *PostgreWalBackupService) UploadWal(
|
||||
errMsg := streamErr.Error()
|
||||
s.markFailed(backup, errMsg)
|
||||
|
||||
return nil, fmt.Errorf("upload failed: %w", streamErr)
|
||||
return fmt.Errorf("upload failed: %w", streamErr)
|
||||
}
|
||||
|
||||
s.markCompleted(backup, sizeBytes)
|
||||
|
||||
return nil, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// UploadBasebackup accepts a streaming basebackup upload from the agent (Phase 1).
|
||||
// The backup stays IN_PROGRESS with UploadCompletedAt set after streaming finishes.
|
||||
// The agent must call FinalizeBasebackup (Phase 2) with WAL segment names to complete.
|
||||
func (s *PostgreWalBackupService) UploadBasebackup(
|
||||
ctx context.Context,
|
||||
database *databases.Database,
|
||||
body io.Reader,
|
||||
) (uuid.UUID, error) {
|
||||
if err := s.validateWalBackupType(database); err != nil {
|
||||
return uuid.Nil, err
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
|
||||
if err != nil {
|
||||
return uuid.Nil, fmt.Errorf("failed to get backup config: %w", err)
|
||||
}
|
||||
|
||||
if backupConfig.Storage == nil {
|
||||
return uuid.Nil, fmt.Errorf("no storage configured for database %s", database.ID)
|
||||
}
|
||||
|
||||
backup := s.createBasebackupRecord(
|
||||
database.ID,
|
||||
backupConfig.Storage.ID,
|
||||
database.Name,
|
||||
backupConfig.Encryption,
|
||||
)
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
return uuid.Nil, fmt.Errorf("failed to create backup record: %w", err)
|
||||
}
|
||||
|
||||
sizeBytes, streamErr := s.streamToStorage(ctx, backup, backupConfig, body)
|
||||
if streamErr != nil {
|
||||
errMsg := streamErr.Error()
|
||||
s.markFailed(backup, errMsg)
|
||||
|
||||
return uuid.Nil, fmt.Errorf("upload failed: %w", streamErr)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
backup.UploadCompletedAt = &now
|
||||
backup.BackupSizeMb = float64(sizeBytes) / (1024 * 1024)
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
return uuid.Nil, fmt.Errorf("failed to update backup after upload: %w", err)
|
||||
}
|
||||
|
||||
return backup.ID, nil
|
||||
}
|
||||
|
||||
// FinalizeBasebackup completes a previously uploaded basebackup (Phase 2).
|
||||
// Sets WAL segment names and marks the backup as COMPLETED, or marks it FAILED if errorMsg is provided.
|
||||
func (s *PostgreWalBackupService) FinalizeBasebackup(
|
||||
database *databases.Database,
|
||||
backupID uuid.UUID,
|
||||
startSegment string,
|
||||
stopSegment string,
|
||||
errorMsg *string,
|
||||
) error {
|
||||
if err := s.validateWalBackupType(database); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
backup, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backup not found: %w", err)
|
||||
}
|
||||
|
||||
if backup.DatabaseID != database.ID {
|
||||
return fmt.Errorf("backup does not belong to this database")
|
||||
}
|
||||
|
||||
if backup.Status != backups_core.BackupStatusInProgress || backup.UploadCompletedAt == nil {
|
||||
return fmt.Errorf("backup is not awaiting finalization")
|
||||
}
|
||||
|
||||
if errorMsg != nil {
|
||||
s.markFailed(backup, *errorMsg)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
backup.PgFullBackupWalStartSegmentName = &startSegment
|
||||
backup.PgFullBackupWalStopSegmentName = &stopSegment
|
||||
backup.Status = backups_core.BackupStatusCompleted
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
return fmt.Errorf("failed to finalize backup: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PostgreWalBackupService) GetRestorePlan(
|
||||
@@ -299,97 +364,97 @@ func (s *PostgreWalBackupService) ReportError(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PostgreWalBackupService) validateWalChain(
|
||||
databaseID uuid.UUID,
|
||||
incomingSegment string,
|
||||
walSegmentSizeBytes int64,
|
||||
) (*backups_dto.UploadGapResponse, error) {
|
||||
fullBackup, err := s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(databaseID)
|
||||
// IsWalChainValid checks whether the WAL chain is continuous since the last completed full backup.
|
||||
func (s *PostgreWalBackupService) IsWalChainValid(
|
||||
database *databases.Database,
|
||||
) (*backups_dto.IsWalChainValidResponse, error) {
|
||||
if err := s.validateWalBackupType(database); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fullBackup, err := s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(database.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query full backup: %w", err)
|
||||
}
|
||||
|
||||
// No full backup exists yet: cannot accept WAL segments without a chain anchor.
|
||||
if fullBackup == nil || fullBackup.PgFullBackupWalStopSegmentName == nil {
|
||||
return &backups_dto.UploadGapResponse{
|
||||
Error: "no_full_backup",
|
||||
ExpectedSegmentName: "",
|
||||
ReceivedSegmentName: incomingSegment,
|
||||
return &backups_dto.IsWalChainValidResponse{
|
||||
IsValid: false,
|
||||
Error: "no_full_backup",
|
||||
}, nil
|
||||
}
|
||||
|
||||
stopSegment := *fullBackup.PgFullBackupWalStopSegmentName
|
||||
startSegment := ""
|
||||
if fullBackup.PgFullBackupWalStartSegmentName != nil {
|
||||
startSegment = *fullBackup.PgFullBackupWalStartSegmentName
|
||||
}
|
||||
|
||||
lastWal, err := s.backupRepository.FindLastWalSegmentAfter(databaseID, stopSegment)
|
||||
walSegments, err := s.backupRepository.FindCompletedWalSegmentsAfter(database.ID, startSegment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query last WAL segment: %w", err)
|
||||
return nil, fmt.Errorf("failed to query WAL segments: %w", err)
|
||||
}
|
||||
|
||||
walCalculator := util_wal.NewWalCalculator(walSegmentSizeBytes)
|
||||
|
||||
var chainTail string
|
||||
if lastWal != nil && lastWal.PgWalSegmentName != nil {
|
||||
chainTail = *lastWal.PgWalSegmentName
|
||||
} else {
|
||||
chainTail = stopSegment
|
||||
}
|
||||
|
||||
expectedNext, err := walCalculator.NextSegment(chainTail)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("WAL arithmetic failed for %q: %w", chainTail, err)
|
||||
}
|
||||
|
||||
if incomingSegment != expectedNext {
|
||||
return &backups_dto.UploadGapResponse{
|
||||
Error: "gap_detected",
|
||||
ExpectedSegmentName: expectedNext,
|
||||
ReceivedSegmentName: incomingSegment,
|
||||
chainErr := s.validateRestoreWalChain(fullBackup, walSegments)
|
||||
if chainErr != nil {
|
||||
return &backups_dto.IsWalChainValidResponse{
|
||||
IsValid: false,
|
||||
Error: chainErr.Error,
|
||||
LastContiguousSegment: chainErr.LastContiguousSegment,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
return &backups_dto.IsWalChainValidResponse{
|
||||
IsValid: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *PostgreWalBackupService) createBackupRecord(
|
||||
func (s *PostgreWalBackupService) createBasebackupRecord(
|
||||
databaseID uuid.UUID,
|
||||
storageID uuid.UUID,
|
||||
uploadType backups_core.PgWalUploadType,
|
||||
dbName string,
|
||||
walSegmentName string,
|
||||
fullBackupWalStartSegment string,
|
||||
fullBackupWalStopSegment string,
|
||||
encryption backups_config.BackupEncryption,
|
||||
) *backups_core.Backup {
|
||||
now := time.Now().UTC()
|
||||
walBackupType := backups_core.PgWalBackupTypeFullBackup
|
||||
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: databaseID,
|
||||
StorageID: storageID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
Encryption: encryption,
|
||||
CreatedAt: now,
|
||||
ID: uuid.New(),
|
||||
DatabaseID: databaseID,
|
||||
StorageID: storageID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
PgWalBackupType: &walBackupType,
|
||||
Encryption: encryption,
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
backup.GenerateFilename(dbName)
|
||||
|
||||
if uploadType == backups_core.PgWalUploadTypeBasebackup {
|
||||
walBackupType := backups_core.PgWalBackupTypeFullBackup
|
||||
backup.PgWalBackupType = &walBackupType
|
||||
return backup
|
||||
}
|
||||
|
||||
if fullBackupWalStartSegment != "" {
|
||||
backup.PgFullBackupWalStartSegmentName = &fullBackupWalStartSegment
|
||||
}
|
||||
func (s *PostgreWalBackupService) createWalSegmentRecord(
|
||||
databaseID uuid.UUID,
|
||||
storageID uuid.UUID,
|
||||
dbName string,
|
||||
walSegmentName string,
|
||||
encryption backups_config.BackupEncryption,
|
||||
) *backups_core.Backup {
|
||||
now := time.Now().UTC()
|
||||
walBackupType := backups_core.PgWalBackupTypeWalSegment
|
||||
|
||||
if fullBackupWalStopSegment != "" {
|
||||
backup.PgFullBackupWalStopSegmentName = &fullBackupWalStopSegment
|
||||
}
|
||||
} else {
|
||||
walBackupType := backups_core.PgWalBackupTypeWalSegment
|
||||
backup.PgWalBackupType = &walBackupType
|
||||
backup.PgWalSegmentName = &walSegmentName
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: databaseID,
|
||||
StorageID: storageID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
PgWalBackupType: &walBackupType,
|
||||
PgWalSegmentName: &walSegmentName,
|
||||
Encryption: encryption,
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
backup.GenerateFilename(dbName)
|
||||
|
||||
return backup
|
||||
}
|
||||
|
||||
|
||||
@@ -2,11 +2,13 @@ package local_storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
@@ -108,7 +110,7 @@ func (l *LocalStorage) SaveFile(
|
||||
}
|
||||
|
||||
// Move the file from temp to backups directory
|
||||
if err = os.Rename(tempFilePath, finalPath); err != nil {
|
||||
if err = moveFile(tempFilePath, finalPath); err != nil {
|
||||
logger.Error(
|
||||
"Failed to move file from temp to backups",
|
||||
"fileName",
|
||||
@@ -197,6 +199,52 @@ func (l *LocalStorage) EncryptSensitiveData(encryptor encryption.FieldEncryptor)
|
||||
func (l *LocalStorage) Update(incoming *LocalStorage) {
|
||||
}
|
||||
|
||||
// moveFile moves a file from src to dst. It first attempts os.Rename for efficiency.
|
||||
// If rename fails with a cross-device link error (EXDEV), it falls back to copy-then-delete.
|
||||
// This happens when users mount temp and backups directories as separate Docker volumes
|
||||
// (e.g., on Unraid with split volume mapping), causing them to reside on different filesystems.
|
||||
func moveFile(src, dst string) error {
|
||||
err := os.Rename(src, dst)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var linkErr *os.LinkError
|
||||
if !errors.As(err, &linkErr) || !errors.Is(linkErr.Err, syscall.EXDEV) {
|
||||
return err
|
||||
}
|
||||
|
||||
srcFile, err := os.Open(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open source file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = srcFile.Close()
|
||||
}()
|
||||
|
||||
dstFile, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create destination file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = dstFile.Close()
|
||||
}()
|
||||
|
||||
if _, err = io.Copy(dstFile, srcFile); err != nil {
|
||||
return fmt.Errorf("failed to copy file: %w", err)
|
||||
}
|
||||
|
||||
if err = dstFile.Sync(); err != nil {
|
||||
return fmt.Errorf("failed to sync destination file: %w", err)
|
||||
}
|
||||
|
||||
if err = os.Remove(src); err != nil {
|
||||
return fmt.Errorf("failed to remove source file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func copyWithContext(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
|
||||
buf := make([]byte, localChunkSize)
|
||||
var written int64
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
-- +goose Up
|
||||
ALTER TABLE backups
|
||||
ADD COLUMN upload_completed_at TIMESTAMPTZ;
|
||||
|
||||
-- +goose Down
|
||||
ALTER TABLE backups
|
||||
DROP COLUMN upload_completed_at;
|
||||
@@ -5,6 +5,10 @@ metadata:
|
||||
namespace: {{ include "databasus.namespace" . }}
|
||||
labels:
|
||||
{{- include "databasus.labels" . | nindent 4 }}
|
||||
{{- with .Values.service.annotations }}
|
||||
annotations:
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
type: {{ .Values.service.type }}
|
||||
ports:
|
||||
|
||||
@@ -17,10 +17,10 @@ service:
|
||||
type: ClusterIP
|
||||
port: 4005 # Service port
|
||||
targetPort: 4005 # Internal container port
|
||||
annotations: {}
|
||||
# Headless service for StatefulSet
|
||||
headless:
|
||||
enabled: true
|
||||
|
||||
# Resource limits and requests
|
||||
resources:
|
||||
requests:
|
||||
|
||||
Reference in New Issue
Block a user