mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 00:32:03 +02:00
Compare commits
44 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f0064b4be3 | ||
|
|
94505bab3f | ||
|
|
9acf3cff09 | ||
|
|
0d7e147df6 | ||
|
|
1394b47570 | ||
|
|
a9865ae3e4 | ||
|
|
4b5478e60a | ||
|
|
6355301903 | ||
|
|
29b403a9c6 | ||
|
|
12606053f4 | ||
|
|
904b386378 | ||
|
|
1d9738b808 | ||
|
|
58b37f4c92 | ||
|
|
6c4f814c94 | ||
|
|
bcd13c27d3 | ||
|
|
120f9600bf | ||
|
|
563c7c1d64 | ||
|
|
68f15f7661 | ||
|
|
627d96a00d | ||
|
|
02b9a9ec8d | ||
|
|
415dda8752 | ||
|
|
3faf85796a | ||
|
|
edd2759f5a | ||
|
|
c283856f38 | ||
|
|
6059e1a33b | ||
|
|
2deda2e7ea | ||
|
|
acf1143752 | ||
|
|
889063a8b4 | ||
|
|
a1e20e7b10 | ||
|
|
7e76945550 | ||
|
|
d98acfc4af | ||
|
|
0ffc7c8c96 | ||
|
|
1b011bdcd4 | ||
|
|
7e209ff537 | ||
|
|
f712e3a437 | ||
|
|
bcd7d8e1aa | ||
|
|
880a7488e9 | ||
|
|
ca4d483f2c | ||
|
|
1b511410a6 | ||
|
|
c8edff8046 | ||
|
|
f60e3d956b | ||
|
|
f2cb9022f2 | ||
|
|
4b3f36eea2 | ||
|
|
460063e7a5 |
87
.github/workflows/ci-release.yml
vendored
87
.github/workflows/ci-release.yml
vendored
@@ -11,7 +11,7 @@ jobs:
|
||||
lint-backend:
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: golang:1.24.9
|
||||
image: golang:1.26.1
|
||||
volumes:
|
||||
- /runner-cache/go-pkg:/go/pkg/mod
|
||||
- /runner-cache/go-build:/root/.cache/go-build
|
||||
@@ -32,7 +32,7 @@ jobs:
|
||||
|
||||
- name: Install golangci-lint
|
||||
run: |
|
||||
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/HEAD/install.sh | sh -s -- -b $(go env GOPATH)/bin v2.7.2
|
||||
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/HEAD/install.sh | sh -s -- -b $(go env GOPATH)/bin v2.11.3
|
||||
echo "$(go env GOPATH)/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Install swag for swagger generation
|
||||
@@ -86,6 +86,39 @@ jobs:
|
||||
cd frontend
|
||||
npm run build
|
||||
|
||||
lint-agent:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.26.1"
|
||||
cache-dependency-path: agent/go.sum
|
||||
|
||||
- name: Download Go modules
|
||||
run: |
|
||||
cd agent
|
||||
go mod download
|
||||
|
||||
- name: Install golangci-lint
|
||||
run: |
|
||||
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/HEAD/install.sh | sh -s -- -b $(go env GOPATH)/bin v2.11.3
|
||||
echo "$(go env GOPATH)/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Run golangci-lint
|
||||
run: |
|
||||
cd agent
|
||||
golangci-lint run
|
||||
|
||||
- name: Verify go mod tidy
|
||||
run: |
|
||||
cd agent
|
||||
go mod tidy
|
||||
git diff --exit-code go.mod go.sum || (echo "go mod tidy made changes, please run 'go mod tidy' and commit the changes" && exit 1)
|
||||
|
||||
test-frontend:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint-frontend]
|
||||
@@ -108,11 +141,55 @@ jobs:
|
||||
cd frontend
|
||||
npm run test
|
||||
|
||||
test-agent:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint-agent]
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.26.1"
|
||||
cache-dependency-path: agent/go.sum
|
||||
|
||||
- name: Download Go modules
|
||||
run: |
|
||||
cd agent
|
||||
go mod download
|
||||
|
||||
- name: Run Go tests
|
||||
run: |
|
||||
cd agent
|
||||
go test -count=1 -failfast ./internal/...
|
||||
|
||||
e2e-agent:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint-agent]
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Run e2e tests
|
||||
run: |
|
||||
cd agent
|
||||
make e2e
|
||||
|
||||
- name: Cleanup
|
||||
if: always()
|
||||
run: |
|
||||
cd agent/e2e
|
||||
docker compose down -v --rmi local || true
|
||||
rm -rf artifacts || true
|
||||
|
||||
# Self-hosted: performant high-frequency CPU is used to start many containers and run tests fast. Tests
|
||||
# step is bottle-neck, because we need a lot of containers and cannot parallelize tests due to shared resources
|
||||
test-backend:
|
||||
runs-on: self-hosted
|
||||
needs: [lint-backend]
|
||||
container:
|
||||
image: golang:1.24.9
|
||||
image: golang:1.26.1
|
||||
options: --privileged -v /var/run/docker.sock:/var/run/docker.sock --add-host=host.docker.internal:host-gateway
|
||||
volumes:
|
||||
- /runner-cache/go-pkg:/go/pkg/mod
|
||||
@@ -441,7 +518,7 @@ jobs:
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: node:20
|
||||
needs: [test-backend, test-frontend]
|
||||
needs: [test-backend, test-frontend, test-agent, e2e-agent]
|
||||
if: ${{ github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, '[skip-release]') }}
|
||||
outputs:
|
||||
should_release: ${{ steps.version_bump.outputs.should_release }}
|
||||
@@ -534,7 +611,7 @@ jobs:
|
||||
|
||||
build-only:
|
||||
runs-on: self-hosted
|
||||
needs: [test-backend, test-frontend]
|
||||
needs: [test-backend, test-frontend, test-agent, e2e-agent]
|
||||
if: ${{ github.ref == 'refs/heads/main' && contains(github.event.head_commit.message, '[skip-release]') }}
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -5,6 +5,7 @@ databasus-data/
|
||||
.env
|
||||
pgdata/
|
||||
docker-compose.yml
|
||||
!agent/e2e/docker-compose.yml
|
||||
node_modules/
|
||||
.idea
|
||||
/articles
|
||||
|
||||
@@ -41,3 +41,20 @@ repos:
|
||||
language: system
|
||||
files: ^backend/.*\.go$
|
||||
pass_filenames: false
|
||||
|
||||
# Agent checks
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: agent-format-and-lint
|
||||
name: Agent Format & Lint (golangci-lint)
|
||||
entry: bash -c "cd agent && golangci-lint fmt ./internal/... ./cmd/... && golangci-lint run ./internal/... ./cmd/..."
|
||||
language: system
|
||||
files: ^agent/.*\.go$
|
||||
pass_filenames: false
|
||||
|
||||
- id: agent-go-mod-tidy
|
||||
name: Agent Go Mod Tidy
|
||||
entry: bash -c "cd agent && go mod tidy"
|
||||
language: system
|
||||
files: ^agent/.*\.go$
|
||||
pass_filenames: false
|
||||
|
||||
34
AGENTS.md
34
AGENTS.md
@@ -73,6 +73,10 @@ This is NOT a strict set of rules, but a set of recommendations to help you writ
|
||||
- Patch the answer accordingly
|
||||
- Verify edge cases are handled
|
||||
|
||||
7. **Fix the reason, not the symptom:**
|
||||
- If you find a bug or issue, ask "Why did this happen?" and fix the root cause
|
||||
- Avoid quick fixes that don't address underlying problems
|
||||
|
||||
### Application guidelines:
|
||||
|
||||
**Scale your response to the task:**
|
||||
@@ -88,6 +92,36 @@ This is NOT a strict set of rules, but a set of recommendations to help you writ
|
||||
|
||||
## Backend guidelines
|
||||
|
||||
### Naming
|
||||
|
||||
Variables and functions naming are the most important part of code readability. Always choose descriptive and meaningful names that clearly indicate the purpose and intent of the code.
|
||||
|
||||
Avoid abbreviations, unless they are widely accepted and unambiguous (e.g., `ID`, `URL`, `HTTP`). Use consistent naming conventions across the codebase.
|
||||
|
||||
Do not use one-two letters. For example:
|
||||
|
||||
Bad:
|
||||
|
||||
```
|
||||
u := users.getUser()
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
r := bufio.NewReader(pr)
|
||||
```
|
||||
|
||||
Good:
|
||||
|
||||
```
|
||||
user := users.GetUser()
|
||||
|
||||
pipeReader, pipeWriter := io.Pipe()
|
||||
|
||||
bufferedReader := bufio.NewReader(pipeReader)
|
||||
```
|
||||
|
||||
Exclusion: widely used variables like "db", "ctx", "req", "res", etc.
|
||||
|
||||
### Code style
|
||||
|
||||
**Always place private methods to the bottom of file**
|
||||
|
||||
45
Dockerfile
45
Dockerfile
@@ -22,7 +22,7 @@ RUN npm run build
|
||||
|
||||
# ========= BUILD BACKEND =========
|
||||
# Backend build stage
|
||||
FROM --platform=$BUILDPLATFORM golang:1.24.9 AS backend-build
|
||||
FROM --platform=$BUILDPLATFORM golang:1.26.1 AS backend-build
|
||||
|
||||
# Make TARGET args available early so tools built here match the final image arch
|
||||
ARG TARGETOS
|
||||
@@ -66,6 +66,43 @@ RUN CGO_ENABLED=0 \
|
||||
go build -o /app/main ./cmd/main.go
|
||||
|
||||
|
||||
# ========= BUILD AGENT =========
|
||||
# Builds the databasus-agent CLI binary for BOTH x86_64 and ARM64.
|
||||
# Both architectures are always built because:
|
||||
# - Databasus server runs on one arch (e.g. amd64)
|
||||
# - The agent runs on remote PostgreSQL servers that may be on a
|
||||
# different arch (e.g. arm64)
|
||||
# - The backend serves the correct binary based on the agent's
|
||||
# ?arch= query parameter
|
||||
#
|
||||
# We cross-compile from the build platform (no QEMU needed) because the
|
||||
# agent is pure Go with zero C dependencies.
|
||||
# CGO_ENABLED=0 produces fully static binaries — no glibc/musl dependency,
|
||||
# so the agent runs on any Linux distro (Alpine, Debian, Ubuntu, RHEL, etc.).
|
||||
# APP_VERSION is baked into the binary via -ldflags so the agent can
|
||||
# compare its version against the server and auto-update when needed.
|
||||
FROM --platform=$BUILDPLATFORM golang:1.26.1 AS agent-build
|
||||
|
||||
ARG APP_VERSION=dev
|
||||
|
||||
WORKDIR /agent
|
||||
|
||||
COPY agent/go.mod ./
|
||||
RUN go mod download
|
||||
|
||||
COPY agent/ ./
|
||||
|
||||
# Build for x86_64 (amd64) — static binary, no glibc dependency
|
||||
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 \
|
||||
go build -ldflags "-X main.Version=${APP_VERSION}" \
|
||||
-o /agent-binaries/databasus-agent-linux-amd64 ./cmd/main.go
|
||||
|
||||
# Build for ARM64 (arm64) — static binary, no glibc dependency
|
||||
RUN CGO_ENABLED=0 GOOS=linux GOARCH=arm64 \
|
||||
go build -ldflags "-X main.Version=${APP_VERSION}" \
|
||||
-o /agent-binaries/databasus-agent-linux-arm64 ./cmd/main.go
|
||||
|
||||
|
||||
# ========= RUNTIME =========
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
@@ -220,6 +257,10 @@ COPY backend/migrations ./migrations
|
||||
# Copy UI files
|
||||
COPY --from=backend-build /app/ui/build ./ui/build
|
||||
|
||||
# Copy agent binaries (both architectures) — served by the backend
|
||||
# at GET /api/v1/system/agent?arch=amd64|arm64
|
||||
COPY --from=agent-build /agent-binaries ./agent-binaries
|
||||
|
||||
# Copy .env file (with fallback to .env.production.example)
|
||||
COPY backend/.env* /app/
|
||||
RUN if [ ! -f /app/.env ]; then \
|
||||
@@ -397,6 +438,8 @@ fi
|
||||
# Create database and set password for postgres user
|
||||
echo "Setting up database and user..."
|
||||
gosu postgres \$PG_BIN/psql -p 5437 -h localhost -d postgres << 'SQL'
|
||||
|
||||
# We use stub password, because internal DB is not exposed outside container
|
||||
ALTER USER postgres WITH PASSWORD 'Q1234567';
|
||||
SELECT 'CREATE DATABASE databasus OWNER postgres'
|
||||
WHERE NOT EXISTS (SELECT FROM pg_database WHERE datname = 'databasus')
|
||||
|
||||
@@ -261,12 +261,17 @@ Also you can join our large community of developers, DBAs and DevOps engineers o
|
||||
|
||||
There have been questions about AI usage in project development in issues and discussions. As the project focuses on security, reliability and production usage, it's important to explain how AI is used in the development process.
|
||||
|
||||
First of all, we are proud to say that Databasus has been accepted into both [Claude for Open Source](https://claude.com/contact-sales/claude-for-oss) by Anthropic and [Codex for Open Source](https://developers.openai.com/codex/community/codex-for-oss/) by OpenAI in March 2026. For us it is one more signal that the project was recognized as important open-source software and was as critical infrastructure worth supporting independently by two of the world's leading AI companies. Read more at [databasus.com/faq](https://databasus.com/faq#oss-programs).
|
||||
|
||||
Despite of this, we have the following rules how AI is used in the development process:
|
||||
|
||||
AI is used as a helper for:
|
||||
|
||||
- verification of code quality and searching for vulnerabilities
|
||||
- cleaning up and improving documentation, comments and code
|
||||
- assistance during development
|
||||
- double-checking PRs and commits after human review
|
||||
- additional security analysis of PRs via Codex Security
|
||||
|
||||
AI is not used for:
|
||||
|
||||
|
||||
1
agent/.env.example
Normal file
1
agent/.env.example
Normal file
@@ -0,0 +1 @@
|
||||
ENV_MODE=development
|
||||
26
agent/.gitignore
vendored
Normal file
26
agent/.gitignore
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
main
|
||||
.env
|
||||
docker-compose.yml
|
||||
!e2e/docker-compose.yml
|
||||
pgdata
|
||||
pgdata_test/
|
||||
mysqldata/
|
||||
mariadbdata/
|
||||
main.exe
|
||||
swagger/
|
||||
swagger/*
|
||||
swagger/docs.go
|
||||
swagger/swagger.json
|
||||
swagger/swagger.yaml
|
||||
postgresus-backend.exe
|
||||
databasus-backend.exe
|
||||
ui/build/*
|
||||
pgdata-for-restore/
|
||||
temp/
|
||||
cmd.exe
|
||||
temp/
|
||||
valkey-data/
|
||||
victoria-logs-data/
|
||||
databasus.json
|
||||
.test-tmp/
|
||||
databasus.log
|
||||
41
agent/.golangci.yml
Normal file
41
agent/.golangci.yml
Normal file
@@ -0,0 +1,41 @@
|
||||
version: "2"
|
||||
|
||||
run:
|
||||
timeout: 5m
|
||||
tests: false
|
||||
concurrency: 4
|
||||
|
||||
linters:
|
||||
default: standard
|
||||
enable:
|
||||
- funcorder
|
||||
- bodyclose
|
||||
- errorlint
|
||||
- gocritic
|
||||
- unconvert
|
||||
- misspell
|
||||
- errname
|
||||
- noctx
|
||||
- modernize
|
||||
|
||||
settings:
|
||||
errcheck:
|
||||
check-type-assertions: true
|
||||
|
||||
formatters:
|
||||
enable:
|
||||
- gofumpt
|
||||
- golines
|
||||
- gci
|
||||
|
||||
settings:
|
||||
golines:
|
||||
max-len: 120
|
||||
gofumpt:
|
||||
module-path: databasus-agent
|
||||
extra-rules: true
|
||||
gci:
|
||||
sections:
|
||||
- standard
|
||||
- default
|
||||
- localmodule
|
||||
26
agent/Makefile
Normal file
26
agent/Makefile
Normal file
@@ -0,0 +1,26 @@
|
||||
.PHONY: run build test lint e2e e2e-clean
|
||||
|
||||
# Usage: make run ARGS="start --pg-host localhost"
|
||||
run:
|
||||
go run cmd/main.go $(ARGS)
|
||||
|
||||
build:
|
||||
CGO_ENABLED=0 go build -ldflags "-X main.Version=$(VERSION)" -o databasus-agent ./cmd/main.go
|
||||
|
||||
test:
|
||||
go test -count=1 -failfast ./internal/...
|
||||
|
||||
lint:
|
||||
golangci-lint fmt ./cmd/... ./internal/... ./e2e/... && golangci-lint run ./cmd/... ./internal/... ./e2e/...
|
||||
|
||||
e2e:
|
||||
cd e2e && docker compose build
|
||||
cd e2e && docker compose run --rm e2e-agent-builder
|
||||
cd e2e && docker compose up -d e2e-postgres e2e-mock-server
|
||||
cd e2e && docker compose run --rm e2e-agent-runner
|
||||
cd e2e && docker compose run --rm e2e-agent-docker
|
||||
cd e2e && docker compose down -v
|
||||
|
||||
e2e-clean:
|
||||
cd e2e && docker compose down -v --rmi local
|
||||
rm -rf e2e/artifacts
|
||||
227
agent/cmd/main.go
Normal file
227
agent/cmd/main.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"databasus-agent/internal/config"
|
||||
"databasus-agent/internal/features/api"
|
||||
"databasus-agent/internal/features/start"
|
||||
"databasus-agent/internal/features/upgrade"
|
||||
"databasus-agent/internal/logger"
|
||||
)
|
||||
|
||||
var Version = "dev"
|
||||
|
||||
func main() {
|
||||
if len(os.Args) < 2 {
|
||||
printUsage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
switch os.Args[1] {
|
||||
case "start":
|
||||
runStart(os.Args[2:])
|
||||
case "_run":
|
||||
runDaemon(os.Args[2:])
|
||||
case "stop":
|
||||
runStop()
|
||||
case "status":
|
||||
runStatus()
|
||||
case "restore":
|
||||
runRestore(os.Args[2:])
|
||||
case "version":
|
||||
fmt.Println(Version)
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "unknown command: %s\n", os.Args[1])
|
||||
printUsage()
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func runStart(args []string) {
|
||||
fs := flag.NewFlagSet("start", flag.ExitOnError)
|
||||
|
||||
isSkipUpdate := fs.Bool("skip-update", false, "Skip auto-update check")
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.LoadFromJSONAndArgs(fs, args)
|
||||
|
||||
if err := cfg.SaveToJSON(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to save config: %v\n", err)
|
||||
}
|
||||
|
||||
log := logger.GetLogger()
|
||||
|
||||
isDev := checkIsDevelopment()
|
||||
runUpdateCheck(cfg.DatabasusHost, *isSkipUpdate, isDev, log)
|
||||
|
||||
if err := start.Start(cfg, Version, isDev, log); err != nil {
|
||||
if errors.Is(err, upgrade.ErrUpgradeRestart) {
|
||||
reexecAfterUpgrade(log)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func runDaemon(args []string) {
|
||||
fs := flag.NewFlagSet("_run", flag.ExitOnError)
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
log := logger.GetLogger()
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.LoadFromJSON()
|
||||
|
||||
if err := start.RunDaemon(cfg, Version, checkIsDevelopment(), log); err != nil {
|
||||
if errors.Is(err, upgrade.ErrUpgradeRestart) {
|
||||
reexecAfterUpgrade(log)
|
||||
}
|
||||
|
||||
log.Error("Agent exited with error", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func runStop() {
|
||||
log := logger.GetLogger()
|
||||
|
||||
if err := start.Stop(log); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func runStatus() {
|
||||
log := logger.GetLogger()
|
||||
|
||||
if err := start.Status(log); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func runRestore(args []string) {
|
||||
fs := flag.NewFlagSet("restore", flag.ExitOnError)
|
||||
|
||||
targetDir := fs.String("target-dir", "", "Target pgdata directory")
|
||||
backupID := fs.String("backup-id", "", "Full backup UUID (optional)")
|
||||
targetTime := fs.String("target-time", "", "PITR target time in RFC3339 (optional)")
|
||||
isYes := fs.Bool("yes", false, "Skip confirmation prompt")
|
||||
isSkipUpdate := fs.Bool("skip-update", false, "Skip auto-update check")
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.LoadFromJSONAndArgs(fs, args)
|
||||
|
||||
if err := cfg.SaveToJSON(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to save config: %v\n", err)
|
||||
}
|
||||
|
||||
log := logger.GetLogger()
|
||||
|
||||
isDev := checkIsDevelopment()
|
||||
runUpdateCheck(cfg.DatabasusHost, *isSkipUpdate, isDev, log)
|
||||
|
||||
log.Info("restore: stub — not yet implemented",
|
||||
"targetDir", *targetDir,
|
||||
"backupId", *backupID,
|
||||
"targetTime", *targetTime,
|
||||
"yes", *isYes,
|
||||
)
|
||||
}
|
||||
|
||||
func printUsage() {
|
||||
fmt.Fprintln(os.Stderr, "Usage: databasus-agent <command> [flags]")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, "Commands:")
|
||||
fmt.Fprintln(os.Stderr, " start Start the agent (WAL archiving + basebackups)")
|
||||
fmt.Fprintln(os.Stderr, " stop Stop a running agent")
|
||||
fmt.Fprintln(os.Stderr, " status Show agent status")
|
||||
fmt.Fprintln(os.Stderr, " restore Restore a database from backup")
|
||||
fmt.Fprintln(os.Stderr, " version Print agent version")
|
||||
}
|
||||
|
||||
func runUpdateCheck(host string, isSkipUpdate, isDev bool, log *slog.Logger) {
|
||||
if isSkipUpdate {
|
||||
return
|
||||
}
|
||||
|
||||
if host == "" {
|
||||
return
|
||||
}
|
||||
|
||||
apiClient := api.NewClient(host, "", log)
|
||||
|
||||
isUpgraded, err := upgrade.CheckAndUpdate(apiClient, Version, isDev, log)
|
||||
if err != nil {
|
||||
log.Error("Auto-update failed", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if isUpgraded {
|
||||
reexecAfterUpgrade(log)
|
||||
}
|
||||
}
|
||||
|
||||
func checkIsDevelopment() bool {
|
||||
dir, err := os.Getwd()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for range 3 {
|
||||
if data, err := os.ReadFile(filepath.Join(dir, ".env")); err == nil {
|
||||
return parseEnvMode(data)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
dir = filepath.Dir(dir)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func parseEnvMode(data []byte) bool {
|
||||
for line := range strings.SplitSeq(string(data), "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.SplitN(line, "=", 2)
|
||||
if len(parts) == 2 && strings.TrimSpace(parts[0]) == "ENV_MODE" {
|
||||
return strings.TrimSpace(parts[1]) == "development"
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func reexecAfterUpgrade(log *slog.Logger) {
|
||||
selfPath, err := os.Executable()
|
||||
if err != nil {
|
||||
log.Error("Failed to resolve executable for re-exec", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
log.Info("Re-executing after upgrade...")
|
||||
|
||||
if err := syscall.Exec(selfPath, os.Args, os.Environ()); err != nil {
|
||||
log.Error("Failed to re-exec after upgrade", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
1
agent/e2e/.gitignore
vendored
Normal file
1
agent/e2e/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
artifacts/
|
||||
13
agent/e2e/Dockerfile.agent-builder
Normal file
13
agent/e2e/Dockerfile.agent-builder
Normal file
@@ -0,0 +1,13 @@
|
||||
# Builds agent binaries with different versions so
|
||||
# we can test upgrade behavior (v1 -> v2)
|
||||
FROM golang:1.26.1-alpine AS build
|
||||
WORKDIR /src
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
RUN CGO_ENABLED=0 go build -ldflags "-X main.Version=v1.0.0" -o /out/agent-v1 ./cmd/main.go
|
||||
RUN CGO_ENABLED=0 go build -ldflags "-X main.Version=v2.0.0" -o /out/agent-v2 ./cmd/main.go
|
||||
|
||||
FROM alpine:3.21
|
||||
COPY --from=build /out/ /out/
|
||||
CMD ["cp", "-v", "/out/agent-v1", "/out/agent-v2", "/artifacts/"]
|
||||
8
agent/e2e/Dockerfile.agent-docker
Normal file
8
agent/e2e/Dockerfile.agent-docker
Normal file
@@ -0,0 +1,8 @@
|
||||
# Runs pg_basebackup-via-docker-exec test (test 5) which tests
|
||||
# that the agent can connect to Postgres inside Docker container
|
||||
FROM docker:27-cli
|
||||
|
||||
RUN apk add --no-cache bash curl
|
||||
|
||||
WORKDIR /tmp
|
||||
ENTRYPOINT []
|
||||
14
agent/e2e/Dockerfile.agent-runner
Normal file
14
agent/e2e/Dockerfile.agent-runner
Normal file
@@ -0,0 +1,14 @@
|
||||
# Runs upgrade and host-mode pg_basebackup tests (tests 1-4). Needs
|
||||
# Postgres client tools to be installed inside the system
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
ca-certificates curl gnupg2 postgresql-common && \
|
||||
/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
postgresql-client-17 && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /tmp
|
||||
ENTRYPOINT []
|
||||
10
agent/e2e/Dockerfile.mock-server
Normal file
10
agent/e2e/Dockerfile.mock-server
Normal file
@@ -0,0 +1,10 @@
|
||||
# Mock databasus API server for version checks and binary downloads. Just
|
||||
# serves static responses and files from the `artifacts` directory.
|
||||
FROM golang:1.26.1-alpine AS build
|
||||
WORKDIR /app
|
||||
COPY mock-server/main.go .
|
||||
RUN CGO_ENABLED=0 go build -o mock-server main.go
|
||||
|
||||
FROM alpine:3.21
|
||||
COPY --from=build /app/mock-server /usr/local/bin/mock-server
|
||||
ENTRYPOINT ["mock-server"]
|
||||
64
agent/e2e/docker-compose.yml
Normal file
64
agent/e2e/docker-compose.yml
Normal file
@@ -0,0 +1,64 @@
|
||||
services:
|
||||
e2e-agent-builder:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: e2e/Dockerfile.agent-builder
|
||||
volumes:
|
||||
- ./artifacts:/artifacts
|
||||
container_name: e2e-agent-builder
|
||||
|
||||
e2e-postgres:
|
||||
image: postgres:17
|
||||
environment:
|
||||
POSTGRES_DB: testdb
|
||||
POSTGRES_USER: testuser
|
||||
POSTGRES_PASSWORD: testpassword
|
||||
container_name: e2e-agent-postgres
|
||||
command: postgres -c wal_level=replica -c max_wal_senders=3
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U testuser -d testdb"]
|
||||
interval: 2s
|
||||
timeout: 5s
|
||||
retries: 30
|
||||
|
||||
e2e-mock-server:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.mock-server
|
||||
volumes:
|
||||
- ./artifacts:/artifacts:ro
|
||||
container_name: e2e-mock-server
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "-q", "--spider", "http://localhost:4050/health"]
|
||||
interval: 2s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
|
||||
e2e-agent-runner:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.agent-runner
|
||||
volumes:
|
||||
- ./artifacts:/opt/agent/artifacts:ro
|
||||
- ./scripts:/opt/agent/scripts:ro
|
||||
depends_on:
|
||||
e2e-postgres:
|
||||
condition: service_healthy
|
||||
e2e-mock-server:
|
||||
condition: service_healthy
|
||||
container_name: e2e-agent-runner
|
||||
command: ["bash", "/opt/agent/scripts/run-all.sh", "host"]
|
||||
|
||||
e2e-agent-docker:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.agent-docker
|
||||
volumes:
|
||||
- ./artifacts:/opt/agent/artifacts:ro
|
||||
- ./scripts:/opt/agent/scripts:ro
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
depends_on:
|
||||
e2e-postgres:
|
||||
condition: service_healthy
|
||||
container_name: e2e-agent-docker
|
||||
command: ["bash", "/opt/agent/scripts/run-all.sh", "docker"]
|
||||
108
agent/e2e/mock-server/main.go
Normal file
108
agent/e2e/mock-server/main.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type server struct {
|
||||
mu sync.RWMutex
|
||||
version string
|
||||
binaryPath string
|
||||
}
|
||||
|
||||
func main() {
|
||||
version := "v2.0.0"
|
||||
binaryPath := "/artifacts/agent-v2"
|
||||
port := "4050"
|
||||
|
||||
s := &server{version: version, binaryPath: binaryPath}
|
||||
|
||||
http.HandleFunc("/api/v1/system/version", s.handleVersion)
|
||||
http.HandleFunc("/api/v1/system/agent", s.handleAgentDownload)
|
||||
http.HandleFunc("/mock/set-version", s.handleSetVersion)
|
||||
http.HandleFunc("/mock/set-binary-path", s.handleSetBinaryPath)
|
||||
http.HandleFunc("/health", s.handleHealth)
|
||||
|
||||
addr := ":" + port
|
||||
log.Printf("Mock server starting on %s (version=%s, binary=%s)", addr, version, binaryPath)
|
||||
|
||||
if err := http.ListenAndServe(addr, nil); err != nil {
|
||||
log.Fatalf("Server failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *server) handleVersion(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.RLock()
|
||||
v := s.version
|
||||
s.mu.RUnlock()
|
||||
|
||||
log.Printf("GET /api/v1/system/version -> %s", v)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"version": v})
|
||||
}
|
||||
|
||||
func (s *server) handleAgentDownload(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.RLock()
|
||||
path := s.binaryPath
|
||||
s.mu.RUnlock()
|
||||
|
||||
log.Printf("GET /api/v1/system/agent (arch=%s) -> serving %s", r.URL.Query().Get("arch"), path)
|
||||
|
||||
http.ServeFile(w, r, path)
|
||||
}
|
||||
|
||||
func (s *server) handleSetVersion(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "POST only", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var body struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.version = body.Version
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Printf("POST /mock/set-version -> %s", body.Version)
|
||||
|
||||
_, _ = fmt.Fprintf(w, "version set to %s", body.Version)
|
||||
}
|
||||
|
||||
func (s *server) handleSetBinaryPath(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "POST only", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var body struct {
|
||||
BinaryPath string `json:"binaryPath"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.binaryPath = body.BinaryPath
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Printf("POST /mock/set-binary-path -> %s", body.BinaryPath)
|
||||
|
||||
_, _ = fmt.Fprintf(w, "binary path set to %s", body.BinaryPath)
|
||||
}
|
||||
|
||||
func (s *server) handleHealth(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}
|
||||
49
agent/e2e/scripts/run-all.sh
Normal file
49
agent/e2e/scripts/run-all.sh
Normal file
@@ -0,0 +1,49 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
MODE="${1:-host}"
|
||||
SCRIPT_DIR="$(dirname "$0")"
|
||||
PASSED=0
|
||||
FAILED=0
|
||||
|
||||
run_test() {
|
||||
local name="$1"
|
||||
local script="$2"
|
||||
|
||||
echo ""
|
||||
echo "========================================"
|
||||
echo " $name"
|
||||
echo "========================================"
|
||||
|
||||
if bash "$script"; then
|
||||
echo " PASSED: $name"
|
||||
PASSED=$((PASSED + 1))
|
||||
else
|
||||
echo " FAILED: $name"
|
||||
FAILED=$((FAILED + 1))
|
||||
fi
|
||||
}
|
||||
|
||||
if [ "$MODE" = "host" ]; then
|
||||
run_test "Test 1: Upgrade success (v1 -> v2)" "$SCRIPT_DIR/test-upgrade-success.sh"
|
||||
run_test "Test 2: Upgrade skip (version matches)" "$SCRIPT_DIR/test-upgrade-skip.sh"
|
||||
run_test "Test 3: Background upgrade (v1 -> v2 while running)" "$SCRIPT_DIR/test-upgrade-background.sh"
|
||||
run_test "Test 4: pg_basebackup in PATH" "$SCRIPT_DIR/test-pg-host-path.sh"
|
||||
run_test "Test 5: pg_basebackup via bindir" "$SCRIPT_DIR/test-pg-host-bindir.sh"
|
||||
|
||||
elif [ "$MODE" = "docker" ]; then
|
||||
run_test "Test 6: pg_basebackup via docker exec" "$SCRIPT_DIR/test-pg-docker-exec.sh"
|
||||
|
||||
else
|
||||
echo "Unknown mode: $MODE (expected 'host' or 'docker')"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "========================================"
|
||||
echo " Results: $PASSED passed, $FAILED failed"
|
||||
echo "========================================"
|
||||
|
||||
if [ "$FAILED" -gt 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
61
agent/e2e/scripts/test-pg-docker-exec.sh
Normal file
61
agent/e2e/scripts/test-pg-docker-exec.sh
Normal file
@@ -0,0 +1,61 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
ARTIFACTS="/opt/agent/artifacts"
|
||||
AGENT="/tmp/test-agent"
|
||||
PG_CONTAINER="e2e-agent-postgres"
|
||||
|
||||
# Cleanup from previous runs
|
||||
pkill -f "test-agent" 2>/dev/null || true
|
||||
for i in $(seq 1 20); do
|
||||
pgrep -f "test-agent" > /dev/null 2>&1 || break
|
||||
sleep 0.5
|
||||
done
|
||||
pkill -9 -f "test-agent" 2>/dev/null || true
|
||||
sleep 0.5
|
||||
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
|
||||
|
||||
# Copy agent binary
|
||||
cp "$ARTIFACTS/agent-v1" "$AGENT"
|
||||
chmod +x "$AGENT"
|
||||
|
||||
# Verify docker CLI works and PG container is accessible
|
||||
if ! docker exec "$PG_CONTAINER" pg_basebackup --version > /dev/null 2>&1; then
|
||||
echo "FAIL: Cannot reach pg_basebackup inside container $PG_CONTAINER (test setup issue)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Run start with --skip-update and pg-type=docker
|
||||
echo "Running agent start (pg_basebackup via docker exec)..."
|
||||
OUTPUT=$("$AGENT" start \
|
||||
--skip-update \
|
||||
--databasus-host http://e2e-mock-server:4050 \
|
||||
--db-id test-db-id \
|
||||
--token test-token \
|
||||
--pg-host e2e-postgres \
|
||||
--pg-port 5432 \
|
||||
--pg-user testuser \
|
||||
--pg-password testpassword \
|
||||
--pg-wal-dir /tmp/wal \
|
||||
--pg-type docker \
|
||||
--pg-docker-container-name "$PG_CONTAINER" 2>&1)
|
||||
|
||||
EXIT_CODE=$?
|
||||
echo "$OUTPUT"
|
||||
|
||||
if [ "$EXIT_CODE" -ne 0 ]; then
|
||||
echo "FAIL: Agent exited with code $EXIT_CODE"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! echo "$OUTPUT" | grep -q "pg_basebackup verified (docker)"; then
|
||||
echo "FAIL: Expected output to contain 'pg_basebackup verified (docker)'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! echo "$OUTPUT" | grep -q "PostgreSQL connection verified"; then
|
||||
echo "FAIL: Expected output to contain 'PostgreSQL connection verified'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "pg_basebackup found via docker exec and DB connection verified"
|
||||
67
agent/e2e/scripts/test-pg-host-bindir.sh
Normal file
67
agent/e2e/scripts/test-pg-host-bindir.sh
Normal file
@@ -0,0 +1,67 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
ARTIFACTS="/opt/agent/artifacts"
|
||||
AGENT="/tmp/test-agent"
|
||||
CUSTOM_BIN_DIR="/opt/pg/bin"
|
||||
|
||||
# Cleanup from previous runs
|
||||
pkill -f "test-agent" 2>/dev/null || true
|
||||
for i in $(seq 1 20); do
|
||||
pgrep -f "test-agent" > /dev/null 2>&1 || break
|
||||
sleep 0.5
|
||||
done
|
||||
pkill -9 -f "test-agent" 2>/dev/null || true
|
||||
sleep 0.5
|
||||
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
|
||||
|
||||
# Copy agent binary
|
||||
cp "$ARTIFACTS/agent-v1" "$AGENT"
|
||||
chmod +x "$AGENT"
|
||||
|
||||
# Move pg_basebackup out of PATH into custom directory
|
||||
mkdir -p "$CUSTOM_BIN_DIR"
|
||||
cp "$(which pg_basebackup)" "$CUSTOM_BIN_DIR/pg_basebackup"
|
||||
|
||||
# Hide the system one by prepending an empty dir to PATH
|
||||
export PATH="/opt/empty-path:$PATH"
|
||||
mkdir -p /opt/empty-path
|
||||
|
||||
# Verify pg_basebackup is NOT directly callable from default location
|
||||
# (we copied it, but the original is still there in debian — so we test
|
||||
# that the agent uses the custom dir, not PATH, by checking the output)
|
||||
|
||||
# Run start with --skip-update and custom bin dir
|
||||
echo "Running agent start (pg_basebackup via --pg-host-bin-dir)..."
|
||||
OUTPUT=$("$AGENT" start \
|
||||
--skip-update \
|
||||
--databasus-host http://e2e-mock-server:4050 \
|
||||
--db-id test-db-id \
|
||||
--token test-token \
|
||||
--pg-host e2e-postgres \
|
||||
--pg-port 5432 \
|
||||
--pg-user testuser \
|
||||
--pg-password testpassword \
|
||||
--pg-wal-dir /tmp/wal \
|
||||
--pg-type host \
|
||||
--pg-host-bin-dir "$CUSTOM_BIN_DIR" 2>&1)
|
||||
|
||||
EXIT_CODE=$?
|
||||
echo "$OUTPUT"
|
||||
|
||||
if [ "$EXIT_CODE" -ne 0 ]; then
|
||||
echo "FAIL: Agent exited with code $EXIT_CODE"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! echo "$OUTPUT" | grep -q "pg_basebackup verified"; then
|
||||
echo "FAIL: Expected output to contain 'pg_basebackup verified'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! echo "$OUTPUT" | grep -q "PostgreSQL connection verified"; then
|
||||
echo "FAIL: Expected output to contain 'PostgreSQL connection verified'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "pg_basebackup found via custom bin dir and DB connection verified"
|
||||
59
agent/e2e/scripts/test-pg-host-path.sh
Normal file
59
agent/e2e/scripts/test-pg-host-path.sh
Normal file
@@ -0,0 +1,59 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
ARTIFACTS="/opt/agent/artifacts"
|
||||
AGENT="/tmp/test-agent"
|
||||
|
||||
# Cleanup from previous runs
|
||||
pkill -f "test-agent" 2>/dev/null || true
|
||||
for i in $(seq 1 20); do
|
||||
pgrep -f "test-agent" > /dev/null 2>&1 || break
|
||||
sleep 0.5
|
||||
done
|
||||
pkill -9 -f "test-agent" 2>/dev/null || true
|
||||
sleep 0.5
|
||||
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
|
||||
|
||||
# Copy agent binary
|
||||
cp "$ARTIFACTS/agent-v1" "$AGENT"
|
||||
chmod +x "$AGENT"
|
||||
|
||||
# Verify pg_basebackup is in PATH
|
||||
if ! which pg_basebackup > /dev/null 2>&1; then
|
||||
echo "FAIL: pg_basebackup not found in PATH (test setup issue)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Run start with --skip-update and pg-type=host
|
||||
echo "Running agent start (pg_basebackup in PATH)..."
|
||||
OUTPUT=$("$AGENT" start \
|
||||
--skip-update \
|
||||
--databasus-host http://e2e-mock-server:4050 \
|
||||
--db-id test-db-id \
|
||||
--token test-token \
|
||||
--pg-host e2e-postgres \
|
||||
--pg-port 5432 \
|
||||
--pg-user testuser \
|
||||
--pg-password testpassword \
|
||||
--pg-wal-dir /tmp/wal \
|
||||
--pg-type host 2>&1)
|
||||
|
||||
EXIT_CODE=$?
|
||||
echo "$OUTPUT"
|
||||
|
||||
if [ "$EXIT_CODE" -ne 0 ]; then
|
||||
echo "FAIL: Agent exited with code $EXIT_CODE"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! echo "$OUTPUT" | grep -q "pg_basebackup verified"; then
|
||||
echo "FAIL: Expected output to contain 'pg_basebackup verified'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! echo "$OUTPUT" | grep -q "PostgreSQL connection verified"; then
|
||||
echo "FAIL: Expected output to contain 'PostgreSQL connection verified'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "pg_basebackup found in PATH and DB connection verified"
|
||||
90
agent/e2e/scripts/test-upgrade-background.sh
Normal file
90
agent/e2e/scripts/test-upgrade-background.sh
Normal file
@@ -0,0 +1,90 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
ARTIFACTS="/opt/agent/artifacts"
|
||||
AGENT="/tmp/test-agent"
|
||||
|
||||
# Cleanup from previous runs
|
||||
pkill -f "test-agent" 2>/dev/null || true
|
||||
for i in $(seq 1 20); do
|
||||
pgrep -f "test-agent" > /dev/null 2>&1 || break
|
||||
sleep 0.5
|
||||
done
|
||||
pkill -9 -f "test-agent" 2>/dev/null || true
|
||||
sleep 0.5
|
||||
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
|
||||
|
||||
# Set mock server to v1.0.0 (same as agent — no sync upgrade on start)
|
||||
curl -sf -X POST http://e2e-mock-server:4050/mock/set-version \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"version":"v1.0.0"}'
|
||||
|
||||
curl -sf -X POST http://e2e-mock-server:4050/mock/set-binary-path \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"binaryPath":"/artifacts/agent-v1"}'
|
||||
|
||||
# Copy v1 binary to writable location
|
||||
cp "$ARTIFACTS/agent-v1" "$AGENT"
|
||||
chmod +x "$AGENT"
|
||||
|
||||
# Verify initial version
|
||||
VERSION=$("$AGENT" version)
|
||||
if [ "$VERSION" != "v1.0.0" ]; then
|
||||
echo "FAIL: Expected initial version v1.0.0, got $VERSION"
|
||||
exit 1
|
||||
fi
|
||||
echo "Initial version: $VERSION"
|
||||
|
||||
# Start agent as daemon (versions match → no sync upgrade)
|
||||
mkdir -p /tmp/wal
|
||||
"$AGENT" start \
|
||||
--databasus-host http://e2e-mock-server:4050 \
|
||||
--db-id test-db-id \
|
||||
--token test-token \
|
||||
--pg-host e2e-postgres \
|
||||
--pg-port 5432 \
|
||||
--pg-user testuser \
|
||||
--pg-password testpassword \
|
||||
--pg-wal-dir /tmp/wal \
|
||||
--pg-type host
|
||||
|
||||
echo "Agent started as daemon, waiting for stabilization..."
|
||||
sleep 2
|
||||
|
||||
# Change mock server to v2.0.0 and point to v2 binary
|
||||
curl -sf -X POST http://e2e-mock-server:4050/mock/set-version \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"version":"v2.0.0"}'
|
||||
|
||||
curl -sf -X POST http://e2e-mock-server:4050/mock/set-binary-path \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"binaryPath":"/artifacts/agent-v2"}'
|
||||
|
||||
echo "Mock server updated to v2.0.0, waiting for background upgrade..."
|
||||
|
||||
# Poll for upgrade (timeout 60s, poll every 3s)
|
||||
DEADLINE=$((SECONDS + 60))
|
||||
while [ $SECONDS -lt $DEADLINE ]; do
|
||||
VERSION=$("$AGENT" version)
|
||||
if [ "$VERSION" = "v2.0.0" ]; then
|
||||
echo "Binary upgraded to $VERSION"
|
||||
break
|
||||
fi
|
||||
sleep 3
|
||||
done
|
||||
|
||||
VERSION=$("$AGENT" version)
|
||||
if [ "$VERSION" != "v2.0.0" ]; then
|
||||
echo "FAIL: Expected v2.0.0 after background upgrade, got $VERSION"
|
||||
cat databasus.log 2>/dev/null || true
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Verify agent is still running after restart
|
||||
sleep 2
|
||||
"$AGENT" status || true
|
||||
|
||||
# Cleanup
|
||||
"$AGENT" stop || true
|
||||
|
||||
echo "Background upgrade test passed"
|
||||
61
agent/e2e/scripts/test-upgrade-skip.sh
Normal file
61
agent/e2e/scripts/test-upgrade-skip.sh
Normal file
@@ -0,0 +1,61 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
ARTIFACTS="/opt/agent/artifacts"
|
||||
AGENT="/tmp/test-agent"
|
||||
|
||||
# Cleanup from previous runs
|
||||
pkill -f "test-agent" 2>/dev/null || true
|
||||
for i in $(seq 1 20); do
|
||||
pgrep -f "test-agent" > /dev/null 2>&1 || break
|
||||
sleep 0.5
|
||||
done
|
||||
pkill -9 -f "test-agent" 2>/dev/null || true
|
||||
sleep 0.5
|
||||
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
|
||||
|
||||
# Set mock server to return v1.0.0 (same as agent)
|
||||
curl -sf -X POST http://e2e-mock-server:4050/mock/set-version \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"version":"v1.0.0"}'
|
||||
|
||||
# Copy v1 binary to writable location
|
||||
cp "$ARTIFACTS/agent-v1" "$AGENT"
|
||||
chmod +x "$AGENT"
|
||||
|
||||
# Verify initial version
|
||||
VERSION=$("$AGENT" version)
|
||||
if [ "$VERSION" != "v1.0.0" ]; then
|
||||
echo "FAIL: Expected initial version v1.0.0, got $VERSION"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Run start — agent should see version matches and skip upgrade
|
||||
echo "Running agent start (expecting upgrade skip)..."
|
||||
OUTPUT=$("$AGENT" start \
|
||||
--databasus-host http://e2e-mock-server:4050 \
|
||||
--db-id test-db-id \
|
||||
--token test-token \
|
||||
--pg-host e2e-postgres \
|
||||
--pg-port 5432 \
|
||||
--pg-user testuser \
|
||||
--pg-password testpassword \
|
||||
--pg-wal-dir /tmp/wal \
|
||||
--pg-type host 2>&1) || true
|
||||
|
||||
echo "$OUTPUT"
|
||||
|
||||
# Verify output contains "up to date"
|
||||
if ! echo "$OUTPUT" | grep -qi "up to date"; then
|
||||
echo "FAIL: Expected output to contain 'up to date'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Verify binary is still v1
|
||||
VERSION=$("$AGENT" version)
|
||||
if [ "$VERSION" != "v1.0.0" ]; then
|
||||
echo "FAIL: Expected version v1.0.0 (unchanged), got $VERSION"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Upgrade correctly skipped, version still $VERSION"
|
||||
66
agent/e2e/scripts/test-upgrade-success.sh
Normal file
66
agent/e2e/scripts/test-upgrade-success.sh
Normal file
@@ -0,0 +1,66 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
ARTIFACTS="/opt/agent/artifacts"
|
||||
AGENT="/tmp/test-agent"
|
||||
|
||||
# Cleanup from previous runs
|
||||
pkill -f "test-agent" 2>/dev/null || true
|
||||
for i in $(seq 1 20); do
|
||||
pgrep -f "test-agent" > /dev/null 2>&1 || break
|
||||
sleep 0.5
|
||||
done
|
||||
pkill -9 -f "test-agent" 2>/dev/null || true
|
||||
sleep 0.5
|
||||
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
|
||||
|
||||
# Ensure mock server returns v2.0.0 and serves v2 binary
|
||||
curl -sf -X POST http://e2e-mock-server:4050/mock/set-version \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"version":"v2.0.0"}'
|
||||
|
||||
curl -sf -X POST http://e2e-mock-server:4050/mock/set-binary-path \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"binaryPath":"/artifacts/agent-v2"}'
|
||||
|
||||
# Copy v1 binary to writable location
|
||||
cp "$ARTIFACTS/agent-v1" "$AGENT"
|
||||
chmod +x "$AGENT"
|
||||
|
||||
# Verify initial version
|
||||
VERSION=$("$AGENT" version)
|
||||
if [ "$VERSION" != "v1.0.0" ]; then
|
||||
echo "FAIL: Expected initial version v1.0.0, got $VERSION"
|
||||
exit 1
|
||||
fi
|
||||
echo "Initial version: $VERSION"
|
||||
|
||||
# Run start — agent will:
|
||||
# 1. Fetch version from mock (v2.0.0 != v1.0.0)
|
||||
# 2. Download v2 binary from mock
|
||||
# 3. Replace itself on disk
|
||||
# 4. Re-exec with same args
|
||||
# 5. Re-exec'd v2 fetches version (v2.0.0 == v2.0.0) → skips update
|
||||
# 6. Proceeds to start → verifies pg_basebackup + DB → exits 0 (stub)
|
||||
echo "Running agent start (expecting upgrade v1 -> v2)..."
|
||||
OUTPUT=$("$AGENT" start \
|
||||
--databasus-host http://e2e-mock-server:4050 \
|
||||
--db-id test-db-id \
|
||||
--token test-token \
|
||||
--pg-host e2e-postgres \
|
||||
--pg-port 5432 \
|
||||
--pg-user testuser \
|
||||
--pg-password testpassword \
|
||||
--pg-wal-dir /tmp/wal \
|
||||
--pg-type host 2>&1) || true
|
||||
|
||||
echo "$OUTPUT"
|
||||
|
||||
# Verify binary on disk is now v2
|
||||
VERSION=$("$AGENT" version)
|
||||
if [ "$VERSION" != "v2.0.0" ]; then
|
||||
echo "FAIL: Expected upgraded version v2.0.0, got $VERSION"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Binary upgraded successfully to $VERSION"
|
||||
22
agent/go.mod
Normal file
22
agent/go.mod
Normal file
@@ -0,0 +1,22 @@
|
||||
module databasus-agent
|
||||
|
||||
go 1.26.1
|
||||
|
||||
require (
|
||||
github.com/go-resty/resty/v2 v2.17.2
|
||||
github.com/jackc/pgx/v5 v5.8.0
|
||||
github.com/klauspost/compress v1.18.4
|
||||
github.com/stretchr/testify v1.11.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.14.1 // indirect
|
||||
golang.org/x/net v0.43.0 // indirect
|
||||
golang.org/x/text v0.29.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
43
agent/go.sum
Normal file
43
agent/go.sum
Normal file
@@ -0,0 +1,43 @@
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-resty/resty/v2 v2.17.2 h1:FQW5oHYcIlkCNrMD2lloGScxcHJ0gkjshV3qcQAyHQk=
|
||||
github.com/go-resty/resty/v2 v2.17.2/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo=
|
||||
github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c=
|
||||
github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
||||
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
|
||||
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
|
||||
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
|
||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
272
agent/internal/config/config.go
Normal file
272
agent/internal/config/config.go
Normal file
@@ -0,0 +1,272 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"databasus-agent/internal/logger"
|
||||
)
|
||||
|
||||
var log = logger.GetLogger()
|
||||
|
||||
const configFileName = "databasus.json"
|
||||
|
||||
type Config struct {
|
||||
DatabasusHost string `json:"databasusHost"`
|
||||
DbID string `json:"dbId"`
|
||||
Token string `json:"token"`
|
||||
PgHost string `json:"pgHost"`
|
||||
PgPort int `json:"pgPort"`
|
||||
PgUser string `json:"pgUser"`
|
||||
PgPassword string `json:"pgPassword"`
|
||||
PgType string `json:"pgType"`
|
||||
PgHostBinDir string `json:"pgHostBinDir"`
|
||||
PgDockerContainerName string `json:"pgDockerContainerName"`
|
||||
PgWalDir string `json:"pgWalDir"`
|
||||
IsDeleteWalAfterUpload *bool `json:"deleteWalAfterUpload"`
|
||||
|
||||
flags parsedFlags
|
||||
}
|
||||
|
||||
// LoadFromJSONAndArgs reads databasus.json into the struct
|
||||
// and overrides JSON values with any explicitly provided CLI flags.
|
||||
func (c *Config) LoadFromJSONAndArgs(fs *flag.FlagSet, args []string) {
|
||||
c.loadFromJSON()
|
||||
c.applyDefaults()
|
||||
c.initSources()
|
||||
|
||||
c.flags.databasusHost = fs.String(
|
||||
"databasus-host",
|
||||
"",
|
||||
"Databasus server URL (e.g. http://your-server:4005)",
|
||||
)
|
||||
c.flags.dbID = fs.String("db-id", "", "Database ID")
|
||||
c.flags.token = fs.String("token", "", "Agent token")
|
||||
c.flags.pgHost = fs.String("pg-host", "", "PostgreSQL host")
|
||||
c.flags.pgPort = fs.Int("pg-port", 0, "PostgreSQL port")
|
||||
c.flags.pgUser = fs.String("pg-user", "", "PostgreSQL user")
|
||||
c.flags.pgPassword = fs.String("pg-password", "", "PostgreSQL password")
|
||||
c.flags.pgType = fs.String("pg-type", "", "PostgreSQL type: host or docker")
|
||||
c.flags.pgHostBinDir = fs.String("pg-host-bin-dir", "", "Path to PG bin directory (host mode)")
|
||||
c.flags.pgDockerContainerName = fs.String("pg-docker-container-name", "", "Docker container name (docker mode)")
|
||||
c.flags.pgWalDir = fs.String("pg-wal-dir", "", "Path to WAL queue directory")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
c.applyFlags()
|
||||
log.Info("========= Loading config ============")
|
||||
c.logConfigSources()
|
||||
log.Info("========= Config has been loaded ====")
|
||||
}
|
||||
|
||||
// SaveToJSON writes the current struct to databasus.json.
|
||||
func (c *Config) SaveToJSON() error {
|
||||
data, err := json.MarshalIndent(c, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(configFileName, data, 0o644)
|
||||
}
|
||||
|
||||
func (c *Config) LoadFromJSON() {
|
||||
c.loadFromJSON()
|
||||
c.applyDefaults()
|
||||
}
|
||||
|
||||
func (c *Config) loadFromJSON() {
|
||||
data, err := os.ReadFile(configFileName)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
log.Info("No databasus.json found, will create on save")
|
||||
return
|
||||
}
|
||||
|
||||
log.Warn("Failed to read databasus.json", "error", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, c); err != nil {
|
||||
log.Warn("Failed to parse databasus.json", "error", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("Configuration loaded from " + configFileName)
|
||||
}
|
||||
|
||||
func (c *Config) applyDefaults() {
|
||||
if c.PgPort == 0 {
|
||||
c.PgPort = 5432
|
||||
}
|
||||
|
||||
if c.PgType == "" {
|
||||
c.PgType = "host"
|
||||
}
|
||||
|
||||
if c.IsDeleteWalAfterUpload == nil {
|
||||
v := true
|
||||
c.IsDeleteWalAfterUpload = &v
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) initSources() {
|
||||
c.flags.sources = map[string]string{
|
||||
"databasus-host": "not configured",
|
||||
"db-id": "not configured",
|
||||
"token": "not configured",
|
||||
"pg-host": "not configured",
|
||||
"pg-port": "not configured",
|
||||
"pg-user": "not configured",
|
||||
"pg-password": "not configured",
|
||||
"pg-type": "not configured",
|
||||
"pg-host-bin-dir": "not configured",
|
||||
"pg-docker-container-name": "not configured",
|
||||
"pg-wal-dir": "not configured",
|
||||
"delete-wal-after-upload": "not configured",
|
||||
}
|
||||
|
||||
if c.DatabasusHost != "" {
|
||||
c.flags.sources["databasus-host"] = configFileName
|
||||
}
|
||||
|
||||
if c.DbID != "" {
|
||||
c.flags.sources["db-id"] = configFileName
|
||||
}
|
||||
|
||||
if c.Token != "" {
|
||||
c.flags.sources["token"] = configFileName
|
||||
}
|
||||
|
||||
if c.PgHost != "" {
|
||||
c.flags.sources["pg-host"] = configFileName
|
||||
}
|
||||
|
||||
// PgPort always has a value after applyDefaults
|
||||
c.flags.sources["pg-port"] = configFileName
|
||||
|
||||
if c.PgUser != "" {
|
||||
c.flags.sources["pg-user"] = configFileName
|
||||
}
|
||||
|
||||
if c.PgPassword != "" {
|
||||
c.flags.sources["pg-password"] = configFileName
|
||||
}
|
||||
|
||||
// PgType always has a value after applyDefaults
|
||||
c.flags.sources["pg-type"] = configFileName
|
||||
|
||||
if c.PgHostBinDir != "" {
|
||||
c.flags.sources["pg-host-bin-dir"] = configFileName
|
||||
}
|
||||
|
||||
if c.PgDockerContainerName != "" {
|
||||
c.flags.sources["pg-docker-container-name"] = configFileName
|
||||
}
|
||||
|
||||
if c.PgWalDir != "" {
|
||||
c.flags.sources["pg-wal-dir"] = configFileName
|
||||
}
|
||||
|
||||
// IsDeleteWalAfterUpload always has a value after applyDefaults
|
||||
c.flags.sources["delete-wal-after-upload"] = configFileName
|
||||
}
|
||||
|
||||
func (c *Config) applyFlags() {
|
||||
if c.flags.databasusHost != nil && *c.flags.databasusHost != "" {
|
||||
c.DatabasusHost = *c.flags.databasusHost
|
||||
c.flags.sources["databasus-host"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.dbID != nil && *c.flags.dbID != "" {
|
||||
c.DbID = *c.flags.dbID
|
||||
c.flags.sources["db-id"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.token != nil && *c.flags.token != "" {
|
||||
c.Token = *c.flags.token
|
||||
c.flags.sources["token"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgHost != nil && *c.flags.pgHost != "" {
|
||||
c.PgHost = *c.flags.pgHost
|
||||
c.flags.sources["pg-host"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgPort != nil && *c.flags.pgPort != 0 {
|
||||
c.PgPort = *c.flags.pgPort
|
||||
c.flags.sources["pg-port"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgUser != nil && *c.flags.pgUser != "" {
|
||||
c.PgUser = *c.flags.pgUser
|
||||
c.flags.sources["pg-user"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgPassword != nil && *c.flags.pgPassword != "" {
|
||||
c.PgPassword = *c.flags.pgPassword
|
||||
c.flags.sources["pg-password"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgType != nil && *c.flags.pgType != "" {
|
||||
c.PgType = *c.flags.pgType
|
||||
c.flags.sources["pg-type"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgHostBinDir != nil && *c.flags.pgHostBinDir != "" {
|
||||
c.PgHostBinDir = *c.flags.pgHostBinDir
|
||||
c.flags.sources["pg-host-bin-dir"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgDockerContainerName != nil && *c.flags.pgDockerContainerName != "" {
|
||||
c.PgDockerContainerName = *c.flags.pgDockerContainerName
|
||||
c.flags.sources["pg-docker-container-name"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgWalDir != nil && *c.flags.pgWalDir != "" {
|
||||
c.PgWalDir = *c.flags.pgWalDir
|
||||
c.flags.sources["pg-wal-dir"] = "command line args"
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) logConfigSources() {
|
||||
log.Info("databasus-host", "value", c.DatabasusHost, "source", c.flags.sources["databasus-host"])
|
||||
log.Info("db-id", "value", c.DbID, "source", c.flags.sources["db-id"])
|
||||
log.Info("token", "value", maskSensitive(c.Token), "source", c.flags.sources["token"])
|
||||
log.Info("pg-host", "value", c.PgHost, "source", c.flags.sources["pg-host"])
|
||||
log.Info("pg-port", "value", c.PgPort, "source", c.flags.sources["pg-port"])
|
||||
log.Info("pg-user", "value", c.PgUser, "source", c.flags.sources["pg-user"])
|
||||
log.Info("pg-password", "value", maskSensitive(c.PgPassword), "source", c.flags.sources["pg-password"])
|
||||
log.Info("pg-type", "value", c.PgType, "source", c.flags.sources["pg-type"])
|
||||
log.Info("pg-host-bin-dir", "value", c.PgHostBinDir, "source", c.flags.sources["pg-host-bin-dir"])
|
||||
log.Info(
|
||||
"pg-docker-container-name",
|
||||
"value",
|
||||
c.PgDockerContainerName,
|
||||
"source",
|
||||
c.flags.sources["pg-docker-container-name"],
|
||||
)
|
||||
log.Info("pg-wal-dir", "value", c.PgWalDir, "source", c.flags.sources["pg-wal-dir"])
|
||||
log.Info(
|
||||
"delete-wal-after-upload",
|
||||
"value",
|
||||
fmt.Sprintf("%v", *c.IsDeleteWalAfterUpload),
|
||||
"source",
|
||||
c.flags.sources["delete-wal-after-upload"],
|
||||
)
|
||||
}
|
||||
|
||||
func maskSensitive(value string) string {
|
||||
if value == "" {
|
||||
return "(not set)"
|
||||
}
|
||||
|
||||
visibleLen := max(len(value)/4, 1)
|
||||
|
||||
return value[:visibleLen] + "***"
|
||||
}
|
||||
301
agent/internal/config/config_test.go
Normal file
301
agent/internal/config/config_test.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_LoadFromJSONAndArgs_ValuesLoadedFromJSON(t *testing.T) {
|
||||
dir := setupTempDir(t)
|
||||
writeConfigJSON(t, dir, Config{
|
||||
DatabasusHost: "http://json-host:4005",
|
||||
DbID: "json-db-id",
|
||||
Token: "json-token",
|
||||
})
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{})
|
||||
|
||||
assert.Equal(t, "http://json-host:4005", cfg.DatabasusHost)
|
||||
assert.Equal(t, "json-db-id", cfg.DbID)
|
||||
assert.Equal(t, "json-token", cfg.Token)
|
||||
}
|
||||
|
||||
func Test_LoadFromJSONAndArgs_ValuesLoadedFromArgs_WhenNoJSON(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{
|
||||
"--databasus-host", "http://arg-host:4005",
|
||||
"--db-id", "arg-db-id",
|
||||
"--token", "arg-token",
|
||||
})
|
||||
|
||||
assert.Equal(t, "http://arg-host:4005", cfg.DatabasusHost)
|
||||
assert.Equal(t, "arg-db-id", cfg.DbID)
|
||||
assert.Equal(t, "arg-token", cfg.Token)
|
||||
}
|
||||
|
||||
func Test_LoadFromJSONAndArgs_ArgsOverrideJSON(t *testing.T) {
|
||||
dir := setupTempDir(t)
|
||||
writeConfigJSON(t, dir, Config{
|
||||
DatabasusHost: "http://json-host:4005",
|
||||
DbID: "json-db-id",
|
||||
Token: "json-token",
|
||||
})
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{
|
||||
"--databasus-host", "http://arg-host:9999",
|
||||
"--db-id", "arg-db-id-override",
|
||||
"--token", "arg-token-override",
|
||||
})
|
||||
|
||||
assert.Equal(t, "http://arg-host:9999", cfg.DatabasusHost)
|
||||
assert.Equal(t, "arg-db-id-override", cfg.DbID)
|
||||
assert.Equal(t, "arg-token-override", cfg.Token)
|
||||
}
|
||||
|
||||
func Test_LoadFromJSONAndArgs_PartialArgsOverrideJSON(t *testing.T) {
|
||||
dir := setupTempDir(t)
|
||||
writeConfigJSON(t, dir, Config{
|
||||
DatabasusHost: "http://json-host:4005",
|
||||
DbID: "json-db-id",
|
||||
Token: "json-token",
|
||||
})
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{
|
||||
"--databasus-host", "http://arg-host-only:4005",
|
||||
})
|
||||
|
||||
assert.Equal(t, "http://arg-host-only:4005", cfg.DatabasusHost)
|
||||
assert.Equal(t, "json-db-id", cfg.DbID)
|
||||
assert.Equal(t, "json-token", cfg.Token)
|
||||
}
|
||||
|
||||
func Test_SaveToJSON_ConfigSavedCorrectly(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
|
||||
deleteWal := true
|
||||
cfg := &Config{
|
||||
DatabasusHost: "http://save-host:4005",
|
||||
DbID: "save-db-id",
|
||||
Token: "save-token",
|
||||
IsDeleteWalAfterUpload: &deleteWal,
|
||||
}
|
||||
|
||||
err := cfg.SaveToJSON()
|
||||
require.NoError(t, err)
|
||||
|
||||
saved := readConfigJSON(t)
|
||||
|
||||
assert.Equal(t, "http://save-host:4005", saved.DatabasusHost)
|
||||
assert.Equal(t, "save-db-id", saved.DbID)
|
||||
assert.Equal(t, "save-token", saved.Token)
|
||||
}
|
||||
|
||||
func Test_SaveToJSON_AfterArgsOverrideJSON_SavedFileContainsMergedValues(t *testing.T) {
|
||||
dir := setupTempDir(t)
|
||||
writeConfigJSON(t, dir, Config{
|
||||
DatabasusHost: "http://json-host:4005",
|
||||
DbID: "json-db-id",
|
||||
Token: "json-token",
|
||||
})
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{
|
||||
"--databasus-host", "http://override-host:9999",
|
||||
})
|
||||
|
||||
err := cfg.SaveToJSON()
|
||||
require.NoError(t, err)
|
||||
|
||||
saved := readConfigJSON(t)
|
||||
|
||||
assert.Equal(t, "http://override-host:9999", saved.DatabasusHost)
|
||||
assert.Equal(t, "json-db-id", saved.DbID)
|
||||
assert.Equal(t, "json-token", saved.Token)
|
||||
}
|
||||
|
||||
func Test_LoadFromJSONAndArgs_PgFieldsLoadedFromJSON(t *testing.T) {
|
||||
dir := setupTempDir(t)
|
||||
deleteWal := false
|
||||
writeConfigJSON(t, dir, Config{
|
||||
DatabasusHost: "http://json-host:4005",
|
||||
DbID: "json-db-id",
|
||||
Token: "json-token",
|
||||
PgHost: "pg-json-host",
|
||||
PgPort: 5433,
|
||||
PgUser: "pg-json-user",
|
||||
PgPassword: "pg-json-pass",
|
||||
PgType: "docker",
|
||||
PgHostBinDir: "/usr/bin",
|
||||
PgDockerContainerName: "pg-container",
|
||||
PgWalDir: "/opt/wal",
|
||||
IsDeleteWalAfterUpload: &deleteWal,
|
||||
})
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{})
|
||||
|
||||
assert.Equal(t, "pg-json-host", cfg.PgHost)
|
||||
assert.Equal(t, 5433, cfg.PgPort)
|
||||
assert.Equal(t, "pg-json-user", cfg.PgUser)
|
||||
assert.Equal(t, "pg-json-pass", cfg.PgPassword)
|
||||
assert.Equal(t, "docker", cfg.PgType)
|
||||
assert.Equal(t, "/usr/bin", cfg.PgHostBinDir)
|
||||
assert.Equal(t, "pg-container", cfg.PgDockerContainerName)
|
||||
assert.Equal(t, "/opt/wal", cfg.PgWalDir)
|
||||
assert.Equal(t, false, *cfg.IsDeleteWalAfterUpload)
|
||||
}
|
||||
|
||||
func Test_LoadFromJSONAndArgs_PgFieldsLoadedFromArgs(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{
|
||||
"--pg-host", "arg-pg-host",
|
||||
"--pg-port", "5433",
|
||||
"--pg-user", "arg-pg-user",
|
||||
"--pg-password", "arg-pg-pass",
|
||||
"--pg-type", "docker",
|
||||
"--pg-host-bin-dir", "/custom/bin",
|
||||
"--pg-docker-container-name", "my-pg",
|
||||
"--pg-wal-dir", "/var/wal",
|
||||
})
|
||||
|
||||
assert.Equal(t, "arg-pg-host", cfg.PgHost)
|
||||
assert.Equal(t, 5433, cfg.PgPort)
|
||||
assert.Equal(t, "arg-pg-user", cfg.PgUser)
|
||||
assert.Equal(t, "arg-pg-pass", cfg.PgPassword)
|
||||
assert.Equal(t, "docker", cfg.PgType)
|
||||
assert.Equal(t, "/custom/bin", cfg.PgHostBinDir)
|
||||
assert.Equal(t, "my-pg", cfg.PgDockerContainerName)
|
||||
assert.Equal(t, "/var/wal", cfg.PgWalDir)
|
||||
}
|
||||
|
||||
func Test_LoadFromJSONAndArgs_PgArgsOverrideJSON(t *testing.T) {
|
||||
dir := setupTempDir(t)
|
||||
writeConfigJSON(t, dir, Config{
|
||||
PgHost: "json-host",
|
||||
PgPort: 5432,
|
||||
PgUser: "json-user",
|
||||
PgType: "host",
|
||||
PgWalDir: "/json/wal",
|
||||
})
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{
|
||||
"--pg-host", "arg-host",
|
||||
"--pg-port", "5433",
|
||||
"--pg-user", "arg-user",
|
||||
"--pg-type", "docker",
|
||||
"--pg-docker-container-name", "my-container",
|
||||
"--pg-wal-dir", "/arg/wal",
|
||||
})
|
||||
|
||||
assert.Equal(t, "arg-host", cfg.PgHost)
|
||||
assert.Equal(t, 5433, cfg.PgPort)
|
||||
assert.Equal(t, "arg-user", cfg.PgUser)
|
||||
assert.Equal(t, "docker", cfg.PgType)
|
||||
assert.Equal(t, "my-container", cfg.PgDockerContainerName)
|
||||
assert.Equal(t, "/arg/wal", cfg.PgWalDir)
|
||||
}
|
||||
|
||||
func Test_LoadFromJSONAndArgs_DefaultsApplied_WhenNoJSONAndNoArgs(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{})
|
||||
|
||||
assert.Equal(t, 5432, cfg.PgPort)
|
||||
assert.Equal(t, "host", cfg.PgType)
|
||||
require.NotNil(t, cfg.IsDeleteWalAfterUpload)
|
||||
assert.Equal(t, true, *cfg.IsDeleteWalAfterUpload)
|
||||
}
|
||||
|
||||
func Test_SaveToJSON_PgFieldsSavedCorrectly(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
|
||||
deleteWal := false
|
||||
cfg := &Config{
|
||||
DatabasusHost: "http://host:4005",
|
||||
DbID: "db-id",
|
||||
Token: "token",
|
||||
PgHost: "pg-host",
|
||||
PgPort: 5433,
|
||||
PgUser: "pg-user",
|
||||
PgPassword: "pg-pass",
|
||||
PgType: "docker",
|
||||
PgHostBinDir: "/usr/bin",
|
||||
PgDockerContainerName: "pg-container",
|
||||
PgWalDir: "/opt/wal",
|
||||
IsDeleteWalAfterUpload: &deleteWal,
|
||||
}
|
||||
|
||||
err := cfg.SaveToJSON()
|
||||
require.NoError(t, err)
|
||||
|
||||
saved := readConfigJSON(t)
|
||||
|
||||
assert.Equal(t, "pg-host", saved.PgHost)
|
||||
assert.Equal(t, 5433, saved.PgPort)
|
||||
assert.Equal(t, "pg-user", saved.PgUser)
|
||||
assert.Equal(t, "pg-pass", saved.PgPassword)
|
||||
assert.Equal(t, "docker", saved.PgType)
|
||||
assert.Equal(t, "/usr/bin", saved.PgHostBinDir)
|
||||
assert.Equal(t, "pg-container", saved.PgDockerContainerName)
|
||||
assert.Equal(t, "/opt/wal", saved.PgWalDir)
|
||||
require.NotNil(t, saved.IsDeleteWalAfterUpload)
|
||||
assert.Equal(t, false, *saved.IsDeleteWalAfterUpload)
|
||||
}
|
||||
|
||||
func setupTempDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
origDir, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
|
||||
dir := t.TempDir()
|
||||
require.NoError(t, os.Chdir(dir))
|
||||
|
||||
t.Cleanup(func() { os.Chdir(origDir) })
|
||||
|
||||
return dir
|
||||
}
|
||||
|
||||
func writeConfigJSON(t *testing.T, dir string, cfg Config) {
|
||||
t.Helper()
|
||||
|
||||
data, err := json.MarshalIndent(cfg, "", " ")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, os.WriteFile(dir+"/"+configFileName, data, 0o644))
|
||||
}
|
||||
|
||||
func readConfigJSON(t *testing.T) Config {
|
||||
t.Helper()
|
||||
|
||||
data, err := os.ReadFile(configFileName)
|
||||
require.NoError(t, err)
|
||||
|
||||
var cfg Config
|
||||
require.NoError(t, json.Unmarshal(data, &cfg))
|
||||
|
||||
return cfg
|
||||
}
|
||||
17
agent/internal/config/dto.go
Normal file
17
agent/internal/config/dto.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package config
|
||||
|
||||
type parsedFlags struct {
|
||||
databasusHost *string
|
||||
dbID *string
|
||||
token *string
|
||||
pgHost *string
|
||||
pgPort *int
|
||||
pgUser *string
|
||||
pgPassword *string
|
||||
pgType *string
|
||||
pgHostBinDir *string
|
||||
pgDockerContainerName *string
|
||||
pgWalDir *string
|
||||
|
||||
sources map[string]string
|
||||
}
|
||||
288
agent/internal/features/api/api.go
Normal file
288
agent/internal/features/api/api.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
chainValidPath = "/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup"
|
||||
nextBackupTimePath = "/api/v1/backups/postgres/wal/next-full-backup-time"
|
||||
walUploadPath = "/api/v1/backups/postgres/wal/upload/wal"
|
||||
fullStartPath = "/api/v1/backups/postgres/wal/upload/full-start"
|
||||
fullCompletePath = "/api/v1/backups/postgres/wal/upload/full-complete"
|
||||
reportErrorPath = "/api/v1/backups/postgres/wal/error"
|
||||
versionPath = "/api/v1/system/version"
|
||||
agentBinaryPath = "/api/v1/system/agent"
|
||||
|
||||
apiCallTimeout = 30 * time.Second
|
||||
maxRetryAttempts = 3
|
||||
retryBaseDelay = 1 * time.Second
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
json *resty.Client
|
||||
stream *resty.Client
|
||||
host string
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
func NewClient(host, token string, log *slog.Logger) *Client {
|
||||
setAuth := func(_ *resty.Client, req *resty.Request) error {
|
||||
if token != "" {
|
||||
req.SetHeader("Authorization", token)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
jsonClient := resty.New().
|
||||
SetTimeout(apiCallTimeout).
|
||||
SetRetryCount(maxRetryAttempts - 1).
|
||||
SetRetryWaitTime(retryBaseDelay).
|
||||
SetRetryMaxWaitTime(4 * retryBaseDelay).
|
||||
AddRetryCondition(func(resp *resty.Response, err error) bool {
|
||||
return err != nil || resp.StatusCode() >= 500
|
||||
}).
|
||||
OnBeforeRequest(setAuth)
|
||||
|
||||
streamClient := resty.New().
|
||||
OnBeforeRequest(setAuth)
|
||||
|
||||
return &Client{
|
||||
json: jsonClient,
|
||||
stream: streamClient,
|
||||
host: host,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) CheckWalChainValidity(ctx context.Context) (*WalChainValidityResponse, error) {
|
||||
var resp WalChainValidityResponse
|
||||
|
||||
httpResp, err := c.json.R().
|
||||
SetContext(ctx).
|
||||
SetResult(&resp).
|
||||
Get(c.buildURL(chainValidPath))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := c.checkResponse(httpResp, "check WAL chain validity"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (c *Client) GetNextFullBackupTime(ctx context.Context) (*NextFullBackupTimeResponse, error) {
|
||||
var resp NextFullBackupTimeResponse
|
||||
|
||||
httpResp, err := c.json.R().
|
||||
SetContext(ctx).
|
||||
SetResult(&resp).
|
||||
Get(c.buildURL(nextBackupTimePath))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := c.checkResponse(httpResp, "get next full backup time"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (c *Client) ReportBackupError(ctx context.Context, errMsg string) error {
|
||||
httpResp, err := c.json.R().
|
||||
SetContext(ctx).
|
||||
SetBody(reportErrorRequest{Error: errMsg}).
|
||||
Post(c.buildURL(reportErrorPath))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.checkResponse(httpResp, "report backup error")
|
||||
}
|
||||
|
||||
func (c *Client) UploadBasebackup(
|
||||
ctx context.Context,
|
||||
body io.Reader,
|
||||
) (*UploadBasebackupResponse, error) {
|
||||
resp, err := c.stream.R().
|
||||
SetContext(ctx).
|
||||
SetBody(body).
|
||||
SetHeader("Content-Type", "application/octet-stream").
|
||||
SetDoNotParseResponse(true).
|
||||
Post(c.buildURL(fullStartPath))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upload request: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.RawBody().Close() }()
|
||||
|
||||
if resp.StatusCode() != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.RawBody())
|
||||
|
||||
return nil, fmt.Errorf("upload failed with status %d: %s", resp.StatusCode(), string(respBody))
|
||||
}
|
||||
|
||||
var result UploadBasebackupResponse
|
||||
if err := json.NewDecoder(resp.RawBody()).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("decode upload response: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (c *Client) FinalizeBasebackup(
|
||||
ctx context.Context,
|
||||
backupID string,
|
||||
startSegment string,
|
||||
stopSegment string,
|
||||
) error {
|
||||
resp, err := c.json.R().
|
||||
SetContext(ctx).
|
||||
SetBody(finalizeBasebackupRequest{
|
||||
BackupID: backupID,
|
||||
StartSegment: startSegment,
|
||||
StopSegment: stopSegment,
|
||||
}).
|
||||
Post(c.buildURL(fullCompletePath))
|
||||
if err != nil {
|
||||
return fmt.Errorf("finalize request: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode() != http.StatusOK {
|
||||
return fmt.Errorf("finalize failed with status %d: %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) FinalizeBasebackupWithError(
|
||||
ctx context.Context,
|
||||
backupID string,
|
||||
errMsg string,
|
||||
) error {
|
||||
resp, err := c.json.R().
|
||||
SetContext(ctx).
|
||||
SetBody(finalizeBasebackupRequest{
|
||||
BackupID: backupID,
|
||||
Error: &errMsg,
|
||||
}).
|
||||
Post(c.buildURL(fullCompletePath))
|
||||
if err != nil {
|
||||
return fmt.Errorf("finalize-with-error request: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode() != http.StatusOK {
|
||||
return fmt.Errorf("finalize-with-error failed with status %d: %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) UploadWalSegment(
|
||||
ctx context.Context,
|
||||
segmentName string,
|
||||
body io.Reader,
|
||||
) (*UploadWalSegmentResult, error) {
|
||||
resp, err := c.stream.R().
|
||||
SetContext(ctx).
|
||||
SetBody(body).
|
||||
SetHeader("Content-Type", "application/octet-stream").
|
||||
SetHeader("X-Wal-Segment-Name", segmentName).
|
||||
SetDoNotParseResponse(true).
|
||||
Post(c.buildURL(walUploadPath))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upload request: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.RawBody().Close() }()
|
||||
|
||||
switch resp.StatusCode() {
|
||||
case http.StatusNoContent:
|
||||
return &UploadWalSegmentResult{IsGapDetected: false}, nil
|
||||
|
||||
case http.StatusConflict:
|
||||
var errResp uploadErrorResponse
|
||||
|
||||
if err := json.NewDecoder(resp.RawBody()).Decode(&errResp); err != nil {
|
||||
return &UploadWalSegmentResult{IsGapDetected: true}, nil
|
||||
}
|
||||
|
||||
return &UploadWalSegmentResult{
|
||||
IsGapDetected: true,
|
||||
ExpectedSegmentName: errResp.ExpectedSegmentName,
|
||||
ReceivedSegmentName: errResp.ReceivedSegmentName,
|
||||
}, nil
|
||||
|
||||
default:
|
||||
respBody, _ := io.ReadAll(resp.RawBody())
|
||||
|
||||
return nil, fmt.Errorf("upload failed with status %d: %s", resp.StatusCode(), string(respBody))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) FetchServerVersion(ctx context.Context) (string, error) {
|
||||
var ver versionResponse
|
||||
|
||||
httpResp, err := c.json.R().
|
||||
SetContext(ctx).
|
||||
SetResult(&ver).
|
||||
Get(c.buildURL(versionPath))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := c.checkResponse(httpResp, "fetch server version"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return ver.Version, nil
|
||||
}
|
||||
|
||||
func (c *Client) DownloadAgentBinary(ctx context.Context, arch, destPath string) error {
|
||||
resp, err := c.stream.R().
|
||||
SetContext(ctx).
|
||||
SetQueryParam("arch", arch).
|
||||
SetDoNotParseResponse(true).
|
||||
Get(c.buildURL(agentBinaryPath))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = resp.RawBody().Close() }()
|
||||
|
||||
if resp.StatusCode() != http.StatusOK {
|
||||
return fmt.Errorf("server returned %d for agent download", resp.StatusCode())
|
||||
}
|
||||
|
||||
f, err := os.Create(destPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
_, err = io.Copy(f, resp.RawBody())
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) buildURL(path string) string {
|
||||
return c.host + path
|
||||
}
|
||||
|
||||
func (c *Client) checkResponse(resp *resty.Response, method string) error {
|
||||
if resp.StatusCode() >= 400 {
|
||||
return fmt.Errorf("%s: server returned status %d: %s", method, resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
44
agent/internal/features/api/dto.go
Normal file
44
agent/internal/features/api/dto.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package api
|
||||
|
||||
import "time"
|
||||
|
||||
type WalChainValidityResponse struct {
|
||||
IsValid bool `json:"isValid"`
|
||||
Error string `json:"error,omitempty"`
|
||||
LastContiguousSegment string `json:"lastContiguousSegment,omitempty"`
|
||||
}
|
||||
|
||||
type NextFullBackupTimeResponse struct {
|
||||
NextFullBackupTime *time.Time `json:"nextFullBackupTime"`
|
||||
}
|
||||
|
||||
type UploadWalSegmentResult struct {
|
||||
IsGapDetected bool
|
||||
ExpectedSegmentName string
|
||||
ReceivedSegmentName string
|
||||
}
|
||||
|
||||
type reportErrorRequest struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
type versionResponse struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
type UploadBasebackupResponse struct {
|
||||
BackupID string `json:"backupId"`
|
||||
}
|
||||
|
||||
type finalizeBasebackupRequest struct {
|
||||
BackupID string `json:"backupId"`
|
||||
StartSegment string `json:"startSegment"`
|
||||
StopSegment string `json:"stopSegment"`
|
||||
Error *string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type uploadErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
ExpectedSegmentName string `json:"expectedSegmentName"`
|
||||
ReceivedSegmentName string `json:"receivedSegmentName"`
|
||||
}
|
||||
292
agent/internal/features/full_backup/backuper.go
Normal file
292
agent/internal/features/full_backup/backuper.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package full_backup
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"databasus-agent/internal/config"
|
||||
"databasus-agent/internal/features/api"
|
||||
)
|
||||
|
||||
const (
|
||||
checkInterval = 30 * time.Second
|
||||
retryDelay = 1 * time.Minute
|
||||
uploadTimeout = 30 * time.Minute
|
||||
)
|
||||
|
||||
var retryDelayOverride *time.Duration
|
||||
|
||||
type CmdBuilder func(ctx context.Context) *exec.Cmd
|
||||
|
||||
// FullBackuper runs pg_basebackup when the WAL chain is broken or a scheduled backup is due.
|
||||
//
|
||||
// Every 30 seconds it checks two conditions via the Databasus API:
|
||||
// 1. WAL chain validity — if broken or no full backup exists, triggers an immediate basebackup.
|
||||
// 2. Scheduled backup time — if the next full backup time has passed, triggers a basebackup.
|
||||
//
|
||||
// Only one basebackup runs at a time (guarded by atomic bool).
|
||||
// On failure the error is reported to the server and the backup retries after 1 minute, indefinitely.
|
||||
// WAL segment uploads (handled by wal.Streamer) continue independently and are not paused.
|
||||
//
|
||||
// pg_basebackup runs as "pg_basebackup -Ft -D - -X none --verbose". Stdout (tar) is zstd-compressed
|
||||
// and uploaded to the server. Stderr is parsed for WAL start/stop segment names (LSN → segment arithmetic).
|
||||
type FullBackuper struct {
|
||||
cfg *config.Config
|
||||
apiClient *api.Client
|
||||
log *slog.Logger
|
||||
isRunning atomic.Bool
|
||||
cmdBuilder CmdBuilder
|
||||
}
|
||||
|
||||
func NewFullBackuper(cfg *config.Config, apiClient *api.Client, log *slog.Logger) *FullBackuper {
|
||||
backuper := &FullBackuper{
|
||||
cfg: cfg,
|
||||
apiClient: apiClient,
|
||||
log: log,
|
||||
}
|
||||
|
||||
backuper.cmdBuilder = backuper.defaultCmdBuilder
|
||||
|
||||
return backuper
|
||||
}
|
||||
|
||||
func (backuper *FullBackuper) Run(ctx context.Context) {
|
||||
backuper.log.Info("Full backuper started")
|
||||
|
||||
backuper.checkAndRunIfNeeded(ctx)
|
||||
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
backuper.log.Info("Full backuper stopping")
|
||||
return
|
||||
case <-ticker.C:
|
||||
backuper.checkAndRunIfNeeded(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (backuper *FullBackuper) checkAndRunIfNeeded(ctx context.Context) {
|
||||
if backuper.isRunning.Load() {
|
||||
backuper.log.Debug("Skipping check: basebackup already in progress")
|
||||
return
|
||||
}
|
||||
|
||||
chainResp, err := backuper.apiClient.CheckWalChainValidity(ctx)
|
||||
if err != nil {
|
||||
backuper.log.Error("Failed to check WAL chain validity", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !chainResp.IsValid {
|
||||
backuper.log.Info("WAL chain is invalid, triggering basebackup",
|
||||
"error", chainResp.Error,
|
||||
"lastContiguousSegment", chainResp.LastContiguousSegment,
|
||||
)
|
||||
|
||||
backuper.runBasebackupWithRetry(ctx)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
nextTimeResp, err := backuper.apiClient.GetNextFullBackupTime(ctx)
|
||||
if err != nil {
|
||||
backuper.log.Error("Failed to check next full backup time", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if nextTimeResp.NextFullBackupTime == nil || !nextTimeResp.NextFullBackupTime.After(time.Now().UTC()) {
|
||||
backuper.log.Info("Scheduled full backup is due, triggering basebackup")
|
||||
backuper.runBasebackupWithRetry(ctx)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
backuper.log.Debug("No basebackup needed",
|
||||
"nextFullBackupTime", nextTimeResp.NextFullBackupTime,
|
||||
)
|
||||
}
|
||||
|
||||
func (backuper *FullBackuper) runBasebackupWithRetry(ctx context.Context) {
|
||||
if !backuper.isRunning.CompareAndSwap(false, true) {
|
||||
backuper.log.Debug("Skipping basebackup: already running")
|
||||
return
|
||||
}
|
||||
defer backuper.isRunning.Store(false)
|
||||
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
backuper.log.Info("Starting pg_basebackup")
|
||||
|
||||
err := backuper.executeAndUploadBasebackup(ctx)
|
||||
if err == nil {
|
||||
backuper.log.Info("Basebackup completed successfully")
|
||||
return
|
||||
}
|
||||
|
||||
backuper.log.Error("Basebackup failed", "error", err)
|
||||
backuper.reportError(ctx, err.Error())
|
||||
|
||||
delay := retryDelay
|
||||
if retryDelayOverride != nil {
|
||||
delay = *retryDelayOverride
|
||||
}
|
||||
|
||||
backuper.log.Info("Retrying basebackup after delay", "delay", delay)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(delay):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (backuper *FullBackuper) executeAndUploadBasebackup(ctx context.Context) error {
|
||||
cmd := backuper.cmdBuilder(ctx)
|
||||
|
||||
var stderrBuf bytes.Buffer
|
||||
cmd.Stderr = &stderrBuf
|
||||
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fmt.Errorf("start pg_basebackup: %w", err)
|
||||
}
|
||||
|
||||
// Phase 1: Stream compressed data via io.Pipe directly to the API.
|
||||
pipeReader, pipeWriter := io.Pipe()
|
||||
go backuper.compressAndStream(pipeWriter, stdoutPipe)
|
||||
|
||||
uploadCtx, cancel := context.WithTimeout(ctx, uploadTimeout)
|
||||
defer cancel()
|
||||
|
||||
uploadResp, uploadErr := backuper.apiClient.UploadBasebackup(uploadCtx, pipeReader)
|
||||
|
||||
cmdErr := cmd.Wait()
|
||||
|
||||
if uploadErr != nil {
|
||||
return fmt.Errorf("upload basebackup: %w", uploadErr)
|
||||
}
|
||||
|
||||
if cmdErr != nil {
|
||||
errMsg := fmt.Sprintf("pg_basebackup exited with error: %v (stderr: %s)", cmdErr, stderrBuf.String())
|
||||
_ = backuper.apiClient.FinalizeBasebackupWithError(ctx, uploadResp.BackupID, errMsg)
|
||||
|
||||
return fmt.Errorf("pg_basebackup: %w", cmdErr)
|
||||
}
|
||||
|
||||
// Phase 2: Parse stderr for WAL segments and finalize the backup.
|
||||
stderrStr := stderrBuf.String()
|
||||
backuper.log.Debug("pg_basebackup stderr", "stderr", stderrStr)
|
||||
|
||||
startSegment, stopSegment, err := ParseBasebackupStderr(stderrStr)
|
||||
if err != nil {
|
||||
errMsg := fmt.Sprintf("parse pg_basebackup stderr: %v", err)
|
||||
_ = backuper.apiClient.FinalizeBasebackupWithError(ctx, uploadResp.BackupID, errMsg)
|
||||
|
||||
return fmt.Errorf("parse pg_basebackup stderr: %w", err)
|
||||
}
|
||||
|
||||
backuper.log.Info("Basebackup WAL segments parsed",
|
||||
"startSegment", startSegment,
|
||||
"stopSegment", stopSegment,
|
||||
"backupId", uploadResp.BackupID,
|
||||
)
|
||||
|
||||
if err := backuper.apiClient.FinalizeBasebackup(ctx, uploadResp.BackupID, startSegment, stopSegment); err != nil {
|
||||
return fmt.Errorf("finalize basebackup: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (backuper *FullBackuper) compressAndStream(pipeWriter *io.PipeWriter, reader io.Reader) {
|
||||
encoder, err := zstd.NewWriter(pipeWriter,
|
||||
zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(5)),
|
||||
zstd.WithEncoderCRC(true),
|
||||
)
|
||||
if err != nil {
|
||||
_ = pipeWriter.CloseWithError(fmt.Errorf("create zstd encoder: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := io.Copy(encoder, reader); err != nil {
|
||||
_ = encoder.Close()
|
||||
_ = pipeWriter.CloseWithError(fmt.Errorf("compress: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := encoder.Close(); err != nil {
|
||||
_ = pipeWriter.CloseWithError(fmt.Errorf("close encoder: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
_ = pipeWriter.Close()
|
||||
}
|
||||
|
||||
func (backuper *FullBackuper) reportError(ctx context.Context, errMsg string) {
|
||||
if err := backuper.apiClient.ReportBackupError(ctx, errMsg); err != nil {
|
||||
backuper.log.Error("Failed to report error to server", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (backuper *FullBackuper) defaultCmdBuilder(ctx context.Context) *exec.Cmd {
|
||||
switch backuper.cfg.PgType {
|
||||
case "docker":
|
||||
return backuper.buildDockerCmd(ctx)
|
||||
default:
|
||||
return backuper.buildHostCmd(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (backuper *FullBackuper) buildHostCmd(ctx context.Context) *exec.Cmd {
|
||||
binary := "pg_basebackup"
|
||||
if backuper.cfg.PgHostBinDir != "" {
|
||||
binary = filepath.Join(backuper.cfg.PgHostBinDir, "pg_basebackup")
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, binary,
|
||||
"-Ft", "-D", "-", "-X", "none", "--verbose",
|
||||
"-h", backuper.cfg.PgHost,
|
||||
"-p", fmt.Sprintf("%d", backuper.cfg.PgPort),
|
||||
"-U", backuper.cfg.PgUser,
|
||||
)
|
||||
|
||||
cmd.Env = append(os.Environ(), "PGPASSWORD="+backuper.cfg.PgPassword)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (backuper *FullBackuper) buildDockerCmd(ctx context.Context) *exec.Cmd {
|
||||
cmd := exec.CommandContext(ctx, "docker", "exec",
|
||||
"-e", "PGPASSWORD="+backuper.cfg.PgPassword,
|
||||
"-i", backuper.cfg.PgDockerContainerName,
|
||||
"pg_basebackup",
|
||||
"-Ft", "-D", "-", "-X", "none", "--verbose",
|
||||
"-h", backuper.cfg.PgHost,
|
||||
"-p", fmt.Sprintf("%d", backuper.cfg.PgPort),
|
||||
"-U", backuper.cfg.PgUser,
|
||||
)
|
||||
|
||||
return cmd
|
||||
}
|
||||
671
agent/internal/features/full_backup/backuper_test.go
Normal file
671
agent/internal/features/full_backup/backuper_test.go
Normal file
@@ -0,0 +1,671 @@
|
||||
package full_backup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"databasus-agent/internal/config"
|
||||
"databasus-agent/internal/features/api"
|
||||
"databasus-agent/internal/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
testChainValidPath = "/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup"
|
||||
testNextBackupTimePath = "/api/v1/backups/postgres/wal/next-full-backup-time"
|
||||
testFullStartPath = "/api/v1/backups/postgres/wal/upload/full-start"
|
||||
testFullCompletePath = "/api/v1/backups/postgres/wal/upload/full-complete"
|
||||
testReportErrorPath = "/api/v1/backups/postgres/wal/error"
|
||||
|
||||
testBackupID = "test-backup-id-1234"
|
||||
)
|
||||
|
||||
func Test_RunFullBackup_WhenChainBroken_BasebackupTriggered(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var uploadReceived bool
|
||||
var uploadHeaders http.Header
|
||||
var finalizeReceived bool
|
||||
var finalizeBody map[string]any
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{
|
||||
IsValid: false,
|
||||
Error: "wal_chain_broken",
|
||||
LastContiguousSegment: "000000010000000100000011",
|
||||
})
|
||||
case testFullStartPath:
|
||||
mu.Lock()
|
||||
uploadReceived = true
|
||||
uploadHeaders = r.Header.Clone()
|
||||
mu.Unlock()
|
||||
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
mu.Lock()
|
||||
finalizeReceived = true
|
||||
_ = json.NewDecoder(r.Body).Decode(&finalizeBody)
|
||||
mu.Unlock()
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "test-backup-data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
waitForCondition(t, func() bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return finalizeReceived
|
||||
}, 5*time.Second)
|
||||
cancel()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
assert.True(t, uploadReceived)
|
||||
assert.Equal(t, "application/octet-stream", uploadHeaders.Get("Content-Type"))
|
||||
assert.Equal(t, "test-token", uploadHeaders.Get("Authorization"))
|
||||
|
||||
assert.True(t, finalizeReceived)
|
||||
assert.Equal(t, testBackupID, finalizeBody["backupId"])
|
||||
assert.Equal(t, "000000010000000000000002", finalizeBody["startSegment"])
|
||||
assert.Equal(t, "000000010000000000000002", finalizeBody["stopSegment"])
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenScheduledBackupDue_BasebackupTriggered(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var finalizeReceived bool
|
||||
|
||||
pastTime := time.Now().UTC().Add(-1 * time.Hour)
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{IsValid: true})
|
||||
case testNextBackupTimePath:
|
||||
writeJSON(w, api.NextFullBackupTimeResponse{NextFullBackupTime: &pastTime})
|
||||
case testFullStartPath:
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
mu.Lock()
|
||||
finalizeReceived = true
|
||||
mu.Unlock()
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "scheduled-backup-data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
waitForCondition(t, func() bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return finalizeReceived
|
||||
}, 5*time.Second)
|
||||
cancel()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
assert.True(t, finalizeReceived)
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenNoFullBackupExists_ImmediateBasebackupTriggered(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var finalizeReceived bool
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{
|
||||
IsValid: false,
|
||||
Error: "no_full_backup",
|
||||
})
|
||||
case testFullStartPath:
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
mu.Lock()
|
||||
finalizeReceived = true
|
||||
mu.Unlock()
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "first-backup-data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
waitForCondition(t, func() bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return finalizeReceived
|
||||
}, 5*time.Second)
|
||||
cancel()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
assert.True(t, finalizeReceived)
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenUploadFails_RetriesAfterDelay(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var uploadAttempts int
|
||||
var errorReported bool
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{
|
||||
IsValid: false,
|
||||
Error: "no_full_backup",
|
||||
})
|
||||
case testFullStartPath:
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
|
||||
mu.Lock()
|
||||
uploadAttempts++
|
||||
attempt := uploadAttempts
|
||||
mu.Unlock()
|
||||
|
||||
if attempt == 1 {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte(`{"error":"storage unavailable"}`))
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
case testReportErrorPath:
|
||||
mu.Lock()
|
||||
errorReported = true
|
||||
mu.Unlock()
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "retry-backup-data", validStderr())
|
||||
|
||||
origRetryDelay := retryDelay
|
||||
setRetryDelay(100 * time.Millisecond)
|
||||
defer setRetryDelay(origRetryDelay)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
waitForCondition(t, func() bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return uploadAttempts >= 2
|
||||
}, 10*time.Second)
|
||||
cancel()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
assert.GreaterOrEqual(t, uploadAttempts, 2)
|
||||
assert.True(t, errorReported)
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenAlreadyRunning_SkipsExecution(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var uploadCount int
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{
|
||||
IsValid: false,
|
||||
Error: "no_full_backup",
|
||||
})
|
||||
case testFullStartPath:
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
|
||||
mu.Lock()
|
||||
uploadCount++
|
||||
mu.Unlock()
|
||||
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr())
|
||||
|
||||
fb.isRunning.Store(true)
|
||||
|
||||
fb.checkAndRunIfNeeded(context.Background())
|
||||
|
||||
mu.Lock()
|
||||
count := uploadCount
|
||||
mu.Unlock()
|
||||
|
||||
assert.Equal(t, 0, count, "should not trigger backup when already running")
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenContextCancelled_StopsCleanly(t *testing.T) {
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{
|
||||
IsValid: false,
|
||||
Error: "no_full_backup",
|
||||
})
|
||||
case testFullStartPath:
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
case testFullCompletePath:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
case testReportErrorPath:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr())
|
||||
|
||||
origRetryDelay := retryDelay
|
||||
setRetryDelay(5 * time.Second)
|
||||
defer setRetryDelay(origRetryDelay)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
fb.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Run should have stopped after context cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenChainValidAndNotScheduled_NoBasebackupTriggered(t *testing.T) {
|
||||
var uploadReceived atomic.Bool
|
||||
|
||||
futureTime := time.Now().UTC().Add(24 * time.Hour)
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{IsValid: true})
|
||||
case testNextBackupTimePath:
|
||||
writeJSON(w, api.NextFullBackupTimeResponse{NextFullBackupTime: &futureTime})
|
||||
case testFullStartPath:
|
||||
uploadReceived.Store(true)
|
||||
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
assert.False(t, uploadReceived.Load(), "should not trigger backup when chain valid and not scheduled")
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenStderrParsingFails_FinalizesWithErrorAndRetries(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var errorReported bool
|
||||
var finalizeWithErrorReceived bool
|
||||
var finalizeBody map[string]any
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{
|
||||
IsValid: false,
|
||||
Error: "no_full_backup",
|
||||
})
|
||||
case testFullStartPath:
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
mu.Lock()
|
||||
finalizeWithErrorReceived = true
|
||||
_ = json.NewDecoder(r.Body).Decode(&finalizeBody)
|
||||
mu.Unlock()
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
case testReportErrorPath:
|
||||
mu.Lock()
|
||||
errorReported = true
|
||||
mu.Unlock()
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "data", "pg_basebackup: unexpected output with no LSN info")
|
||||
|
||||
origRetryDelay := retryDelay
|
||||
setRetryDelay(100 * time.Millisecond)
|
||||
defer setRetryDelay(origRetryDelay)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
waitForCondition(t, func() bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return errorReported
|
||||
}, 2*time.Second)
|
||||
cancel()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
assert.True(t, errorReported)
|
||||
assert.True(t, finalizeWithErrorReceived, "should finalize with error when stderr parsing fails")
|
||||
assert.Equal(t, testBackupID, finalizeBody["backupId"])
|
||||
assert.NotNil(t, finalizeBody["error"], "finalize should include error message")
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenNextBackupTimeNull_BasebackupTriggered(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var finalizeReceived bool
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{IsValid: true})
|
||||
case testNextBackupTimePath:
|
||||
writeJSON(w, api.NextFullBackupTimeResponse{NextFullBackupTime: nil})
|
||||
case testFullStartPath:
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
mu.Lock()
|
||||
finalizeReceived = true
|
||||
mu.Unlock()
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "first-run-data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
waitForCondition(t, func() bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return finalizeReceived
|
||||
}, 5*time.Second)
|
||||
cancel()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
assert.True(t, finalizeReceived)
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenChainValidityReturns401_NoBasebackupTriggered(t *testing.T) {
|
||||
var uploadReceived atomic.Bool
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"error":"invalid token"}`))
|
||||
case testFullStartPath:
|
||||
uploadReceived.Store(true)
|
||||
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
assert.False(t, uploadReceived.Load(), "should not trigger backup when API returns 401")
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenUploadSucceeds_BodyIsZstdCompressed(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var receivedBody []byte
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{
|
||||
IsValid: false,
|
||||
Error: "no_full_backup",
|
||||
})
|
||||
case testFullStartPath:
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
|
||||
mu.Lock()
|
||||
receivedBody = body
|
||||
mu.Unlock()
|
||||
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
originalContent := "test-backup-content-for-compression-check"
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, originalContent, validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
waitForCondition(t, func() bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return len(receivedBody) > 0
|
||||
}, 5*time.Second)
|
||||
cancel()
|
||||
|
||||
mu.Lock()
|
||||
body := receivedBody
|
||||
mu.Unlock()
|
||||
|
||||
decoder, err := zstd.NewReader(nil)
|
||||
require.NoError(t, err)
|
||||
defer decoder.Close()
|
||||
|
||||
decompressed, err := decoder.DecodeAll(body, nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, originalContent, string(decompressed))
|
||||
}
|
||||
|
||||
func newTestServer(t *testing.T, handler http.HandlerFunc) *httptest.Server {
|
||||
t.Helper()
|
||||
|
||||
server := httptest.NewServer(handler)
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
func newTestFullBackuper(serverURL string) *FullBackuper {
|
||||
cfg := &config.Config{
|
||||
DatabasusHost: serverURL,
|
||||
DbID: "test-db-id",
|
||||
Token: "test-token",
|
||||
PgHost: "localhost",
|
||||
PgPort: 5432,
|
||||
PgUser: "postgres",
|
||||
PgPassword: "password",
|
||||
PgType: "host",
|
||||
}
|
||||
|
||||
apiClient := api.NewClient(serverURL, cfg.Token, logger.GetLogger())
|
||||
|
||||
return NewFullBackuper(cfg, apiClient, logger.GetLogger())
|
||||
}
|
||||
|
||||
func mockCmdBuilder(t *testing.T, stdoutContent, stderrContent string) CmdBuilder {
|
||||
t.Helper()
|
||||
|
||||
return func(ctx context.Context) *exec.Cmd {
|
||||
cmd := exec.CommandContext(ctx, os.Args[0],
|
||||
"-test.run=TestHelperProcess",
|
||||
"--",
|
||||
stdoutContent,
|
||||
stderrContent,
|
||||
)
|
||||
|
||||
cmd.Env = append(os.Environ(), "GO_TEST_HELPER_PROCESS=1")
|
||||
|
||||
return cmd
|
||||
}
|
||||
}
|
||||
|
||||
func TestHelperProcess(t *testing.T) {
|
||||
if os.Getenv("GO_TEST_HELPER_PROCESS") != "1" {
|
||||
return
|
||||
}
|
||||
|
||||
args := os.Args
|
||||
for i, arg := range args {
|
||||
if arg == "--" {
|
||||
args = args[i+1:]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(args) >= 1 {
|
||||
_, _ = fmt.Fprint(os.Stdout, args[0])
|
||||
}
|
||||
|
||||
if len(args) >= 2 {
|
||||
_, _ = fmt.Fprint(os.Stderr, args[1])
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func validStderr() string {
|
||||
return `pg_basebackup: initiating base backup, waiting for checkpoint to complete
|
||||
pg_basebackup: checkpoint completed
|
||||
pg_basebackup: write-ahead log start point: 0/2000028, on timeline 1
|
||||
pg_basebackup: checkpoint redo point at 0/2000028
|
||||
pg_basebackup: write-ahead log end point: 0/2000100
|
||||
pg_basebackup: base backup completed`
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, v any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if err := json.NewEncoder(w).Encode(v); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func waitForCondition(t *testing.T, condition func() bool, timeout time.Duration) {
|
||||
t.Helper()
|
||||
|
||||
deadline := time.Now().Add(timeout)
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
if condition() {
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Fatalf("condition not met within %v", timeout)
|
||||
}
|
||||
|
||||
func setRetryDelay(d time.Duration) {
|
||||
retryDelayOverride = &d
|
||||
}
|
||||
|
||||
func init() {
|
||||
retryDelayOverride = nil
|
||||
}
|
||||
75
agent/internal/features/full_backup/stderr_parser.go
Normal file
75
agent/internal/features/full_backup/stderr_parser.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package full_backup
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const defaultWalSegmentSize uint32 = 16 * 1024 * 1024 // 16 MB
|
||||
|
||||
var (
|
||||
startLSNRegex = regexp.MustCompile(`checkpoint redo point at ([0-9A-Fa-f]+/[0-9A-Fa-f]+)`)
|
||||
stopLSNRegex = regexp.MustCompile(`write-ahead log end point: ([0-9A-Fa-f]+/[0-9A-Fa-f]+)`)
|
||||
)
|
||||
|
||||
func ParseBasebackupStderr(stderr string) (startSegment, stopSegment string, err error) {
|
||||
startMatch := startLSNRegex.FindStringSubmatch(stderr)
|
||||
if len(startMatch) < 2 {
|
||||
return "", "", fmt.Errorf("failed to parse start WAL location from pg_basebackup stderr")
|
||||
}
|
||||
|
||||
stopMatch := stopLSNRegex.FindStringSubmatch(stderr)
|
||||
if len(stopMatch) < 2 {
|
||||
return "", "", fmt.Errorf("failed to parse stop WAL location from pg_basebackup stderr")
|
||||
}
|
||||
|
||||
startSegment, err = LSNToSegmentName(startMatch[1], 1, defaultWalSegmentSize)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to convert start LSN to segment name: %w", err)
|
||||
}
|
||||
|
||||
stopSegment, err = LSNToSegmentName(stopMatch[1], 1, defaultWalSegmentSize)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to convert stop LSN to segment name: %w", err)
|
||||
}
|
||||
|
||||
return startSegment, stopSegment, nil
|
||||
}
|
||||
|
||||
func LSNToSegmentName(lsn string, timelineID, walSegmentSize uint32) (string, error) {
|
||||
high, low, err := parseLSN(lsn)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
segmentsPerXLogID := uint32(0x100000000 / uint64(walSegmentSize))
|
||||
logID := high
|
||||
segmentOffset := low / walSegmentSize
|
||||
|
||||
if segmentOffset >= segmentsPerXLogID {
|
||||
return "", fmt.Errorf("segment offset %d exceeds segments per XLogId %d", segmentOffset, segmentsPerXLogID)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%08X%08X%08X", timelineID, logID, segmentOffset), nil
|
||||
}
|
||||
|
||||
func parseLSN(lsn string) (high, low uint32, err error) {
|
||||
parts := strings.SplitN(lsn, "/", 2)
|
||||
if len(parts) != 2 {
|
||||
return 0, 0, fmt.Errorf("invalid LSN format: %q (expected X/Y)", lsn)
|
||||
}
|
||||
|
||||
highVal, err := strconv.ParseUint(parts[0], 16, 32)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("invalid LSN high part %q: %w", parts[0], err)
|
||||
}
|
||||
|
||||
lowVal, err := strconv.ParseUint(parts[1], 16, 32)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("invalid LSN low part %q: %w", parts[1], err)
|
||||
}
|
||||
|
||||
return uint32(highVal), uint32(lowVal), nil
|
||||
}
|
||||
162
agent/internal/features/full_backup/stderr_parser_test.go
Normal file
162
agent/internal/features/full_backup/stderr_parser_test.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package full_backup
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_ParseBasebackupStderr_WithPG17Output_ExtractsCorrectSegments(t *testing.T) {
|
||||
stderr := `pg_basebackup: initiating base backup, waiting for checkpoint to complete
|
||||
pg_basebackup: checkpoint completed
|
||||
pg_basebackup: write-ahead log start point: 0/2000028, on timeline 1
|
||||
pg_basebackup: starting background WAL receiver
|
||||
pg_basebackup: checkpoint redo point at 0/2000028
|
||||
pg_basebackup: write-ahead log end point: 0/2000100
|
||||
pg_basebackup: waiting for background process to finish streaming ...
|
||||
pg_basebackup: syncing data to disk ...
|
||||
pg_basebackup: renaming backup_manifest.tmp to backup_manifest
|
||||
pg_basebackup: base backup completed`
|
||||
|
||||
startSeg, stopSeg, err := ParseBasebackupStderr(stderr)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "000000010000000000000002", startSeg)
|
||||
assert.Equal(t, "000000010000000000000002", stopSeg)
|
||||
}
|
||||
|
||||
func Test_ParseBasebackupStderr_WithPG15Output_ExtractsCorrectSegments(t *testing.T) {
|
||||
stderr := `pg_basebackup: initiating base backup, waiting for checkpoint to complete
|
||||
pg_basebackup: checkpoint completed
|
||||
pg_basebackup: write-ahead log start point: 1/AB000028, on timeline 1
|
||||
pg_basebackup: checkpoint redo point at 1/AB000028
|
||||
pg_basebackup: write-ahead log end point: 1/AC000000
|
||||
pg_basebackup: base backup completed`
|
||||
|
||||
startSeg, stopSeg, err := ParseBasebackupStderr(stderr)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "0000000100000001000000AB", startSeg)
|
||||
assert.Equal(t, "0000000100000001000000AC", stopSeg)
|
||||
}
|
||||
|
||||
func Test_ParseBasebackupStderr_WithHighLogID_ExtractsCorrectSegments(t *testing.T) {
|
||||
stderr := `pg_basebackup: checkpoint redo point at A/FF000028
|
||||
pg_basebackup: write-ahead log end point: B/1000000`
|
||||
|
||||
startSeg, stopSeg, err := ParseBasebackupStderr(stderr)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "000000010000000A000000FF", startSeg)
|
||||
assert.Equal(t, "000000010000000B00000001", stopSeg)
|
||||
}
|
||||
|
||||
func Test_ParseBasebackupStderr_WhenStartLSNMissing_ReturnsError(t *testing.T) {
|
||||
stderr := `pg_basebackup: write-ahead log end point: 0/2000100
|
||||
pg_basebackup: base backup completed`
|
||||
|
||||
_, _, err := ParseBasebackupStderr(stderr)
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to parse start WAL location")
|
||||
}
|
||||
|
||||
func Test_ParseBasebackupStderr_WhenStopLSNMissing_ReturnsError(t *testing.T) {
|
||||
stderr := `pg_basebackup: checkpoint redo point at 0/2000028
|
||||
pg_basebackup: base backup completed`
|
||||
|
||||
_, _, err := ParseBasebackupStderr(stderr)
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to parse stop WAL location")
|
||||
}
|
||||
|
||||
func Test_ParseBasebackupStderr_WhenEmptyStderr_ReturnsError(t *testing.T) {
|
||||
_, _, err := ParseBasebackupStderr("")
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to parse start WAL location")
|
||||
}
|
||||
|
||||
func Test_LSNToSegmentName_WithBoundaryValues_ConvertsCorrectly(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
lsn string
|
||||
timeline uint32
|
||||
segSize uint32
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "first segment",
|
||||
lsn: "0/1000000",
|
||||
timeline: 1,
|
||||
segSize: 16 * 1024 * 1024,
|
||||
expected: "000000010000000000000001",
|
||||
},
|
||||
{
|
||||
name: "segment at boundary FF",
|
||||
lsn: "0/FF000000",
|
||||
timeline: 1,
|
||||
segSize: 16 * 1024 * 1024,
|
||||
expected: "0000000100000000000000FF",
|
||||
},
|
||||
{
|
||||
name: "segment in second log file",
|
||||
lsn: "1/0",
|
||||
timeline: 1,
|
||||
segSize: 16 * 1024 * 1024,
|
||||
expected: "000000010000000100000000",
|
||||
},
|
||||
{
|
||||
name: "segment with offset within 16MB",
|
||||
lsn: "0/200ABCD",
|
||||
timeline: 1,
|
||||
segSize: 16 * 1024 * 1024,
|
||||
expected: "000000010000000000000002",
|
||||
},
|
||||
{
|
||||
name: "zero LSN",
|
||||
lsn: "0/0",
|
||||
timeline: 1,
|
||||
segSize: 16 * 1024 * 1024,
|
||||
expected: "000000010000000000000000",
|
||||
},
|
||||
{
|
||||
name: "high timeline ID",
|
||||
lsn: "0/1000000",
|
||||
timeline: 2,
|
||||
segSize: 16 * 1024 * 1024,
|
||||
expected: "000000020000000000000001",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := LSNToSegmentName(tt.lsn, tt.timeline, tt.segSize)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_LSNToSegmentName_WithInvalidLSN_ReturnsError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
lsn string
|
||||
}{
|
||||
{name: "no slash", lsn: "012345"},
|
||||
{name: "empty string", lsn: ""},
|
||||
{name: "invalid hex high", lsn: "GG/0"},
|
||||
{name: "invalid hex low", lsn: "0/ZZ"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := LSNToSegmentName(tt.lsn, 1, 16*1024*1024)
|
||||
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
121
agent/internal/features/start/daemon.go
Normal file
121
agent/internal/features/start/daemon.go
Normal file
@@ -0,0 +1,121 @@
|
||||
//go:build !windows
|
||||
|
||||
package start
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
logFileName = "databasus.log"
|
||||
stopTimeout = 30 * time.Second
|
||||
stopPollInterval = 500 * time.Millisecond
|
||||
daemonStartupDelay = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
func Stop(log *slog.Logger) error {
|
||||
pid, err := ReadLockFilePID()
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return errors.New("agent is not running (no lock file found)")
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to read lock file: %w", err)
|
||||
}
|
||||
|
||||
if !isProcessAlive(pid) {
|
||||
_ = os.Remove(lockFileName)
|
||||
return fmt.Errorf("agent is not running (stale lock file removed, PID %d)", pid)
|
||||
}
|
||||
|
||||
log.Info("Sending SIGTERM to agent", "pid", pid)
|
||||
|
||||
if err := syscall.Kill(pid, syscall.SIGTERM); err != nil {
|
||||
return fmt.Errorf("failed to send SIGTERM to PID %d: %w", pid, err)
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(stopTimeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if !isProcessAlive(pid) {
|
||||
log.Info("Agent stopped", "pid", pid)
|
||||
return nil
|
||||
}
|
||||
|
||||
time.Sleep(stopPollInterval)
|
||||
}
|
||||
|
||||
return fmt.Errorf("agent (PID %d) did not stop within %s — process may be stuck", pid, stopTimeout)
|
||||
}
|
||||
|
||||
func Status(log *slog.Logger) error {
|
||||
pid, err := ReadLockFilePID()
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
fmt.Println("Agent is not running")
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to read lock file: %w", err)
|
||||
}
|
||||
|
||||
if isProcessAlive(pid) {
|
||||
fmt.Printf("Agent is running (PID %d)\n", pid)
|
||||
} else {
|
||||
fmt.Println("Agent is not running (stale lock file)")
|
||||
_ = os.Remove(lockFileName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func spawnDaemon(log *slog.Logger) (int, error) {
|
||||
execPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to resolve executable path: %w", err)
|
||||
}
|
||||
|
||||
args := []string{"_run"}
|
||||
|
||||
logFile, err := os.OpenFile(logFileName, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to open log file %s: %w", logFileName, err)
|
||||
}
|
||||
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
_ = logFile.Close()
|
||||
return 0, fmt.Errorf("failed to get working directory: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(context.Background(), execPath, args...)
|
||||
cmd.Dir = cwd
|
||||
cmd.Stderr = logFile
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
_ = logFile.Close()
|
||||
return 0, fmt.Errorf("failed to start daemon process: %w", err)
|
||||
}
|
||||
|
||||
pid := cmd.Process.Pid
|
||||
|
||||
// Detach — we don't wait for the child
|
||||
_ = logFile.Close()
|
||||
|
||||
time.Sleep(daemonStartupDelay)
|
||||
|
||||
if !isProcessAlive(pid) {
|
||||
return 0, fmt.Errorf("daemon process (PID %d) exited immediately — check %s for details", pid, logFileName)
|
||||
}
|
||||
|
||||
log.Info("Daemon spawned", "pid", pid, "log", logFileName)
|
||||
|
||||
return pid, nil
|
||||
}
|
||||
20
agent/internal/features/start/daemon_windows.go
Normal file
20
agent/internal/features/start/daemon_windows.go
Normal file
@@ -0,0 +1,20 @@
|
||||
//go:build windows
|
||||
|
||||
package start
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
func Stop(log *slog.Logger) error {
|
||||
return errors.New("stop is not supported on Windows — use Ctrl+C in the terminal where the agent is running")
|
||||
}
|
||||
|
||||
func Status(log *slog.Logger) error {
|
||||
return errors.New("status is not supported on Windows — check the terminal where the agent is running")
|
||||
}
|
||||
|
||||
func spawnDaemon(_ *slog.Logger) (int, error) {
|
||||
return 0, errors.New("daemon mode is not supported on Windows")
|
||||
}
|
||||
117
agent/internal/features/start/lock.go
Normal file
117
agent/internal/features/start/lock.go
Normal file
@@ -0,0 +1,117 @@
|
||||
//go:build !windows
|
||||
|
||||
package start
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
const lockFileName = "databasus.lock"
|
||||
|
||||
func AcquireLock(log *slog.Logger) (*os.File, error) {
|
||||
f, err := os.OpenFile(lockFileName, os.O_CREATE|os.O_RDWR, 0o644)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open lock file: %w", err)
|
||||
}
|
||||
|
||||
err = syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB)
|
||||
if err == nil {
|
||||
if err := writePID(f); err != nil {
|
||||
_ = f.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Info("Process lock acquired", "pid", os.Getpid(), "lockFile", lockFileName)
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
if !errors.Is(err, syscall.EWOULDBLOCK) {
|
||||
_ = f.Close()
|
||||
return nil, fmt.Errorf("failed to acquire lock: %w", err)
|
||||
}
|
||||
|
||||
pid, pidErr := readLockPID(f)
|
||||
_ = f.Close()
|
||||
|
||||
if pidErr != nil {
|
||||
return nil, fmt.Errorf("another instance is already running")
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("another instance is already running (PID %d)", pid)
|
||||
}
|
||||
|
||||
func ReleaseLock(f *os.File) {
|
||||
_ = syscall.Flock(int(f.Fd()), syscall.LOCK_UN)
|
||||
_ = f.Close()
|
||||
_ = os.Remove(lockFileName)
|
||||
}
|
||||
|
||||
func ReadLockFilePID() (int, error) {
|
||||
f, err := os.Open(lockFileName)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
return readLockPID(f)
|
||||
}
|
||||
|
||||
func writePID(f *os.File) error {
|
||||
if err := f.Truncate(0); err != nil {
|
||||
return fmt.Errorf("failed to truncate lock file: %w", err)
|
||||
}
|
||||
|
||||
if _, err := f.Seek(0, io.SeekStart); err != nil {
|
||||
return fmt.Errorf("failed to seek lock file: %w", err)
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(f, "%d\n", os.Getpid()); err != nil {
|
||||
return fmt.Errorf("failed to write PID to lock file: %w", err)
|
||||
}
|
||||
|
||||
return f.Sync()
|
||||
}
|
||||
|
||||
func readLockPID(f *os.File) (int, error) {
|
||||
if _, err := f.Seek(0, io.SeekStart); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
s := strings.TrimSpace(string(data))
|
||||
if s == "" {
|
||||
return 0, errors.New("lock file is empty")
|
||||
}
|
||||
|
||||
pid, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid PID in lock file: %w", err)
|
||||
}
|
||||
|
||||
return pid, nil
|
||||
}
|
||||
|
||||
func isProcessAlive(pid int) bool {
|
||||
err := syscall.Kill(pid, 0)
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if errors.Is(err, syscall.EPERM) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
148
agent/internal/features/start/lock_test.go
Normal file
148
agent/internal/features/start/lock_test.go
Normal file
@@ -0,0 +1,148 @@
|
||||
//go:build !windows
|
||||
|
||||
package start
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"databasus-agent/internal/logger"
|
||||
)
|
||||
|
||||
func Test_AcquireLock_LockFileCreatedWithPID(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
log := logger.GetLogger()
|
||||
|
||||
lockFile, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
data, err := os.ReadFile(lockFileName)
|
||||
require.NoError(t, err)
|
||||
|
||||
pid, err := strconv.Atoi(strings.TrimSpace(string(data)))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, os.Getpid(), pid)
|
||||
}
|
||||
|
||||
func Test_AcquireLock_SecondAcquireFails_WhenFirstHeld(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
log := logger.GetLogger()
|
||||
|
||||
first, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(first)
|
||||
|
||||
second, err := AcquireLock(log)
|
||||
assert.Nil(t, second)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "another instance is already running")
|
||||
assert.Contains(t, err.Error(), fmt.Sprintf("PID %d", os.Getpid()))
|
||||
}
|
||||
|
||||
func Test_AcquireLock_StaleLockReacquired_WhenProcessDead(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
log := logger.GetLogger()
|
||||
|
||||
err := os.WriteFile(lockFileName, []byte("999999999\n"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
lockFile, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
data, err := os.ReadFile(lockFileName)
|
||||
require.NoError(t, err)
|
||||
|
||||
pid, err := strconv.Atoi(strings.TrimSpace(string(data)))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, os.Getpid(), pid)
|
||||
}
|
||||
|
||||
func Test_ReleaseLock_LockFileRemoved(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
log := logger.GetLogger()
|
||||
|
||||
lockFile, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
|
||||
ReleaseLock(lockFile)
|
||||
|
||||
_, err = os.Stat(lockFileName)
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
}
|
||||
|
||||
func Test_AcquireLock_ReacquiredAfterRelease(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
log := logger.GetLogger()
|
||||
|
||||
first, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
ReleaseLock(first)
|
||||
|
||||
second, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(second)
|
||||
|
||||
data, err := os.ReadFile(lockFileName)
|
||||
require.NoError(t, err)
|
||||
|
||||
pid, err := strconv.Atoi(strings.TrimSpace(string(data)))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, os.Getpid(), pid)
|
||||
}
|
||||
|
||||
func Test_isProcessAlive_ReturnsTrueForSelf(t *testing.T) {
|
||||
assert.True(t, isProcessAlive(os.Getpid()))
|
||||
}
|
||||
|
||||
func Test_isProcessAlive_ReturnsFalseForNonExistentPID(t *testing.T) {
|
||||
assert.False(t, isProcessAlive(999999999))
|
||||
}
|
||||
|
||||
func Test_readLockPID_ParsesValidPID(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
|
||||
f, err := os.CreateTemp("", "lock-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(f.Name())
|
||||
|
||||
_, err = f.WriteString("12345\n")
|
||||
require.NoError(t, err)
|
||||
|
||||
pid, err := readLockPID(f)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 12345, pid)
|
||||
}
|
||||
|
||||
func Test_readLockPID_ReturnsErrorForEmptyFile(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
|
||||
f, err := os.CreateTemp("", "lock-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(f.Name())
|
||||
|
||||
_, err = readLockPID(f)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "lock file is empty")
|
||||
}
|
||||
|
||||
func setupTempDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
origDir, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
|
||||
dir := t.TempDir()
|
||||
require.NoError(t, os.Chdir(dir))
|
||||
|
||||
t.Cleanup(func() { _ = os.Chdir(origDir) })
|
||||
|
||||
return dir
|
||||
}
|
||||
90
agent/internal/features/start/lock_watcher.go
Normal file
90
agent/internal/features/start/lock_watcher.go
Normal file
@@ -0,0 +1,90 @@
|
||||
//go:build !windows
|
||||
|
||||
package start
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
const lockWatchInterval = 5 * time.Second
|
||||
|
||||
type LockWatcher struct {
|
||||
originalInode uint64
|
||||
cancel context.CancelFunc
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
func NewLockWatcher(lockFile *os.File, cancel context.CancelFunc, log *slog.Logger) (*LockWatcher, error) {
|
||||
inode, err := getFileInode(lockFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &LockWatcher{
|
||||
originalInode: inode,
|
||||
cancel: cancel,
|
||||
log: log,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *LockWatcher) Run(ctx context.Context) {
|
||||
ticker := time.NewTicker(lockWatchInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
w.check()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *LockWatcher) check() {
|
||||
info, err := os.Stat(lockFileName)
|
||||
if err != nil {
|
||||
w.log.Error("Lock file disappeared, shutting down", "file", lockFileName, "error", err)
|
||||
w.cancel()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
currentInode, err := getStatInode(info)
|
||||
if err != nil {
|
||||
w.log.Error("Failed to read lock file inode, shutting down", "error", err)
|
||||
w.cancel()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if currentInode != w.originalInode {
|
||||
w.log.Error("Lock file was replaced (inode changed), shutting down",
|
||||
"originalInode", w.originalInode,
|
||||
"currentInode", currentInode,
|
||||
)
|
||||
w.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
func getFileInode(f *os.File) (uint64, error) {
|
||||
info, err := f.Stat()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return getStatInode(info)
|
||||
}
|
||||
|
||||
func getStatInode(info os.FileInfo) (uint64, error) {
|
||||
stat, ok := info.Sys().(*syscall.Stat_t)
|
||||
if !ok {
|
||||
return 0, os.ErrInvalid
|
||||
}
|
||||
|
||||
return stat.Ino, nil
|
||||
}
|
||||
110
agent/internal/features/start/lock_watcher_test.go
Normal file
110
agent/internal/features/start/lock_watcher_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
//go:build !windows
|
||||
|
||||
package start
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"databasus-agent/internal/logger"
|
||||
)
|
||||
|
||||
func Test_NewLockWatcher_CapturesInode(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
log := logger.GetLogger()
|
||||
|
||||
lockFile, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
watcher, err := NewLockWatcher(lockFile, cancel, log)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, watcher.originalInode)
|
||||
}
|
||||
|
||||
func Test_LockWatcher_FileUnchanged_ContextNotCancelled(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
log := logger.GetLogger()
|
||||
|
||||
lockFile, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
watcher, err := NewLockWatcher(lockFile, cancel, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
watcher.check()
|
||||
watcher.check()
|
||||
watcher.check()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("context should not be cancelled when lock file is unchanged")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func Test_LockWatcher_FileDeleted_CancelsContext(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
log := logger.GetLogger()
|
||||
|
||||
lockFile, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
watcher, err := NewLockWatcher(lockFile, cancel, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.Remove(lockFileName)
|
||||
require.NoError(t, err)
|
||||
|
||||
watcher.check()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
default:
|
||||
t.Fatal("context should be cancelled when lock file is deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_LockWatcher_FileReplacedWithDifferentInode_CancelsContext(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
log := logger.GetLogger()
|
||||
|
||||
lockFile, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
watcher, err := NewLockWatcher(lockFile, cancel, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.Remove(lockFileName)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(lockFileName, []byte("99999\n"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
watcher.check()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
default:
|
||||
t.Fatal("context should be cancelled when lock file inode changes")
|
||||
}
|
||||
}
|
||||
17
agent/internal/features/start/lock_watcher_windows.go
Normal file
17
agent/internal/features/start/lock_watcher_windows.go
Normal file
@@ -0,0 +1,17 @@
|
||||
//go:build windows
|
||||
|
||||
package start
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
)
|
||||
|
||||
type LockWatcher struct{}
|
||||
|
||||
func NewLockWatcher(_ *os.File, _ context.CancelFunc, _ *slog.Logger) (*LockWatcher, error) {
|
||||
return &LockWatcher{}, nil
|
||||
}
|
||||
|
||||
func (w *LockWatcher) Run(_ context.Context) {}
|
||||
18
agent/internal/features/start/lock_windows.go
Normal file
18
agent/internal/features/start/lock_windows.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package start
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
)
|
||||
|
||||
func AcquireLock(log *slog.Logger) (*os.File, error) {
|
||||
log.Warn("Process locking is not supported on Windows, skipping")
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func ReleaseLock(f *os.File) {
|
||||
if f != nil {
|
||||
_ = f.Close()
|
||||
}
|
||||
}
|
||||
271
agent/internal/features/start/start.go
Normal file
271
agent/internal/features/start/start.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package start
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"databasus-agent/internal/config"
|
||||
"databasus-agent/internal/features/api"
|
||||
full_backup "databasus-agent/internal/features/full_backup"
|
||||
"databasus-agent/internal/features/upgrade"
|
||||
"databasus-agent/internal/features/wal"
|
||||
)
|
||||
|
||||
const (
|
||||
pgBasebackupVerifyTimeout = 10 * time.Second
|
||||
dbVerifyTimeout = 10 * time.Second
|
||||
minPgMajorVersion = 15
|
||||
)
|
||||
|
||||
func Start(cfg *config.Config, agentVersion string, isDev bool, log *slog.Logger) error {
|
||||
if err := validateConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := verifyPgBasebackup(cfg, log); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := verifyDatabase(cfg, log); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
return RunDaemon(cfg, agentVersion, isDev, log)
|
||||
}
|
||||
|
||||
pid, err := spawnDaemon(log)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("Agent started in background (PID %d)\n", pid)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func RunDaemon(cfg *config.Config, agentVersion string, isDev bool, log *slog.Logger) error {
|
||||
lockFile, err := AcquireLock(log)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer cancel()
|
||||
|
||||
watcher, err := NewLockWatcher(lockFile, cancel, log)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize lock watcher: %w", err)
|
||||
}
|
||||
go watcher.Run(ctx)
|
||||
|
||||
apiClient := api.NewClient(cfg.DatabasusHost, cfg.Token, log)
|
||||
|
||||
var backgroundUpgrader *upgrade.BackgroundUpgrader
|
||||
if agentVersion != "dev" && runtime.GOOS != "windows" {
|
||||
backgroundUpgrader = upgrade.NewBackgroundUpgrader(apiClient, agentVersion, isDev, cancel, log)
|
||||
go backgroundUpgrader.Run(ctx)
|
||||
}
|
||||
|
||||
fullBackuper := full_backup.NewFullBackuper(cfg, apiClient, log)
|
||||
go fullBackuper.Run(ctx)
|
||||
|
||||
streamer := wal.NewStreamer(cfg, apiClient, log)
|
||||
streamer.Run(ctx)
|
||||
|
||||
if backgroundUpgrader != nil {
|
||||
backgroundUpgrader.WaitForCompletion(30 * time.Second)
|
||||
|
||||
if backgroundUpgrader.IsUpgraded() {
|
||||
return upgrade.ErrUpgradeRestart
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("Agent stopped")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateConfig(cfg *config.Config) error {
|
||||
if cfg.DatabasusHost == "" {
|
||||
return errors.New("argument databasus-host is required")
|
||||
}
|
||||
|
||||
if cfg.DbID == "" {
|
||||
return errors.New("argument db-id is required")
|
||||
}
|
||||
|
||||
if cfg.Token == "" {
|
||||
return errors.New("argument token is required")
|
||||
}
|
||||
|
||||
if cfg.PgHost == "" {
|
||||
return errors.New("argument pg-host is required")
|
||||
}
|
||||
|
||||
if cfg.PgPort <= 0 {
|
||||
return errors.New("argument pg-port must be a positive number")
|
||||
}
|
||||
|
||||
if cfg.PgUser == "" {
|
||||
return errors.New("argument pg-user is required")
|
||||
}
|
||||
|
||||
if cfg.PgType != "host" && cfg.PgType != "docker" {
|
||||
return fmt.Errorf("argument pg-type must be 'host' or 'docker', got '%s'", cfg.PgType)
|
||||
}
|
||||
|
||||
if cfg.PgWalDir == "" {
|
||||
return errors.New("argument pg-wal-dir is required")
|
||||
}
|
||||
|
||||
if cfg.PgType == "docker" && cfg.PgDockerContainerName == "" {
|
||||
return errors.New("argument pg-docker-container-name is required when pg-type is 'docker'")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyPgBasebackup(cfg *config.Config, log *slog.Logger) error {
|
||||
switch cfg.PgType {
|
||||
case "host":
|
||||
return verifyPgBasebackupHost(cfg, log)
|
||||
case "docker":
|
||||
return verifyPgBasebackupDocker(cfg, log)
|
||||
default:
|
||||
return fmt.Errorf("unexpected pg-type: %s", cfg.PgType)
|
||||
}
|
||||
}
|
||||
|
||||
func verifyPgBasebackupHost(cfg *config.Config, log *slog.Logger) error {
|
||||
binary := "pg_basebackup"
|
||||
if cfg.PgHostBinDir != "" {
|
||||
binary = filepath.Join(cfg.PgHostBinDir, "pg_basebackup")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgBasebackupVerifyTimeout)
|
||||
defer cancel()
|
||||
|
||||
output, err := exec.CommandContext(ctx, binary, "--version").CombinedOutput()
|
||||
if err != nil {
|
||||
if cfg.PgHostBinDir != "" {
|
||||
return fmt.Errorf(
|
||||
"pg_basebackup not found at '%s': %w. Verify pg-host-bin-dir is correct",
|
||||
binary, err,
|
||||
)
|
||||
}
|
||||
|
||||
return fmt.Errorf(
|
||||
"pg_basebackup not found in PATH: %w. Install PostgreSQL client tools or set pg-host-bin-dir",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
log.Info("pg_basebackup verified", "version", strings.TrimSpace(string(output)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyPgBasebackupDocker(cfg *config.Config, log *slog.Logger) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgBasebackupVerifyTimeout)
|
||||
defer cancel()
|
||||
|
||||
output, err := exec.CommandContext(ctx,
|
||||
"docker", "exec", cfg.PgDockerContainerName,
|
||||
"pg_basebackup", "--version",
|
||||
).CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"pg_basebackup not available in container '%s': %w. "+
|
||||
"Check that the container is running and pg_basebackup is installed inside it",
|
||||
cfg.PgDockerContainerName, err,
|
||||
)
|
||||
}
|
||||
|
||||
log.Info("pg_basebackup verified (docker)",
|
||||
"container", cfg.PgDockerContainerName,
|
||||
"version", strings.TrimSpace(string(output)),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyDatabase(cfg *config.Config, log *slog.Logger) error {
|
||||
connStr := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=postgres sslmode=disable",
|
||||
cfg.PgHost, cfg.PgPort, cfg.PgUser, cfg.PgPassword,
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dbVerifyTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pgx.Connect(ctx, connStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to connect to PostgreSQL at %s:%d as user '%s': %w",
|
||||
cfg.PgHost, cfg.PgPort, cfg.PgUser, err,
|
||||
)
|
||||
}
|
||||
defer func() { _ = conn.Close(ctx) }()
|
||||
|
||||
if err := conn.Ping(ctx); err != nil {
|
||||
return fmt.Errorf("PostgreSQL ping failed at %s:%d: %w",
|
||||
cfg.PgHost, cfg.PgPort, err,
|
||||
)
|
||||
}
|
||||
|
||||
var versionNumStr string
|
||||
if err := conn.QueryRow(ctx, "SHOW server_version_num").Scan(&versionNumStr); err != nil {
|
||||
return fmt.Errorf("failed to query PostgreSQL version: %w", err)
|
||||
}
|
||||
|
||||
majorVersion, err := parsePgVersionNum(versionNumStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse PostgreSQL version '%s': %w", versionNumStr, err)
|
||||
}
|
||||
|
||||
if majorVersion < minPgMajorVersion {
|
||||
return fmt.Errorf(
|
||||
"PostgreSQL %d is not supported, minimum required version is %d",
|
||||
majorVersion, minPgMajorVersion,
|
||||
)
|
||||
}
|
||||
|
||||
log.Info("PostgreSQL connection verified",
|
||||
"host", cfg.PgHost,
|
||||
"port", cfg.PgPort,
|
||||
"user", cfg.PgUser,
|
||||
"version", majorVersion,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parsePgVersionNum(versionNumStr string) (int, error) {
|
||||
versionNum, err := strconv.Atoi(strings.TrimSpace(versionNumStr))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid version number: %w", err)
|
||||
}
|
||||
|
||||
if versionNum <= 0 {
|
||||
return 0, fmt.Errorf("invalid version number: %d", versionNum)
|
||||
}
|
||||
|
||||
majorVersion := versionNum / 10000
|
||||
|
||||
return majorVersion, nil
|
||||
}
|
||||
84
agent/internal/features/start/start_test.go
Normal file
84
agent/internal/features/start/start_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package start
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_ParsePgVersionNum_SupportedVersions_ReturnsMajorVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
versionNumStr string
|
||||
expectedMajor int
|
||||
}{
|
||||
{name: "PG 15.0", versionNumStr: "150000", expectedMajor: 15},
|
||||
{name: "PG 15.4", versionNumStr: "150004", expectedMajor: 15},
|
||||
{name: "PG 16.0", versionNumStr: "160000", expectedMajor: 16},
|
||||
{name: "PG 16.3", versionNumStr: "160003", expectedMajor: 16},
|
||||
{name: "PG 17.2", versionNumStr: "170002", expectedMajor: 17},
|
||||
{name: "PG 18.0", versionNumStr: "180000", expectedMajor: 18},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
major, err := parsePgVersionNum(tt.versionNumStr)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedMajor, major)
|
||||
assert.GreaterOrEqual(t, major, minPgMajorVersion)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ParsePgVersionNum_UnsupportedVersions_ReturnsMajorVersionBelow15(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
versionNumStr string
|
||||
expectedMajor int
|
||||
}{
|
||||
{name: "PG 12.5", versionNumStr: "120005", expectedMajor: 12},
|
||||
{name: "PG 13.0", versionNumStr: "130000", expectedMajor: 13},
|
||||
{name: "PG 14.12", versionNumStr: "140012", expectedMajor: 14},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
major, err := parsePgVersionNum(tt.versionNumStr)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedMajor, major)
|
||||
assert.Less(t, major, minPgMajorVersion)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ParsePgVersionNum_InvalidInput_ReturnsError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
versionNumStr string
|
||||
}{
|
||||
{name: "empty string", versionNumStr: ""},
|
||||
{name: "non-numeric", versionNumStr: "abc"},
|
||||
{name: "negative number", versionNumStr: "-1"},
|
||||
{name: "zero", versionNumStr: "0"},
|
||||
{name: "float", versionNumStr: "15.4"},
|
||||
{name: "whitespace only", versionNumStr: " "},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := parsePgVersionNum(tt.versionNumStr)
|
||||
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ParsePgVersionNum_WithWhitespace_ParsesCorrectly(t *testing.T) {
|
||||
major, err := parsePgVersionNum(" 150004 ")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 15, major)
|
||||
}
|
||||
88
agent/internal/features/upgrade/background_upgrader.go
Normal file
88
agent/internal/features/upgrade/background_upgrader.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package upgrade
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"databasus-agent/internal/features/api"
|
||||
)
|
||||
|
||||
const backgroundCheckInterval = 5 * time.Second
|
||||
|
||||
type BackgroundUpgrader struct {
|
||||
apiClient *api.Client
|
||||
currentVersion string
|
||||
isDev bool
|
||||
cancel context.CancelFunc
|
||||
isUpgraded atomic.Bool
|
||||
log *slog.Logger
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func NewBackgroundUpgrader(
|
||||
apiClient *api.Client,
|
||||
currentVersion string,
|
||||
isDev bool,
|
||||
cancel context.CancelFunc,
|
||||
log *slog.Logger,
|
||||
) *BackgroundUpgrader {
|
||||
return &BackgroundUpgrader{
|
||||
apiClient,
|
||||
currentVersion,
|
||||
isDev,
|
||||
cancel,
|
||||
atomic.Bool{},
|
||||
log,
|
||||
make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (u *BackgroundUpgrader) Run(ctx context.Context) {
|
||||
defer close(u.done)
|
||||
|
||||
ticker := time.NewTicker(backgroundCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if u.checkAndUpgrade() {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *BackgroundUpgrader) IsUpgraded() bool {
|
||||
return u.isUpgraded.Load()
|
||||
}
|
||||
|
||||
func (u *BackgroundUpgrader) WaitForCompletion(timeout time.Duration) {
|
||||
select {
|
||||
case <-u.done:
|
||||
case <-time.After(timeout):
|
||||
}
|
||||
}
|
||||
|
||||
func (u *BackgroundUpgrader) checkAndUpgrade() bool {
|
||||
isUpgraded, err := CheckAndUpdate(u.apiClient, u.currentVersion, u.isDev, u.log)
|
||||
if err != nil {
|
||||
u.log.Warn("Background update check failed", "error", err)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
if !isUpgraded {
|
||||
return false
|
||||
}
|
||||
|
||||
u.log.Info("Background upgrade complete, restarting...")
|
||||
u.isUpgraded.Store(true)
|
||||
u.cancel()
|
||||
|
||||
return true
|
||||
}
|
||||
5
agent/internal/features/upgrade/errors.go
Normal file
5
agent/internal/features/upgrade/errors.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package upgrade
|
||||
|
||||
import "errors"
|
||||
|
||||
var ErrUpgradeRestart = errors.New("agent upgraded, restart required")
|
||||
89
agent/internal/features/upgrade/upgrader.go
Normal file
89
agent/internal/features/upgrade/upgrader.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package upgrade
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"databasus-agent/internal/features/api"
|
||||
)
|
||||
|
||||
// CheckAndUpdate checks if a new version is available and upgrades the binary on disk.
|
||||
// Returns (true, nil) if the binary was upgraded, (false, nil) if already up to date,
|
||||
// or (false, err) on failure. Callers are responsible for re-exec or restart signaling.
|
||||
func CheckAndUpdate(apiClient *api.Client, currentVersion string, isDev bool, log *slog.Logger) (bool, error) {
|
||||
if isDev {
|
||||
log.Info("Skipping update check (development mode)")
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
serverVersion, err := apiClient.FetchServerVersion(context.Background())
|
||||
if err != nil {
|
||||
log.Warn("Could not reach server for update check", "error", err)
|
||||
|
||||
return false, fmt.Errorf(
|
||||
"unable to check version, please verify Databasus server is available: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
if serverVersion == currentVersion {
|
||||
log.Info("Agent version is up to date", "version", currentVersion)
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
log.Info("Updating agent...", "current", currentVersion, "target", serverVersion)
|
||||
|
||||
selfPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to determine executable path: %w", err)
|
||||
}
|
||||
|
||||
tempPath := selfPath + ".update"
|
||||
|
||||
defer func() {
|
||||
_ = os.Remove(tempPath)
|
||||
}()
|
||||
|
||||
if err := apiClient.DownloadAgentBinary(context.Background(), runtime.GOARCH, tempPath); err != nil {
|
||||
return false, fmt.Errorf("failed to download update: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(tempPath, 0o755); err != nil {
|
||||
return false, fmt.Errorf("failed to set permissions on update: %w", err)
|
||||
}
|
||||
|
||||
if err := verifyBinary(tempPath, serverVersion); err != nil {
|
||||
return false, fmt.Errorf("update verification failed: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tempPath, selfPath); err != nil {
|
||||
return false, fmt.Errorf("failed to replace binary (try --skip-update if this persists): %w", err)
|
||||
}
|
||||
|
||||
log.Info("Agent binary updated", "version", serverVersion)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func verifyBinary(binaryPath, expectedVersion string) error {
|
||||
cmd := exec.CommandContext(context.Background(), binaryPath, "version")
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("binary failed to execute: %w", err)
|
||||
}
|
||||
|
||||
got := strings.TrimSpace(string(output))
|
||||
if got != expectedVersion {
|
||||
return fmt.Errorf("version mismatch: expected %q, got %q", expectedVersion, got)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
182
agent/internal/features/wal/streamer.go
Normal file
182
agent/internal/features/wal/streamer.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"databasus-agent/internal/config"
|
||||
"databasus-agent/internal/features/api"
|
||||
)
|
||||
|
||||
const (
|
||||
pollInterval = 2 * time.Second
|
||||
uploadTimeout = 5 * time.Minute
|
||||
)
|
||||
|
||||
var segmentNameRegex = regexp.MustCompile(`^[0-9A-Fa-f]{24}$`)
|
||||
|
||||
type Streamer struct {
|
||||
cfg *config.Config
|
||||
apiClient *api.Client
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
func NewStreamer(cfg *config.Config, apiClient *api.Client, log *slog.Logger) *Streamer {
|
||||
return &Streamer{
|
||||
cfg: cfg,
|
||||
apiClient: apiClient,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Streamer) Run(ctx context.Context) {
|
||||
s.log.Info("WAL streamer started", "pgWalDir", s.cfg.PgWalDir)
|
||||
|
||||
s.processQueue(ctx)
|
||||
|
||||
ticker := time.NewTicker(pollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.log.Info("WAL streamer stopping")
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.processQueue(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Streamer) processQueue(ctx context.Context) {
|
||||
segments, err := s.listSegments()
|
||||
if err != nil {
|
||||
s.log.Error("Failed to list WAL segments", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, segmentName := range segments {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.uploadSegment(ctx, segmentName); err != nil {
|
||||
s.log.Error("Failed to upload WAL segment",
|
||||
"segment", segmentName,
|
||||
"error", err,
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Streamer) listSegments() ([]string, error) {
|
||||
entries, err := os.ReadDir(s.cfg.PgWalDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read wal dir: %w", err)
|
||||
}
|
||||
|
||||
var segments []string
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
name := entry.Name()
|
||||
|
||||
if strings.HasSuffix(name, ".tmp") {
|
||||
continue
|
||||
}
|
||||
|
||||
if !segmentNameRegex.MatchString(name) {
|
||||
continue
|
||||
}
|
||||
|
||||
segments = append(segments, name)
|
||||
}
|
||||
|
||||
sort.Strings(segments)
|
||||
|
||||
return segments, nil
|
||||
}
|
||||
|
||||
func (s *Streamer) uploadSegment(ctx context.Context, segmentName string) error {
|
||||
filePath := filepath.Join(s.cfg.PgWalDir, segmentName)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
go s.compressAndStream(pw, filePath)
|
||||
|
||||
uploadCtx, cancel := context.WithTimeout(ctx, uploadTimeout)
|
||||
defer cancel()
|
||||
|
||||
result, err := s.apiClient.UploadWalSegment(uploadCtx, segmentName, pr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if result.IsGapDetected {
|
||||
s.log.Warn("WAL chain gap detected",
|
||||
"segment", segmentName,
|
||||
"expected", result.ExpectedSegmentName,
|
||||
"received", result.ReceivedSegmentName,
|
||||
)
|
||||
|
||||
return fmt.Errorf("gap detected for segment %s", segmentName)
|
||||
}
|
||||
|
||||
s.log.Debug("WAL segment uploaded", "segment", segmentName)
|
||||
|
||||
if *s.cfg.IsDeleteWalAfterUpload {
|
||||
if err := os.Remove(filePath); err != nil {
|
||||
s.log.Warn("Failed to delete uploaded WAL segment",
|
||||
"segment", segmentName,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Streamer) compressAndStream(pw *io.PipeWriter, filePath string) {
|
||||
f, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
_ = pw.CloseWithError(fmt.Errorf("open file: %w", err))
|
||||
return
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
encoder, err := zstd.NewWriter(pw,
|
||||
zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(5)),
|
||||
zstd.WithEncoderCRC(true),
|
||||
)
|
||||
if err != nil {
|
||||
_ = pw.CloseWithError(fmt.Errorf("create zstd encoder: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := io.Copy(encoder, f); err != nil {
|
||||
_ = encoder.Close()
|
||||
_ = pw.CloseWithError(fmt.Errorf("compress: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := encoder.Close(); err != nil {
|
||||
_ = pw.CloseWithError(fmt.Errorf("close encoder: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
_ = pw.Close()
|
||||
}
|
||||
348
agent/internal/features/wal/streamer_test.go
Normal file
348
agent/internal/features/wal/streamer_test.go
Normal file
@@ -0,0 +1,348 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"databasus-agent/internal/config"
|
||||
"databasus-agent/internal/features/api"
|
||||
"databasus-agent/internal/logger"
|
||||
)
|
||||
|
||||
func Test_UploadSegment_SingleSegment_ServerReceivesCorrectHeadersAndBody(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
segmentContent := []byte("test-wal-segment-data-for-upload")
|
||||
writeTestSegment(t, walDir, "000000010000000100000001", segmentContent)
|
||||
|
||||
var receivedHeaders http.Header
|
||||
var receivedBody []byte
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedHeaders = r.Header.Clone()
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
receivedBody = body
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
require.NotNil(t, receivedHeaders)
|
||||
assert.Equal(t, "test-token", receivedHeaders.Get("Authorization"))
|
||||
assert.Equal(t, "application/octet-stream", receivedHeaders.Get("Content-Type"))
|
||||
assert.Equal(t, "000000010000000100000001", receivedHeaders.Get("X-Wal-Segment-Name"))
|
||||
|
||||
decompressed := decompressZstd(t, receivedBody)
|
||||
assert.Equal(t, segmentContent, decompressed)
|
||||
}
|
||||
|
||||
func Test_UploadSegments_MultipleSegmentsOutOfOrder_UploadedInAscendingOrder(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
writeTestSegment(t, walDir, "000000010000000100000003", []byte("third"))
|
||||
writeTestSegment(t, walDir, "000000010000000100000001", []byte("first"))
|
||||
writeTestSegment(t, walDir, "000000010000000100000002", []byte("second"))
|
||||
|
||||
var mu sync.Mutex
|
||||
var uploadOrder []string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
uploadOrder = append(uploadOrder, r.Header.Get("X-Wal-Segment-Name"))
|
||||
mu.Unlock()
|
||||
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
require.Len(t, uploadOrder, 3)
|
||||
assert.Equal(t, "000000010000000100000001", uploadOrder[0])
|
||||
assert.Equal(t, "000000010000000100000002", uploadOrder[1])
|
||||
assert.Equal(t, "000000010000000100000003", uploadOrder[2])
|
||||
}
|
||||
|
||||
func Test_UploadSegments_DirectoryHasTmpFiles_TmpFilesIgnored(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
writeTestSegment(t, walDir, "000000010000000100000001", []byte("real segment"))
|
||||
writeTestSegment(t, walDir, "000000010000000100000002.tmp", []byte("partial copy"))
|
||||
|
||||
var mu sync.Mutex
|
||||
var uploadedSegments []string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
uploadedSegments = append(uploadedSegments, r.Header.Get("X-Wal-Segment-Name"))
|
||||
mu.Unlock()
|
||||
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
require.Len(t, uploadedSegments, 1)
|
||||
assert.Equal(t, "000000010000000100000001", uploadedSegments[0])
|
||||
}
|
||||
|
||||
func Test_UploadSegment_DeleteEnabled_FileRemovedAfterUpload(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
segmentName := "000000010000000100000001"
|
||||
writeTestSegment(t, walDir, segmentName, []byte("segment data"))
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
isDeleteEnabled := true
|
||||
cfg := createTestConfig(walDir, server.URL)
|
||||
cfg.IsDeleteWalAfterUpload = &isDeleteEnabled
|
||||
apiClient := api.NewClient(server.URL, cfg.Token, logger.GetLogger())
|
||||
streamer := NewStreamer(cfg, apiClient, logger.GetLogger())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
_, err := os.Stat(filepath.Join(walDir, segmentName))
|
||||
assert.True(t, os.IsNotExist(err), "segment file should be deleted after successful upload")
|
||||
}
|
||||
|
||||
func Test_UploadSegment_DeleteDisabled_FileKeptAfterUpload(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
segmentName := "000000010000000100000001"
|
||||
writeTestSegment(t, walDir, segmentName, []byte("segment data"))
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
isDeleteDisabled := false
|
||||
cfg := createTestConfig(walDir, server.URL)
|
||||
cfg.IsDeleteWalAfterUpload = &isDeleteDisabled
|
||||
apiClient := api.NewClient(server.URL, cfg.Token, logger.GetLogger())
|
||||
streamer := NewStreamer(cfg, apiClient, logger.GetLogger())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
_, err := os.Stat(filepath.Join(walDir, segmentName))
|
||||
assert.NoError(t, err, "segment file should be kept when delete is disabled")
|
||||
}
|
||||
|
||||
func Test_UploadSegment_ServerReturns500_FileKeptInQueue(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
segmentName := "000000010000000100000001"
|
||||
writeTestSegment(t, walDir, segmentName, []byte("segment data"))
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte(`{"error":"internal server error"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
_, err := os.Stat(filepath.Join(walDir, segmentName))
|
||||
assert.NoError(t, err, "segment file should remain in queue after server error")
|
||||
}
|
||||
|
||||
func Test_ProcessQueue_EmptyDirectory_NoUploads(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
|
||||
uploadCount := 0
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
uploadCount++
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
assert.Equal(t, 0, uploadCount, "no uploads should occur for empty directory")
|
||||
}
|
||||
|
||||
func Test_Run_ContextCancelled_StopsImmediately(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
|
||||
streamer := newTestStreamer(walDir, "http://localhost:0")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
streamer.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Run should have stopped immediately when context is already cancelled")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_UploadSegment_ServerReturns409_FileNotDeleted(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
segmentName := "000000010000000100000005"
|
||||
writeTestSegment(t, walDir, segmentName, []byte("gap segment"))
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusConflict)
|
||||
|
||||
resp := map[string]string{
|
||||
"error": "gap_detected",
|
||||
"expectedSegmentName": "000000010000000100000003",
|
||||
"receivedSegmentName": segmentName,
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
_, err := os.Stat(filepath.Join(walDir, segmentName))
|
||||
assert.NoError(t, err, "segment file should not be deleted on gap detection")
|
||||
}
|
||||
|
||||
func newTestStreamer(walDir, serverURL string) *Streamer {
|
||||
cfg := createTestConfig(walDir, serverURL)
|
||||
apiClient := api.NewClient(serverURL, cfg.Token, logger.GetLogger())
|
||||
|
||||
return NewStreamer(cfg, apiClient, logger.GetLogger())
|
||||
}
|
||||
|
||||
func createTestWalDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
baseDir := filepath.Join(".", ".test-tmp")
|
||||
if err := os.MkdirAll(baseDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create base test dir: %v", err)
|
||||
}
|
||||
|
||||
dir, err := os.MkdirTemp(baseDir, t.Name()+"-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test wal dir: %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(dir)
|
||||
})
|
||||
|
||||
return dir
|
||||
}
|
||||
|
||||
func writeTestSegment(t *testing.T, dir, name string, content []byte) {
|
||||
t.Helper()
|
||||
|
||||
if err := os.WriteFile(filepath.Join(dir, name), content, 0o644); err != nil {
|
||||
t.Fatalf("failed to write test segment %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
func createTestConfig(walDir, serverURL string) *config.Config {
|
||||
isDeleteEnabled := true
|
||||
|
||||
return &config.Config{
|
||||
DatabasusHost: serverURL,
|
||||
DbID: "test-db-id",
|
||||
Token: "test-token",
|
||||
PgWalDir: walDir,
|
||||
IsDeleteWalAfterUpload: &isDeleteEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func decompressZstd(t *testing.T, data []byte) []byte {
|
||||
t.Helper()
|
||||
|
||||
decoder, err := zstd.NewReader(nil)
|
||||
require.NoError(t, err)
|
||||
defer decoder.Close()
|
||||
|
||||
decoded, err := decoder.DecodeAll(data, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
return decoded
|
||||
}
|
||||
119
agent/internal/logger/logger.go
Normal file
119
agent/internal/logger/logger.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
logFileName = "databasus.log"
|
||||
oldLogFileName = "databasus.log.old"
|
||||
maxLogFileSize = 5 * 1024 * 1024 // 5MB
|
||||
)
|
||||
|
||||
type rotatingWriter struct {
|
||||
mu sync.Mutex
|
||||
file *os.File
|
||||
currentSize int64
|
||||
maxSize int64
|
||||
logPath string
|
||||
oldLogPath string
|
||||
}
|
||||
|
||||
func (w *rotatingWriter) Write(p []byte) (int, error) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if w.currentSize+int64(len(p)) > w.maxSize {
|
||||
if err := w.rotate(); err != nil {
|
||||
return 0, fmt.Errorf("failed to rotate log file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
n, err := w.file.Write(p)
|
||||
w.currentSize += int64(n)
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (w *rotatingWriter) rotate() error {
|
||||
if err := w.file.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close %s: %w", w.logPath, err)
|
||||
}
|
||||
|
||||
if err := os.Remove(w.oldLogPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to remove %s: %w", w.oldLogPath, err)
|
||||
}
|
||||
|
||||
if err := os.Rename(w.logPath, w.oldLogPath); err != nil {
|
||||
return fmt.Errorf("failed to rename %s to %s: %w", w.logPath, w.oldLogPath, err)
|
||||
}
|
||||
|
||||
f, err := os.OpenFile(w.logPath, os.O_CREATE|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create new %s: %w", w.logPath, err)
|
||||
}
|
||||
|
||||
w.file = f
|
||||
w.currentSize = 0
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
loggerInstance *slog.Logger
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
func GetLogger() *slog.Logger {
|
||||
once.Do(func() {
|
||||
initialize()
|
||||
})
|
||||
|
||||
return loggerInstance
|
||||
}
|
||||
|
||||
func initialize() {
|
||||
writer := buildWriter()
|
||||
|
||||
loggerInstance = slog.New(slog.NewTextHandler(writer, &slog.HandlerOptions{
|
||||
Level: slog.LevelInfo,
|
||||
ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr {
|
||||
if a.Key == slog.TimeKey {
|
||||
a.Value = slog.StringValue(time.Now().Format("2006/01/02 15:04:05"))
|
||||
}
|
||||
if a.Key == slog.LevelKey {
|
||||
return slog.Attr{}
|
||||
}
|
||||
|
||||
return a
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
func buildWriter() io.Writer {
|
||||
f, err := os.OpenFile(logFileName, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Warning: failed to open %s for logging: %v\n", logFileName, err)
|
||||
return os.Stdout
|
||||
}
|
||||
|
||||
var currentSize int64
|
||||
if info, err := f.Stat(); err == nil {
|
||||
currentSize = info.Size()
|
||||
}
|
||||
|
||||
rw := &rotatingWriter{
|
||||
file: f,
|
||||
currentSize: currentSize,
|
||||
maxSize: maxLogFileSize,
|
||||
logPath: logFileName,
|
||||
oldLogPath: oldLogFileName,
|
||||
}
|
||||
|
||||
return io.MultiWriter(os.Stdout, rw)
|
||||
}
|
||||
128
agent/internal/logger/logger_test.go
Normal file
128
agent/internal/logger/logger_test.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_Write_DataWrittenToFile(t *testing.T) {
|
||||
rw, logPath, _ := setupRotatingWriter(t, 1024)
|
||||
|
||||
data := []byte("hello world\n")
|
||||
n, err := rw.Write(data)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(data), n)
|
||||
assert.Equal(t, int64(len(data)), rw.currentSize)
|
||||
|
||||
content, err := os.ReadFile(logPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(data), string(content))
|
||||
}
|
||||
|
||||
func Test_Write_WhenLimitExceeded_FileRotated(t *testing.T) {
|
||||
rw, logPath, oldLogPath := setupRotatingWriter(t, 100)
|
||||
|
||||
firstData := []byte(strings.Repeat("A", 80))
|
||||
_, err := rw.Write(firstData)
|
||||
require.NoError(t, err)
|
||||
|
||||
secondData := []byte(strings.Repeat("B", 30))
|
||||
_, err = rw.Write(secondData)
|
||||
require.NoError(t, err)
|
||||
|
||||
oldContent, err := os.ReadFile(oldLogPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(firstData), string(oldContent))
|
||||
|
||||
newContent, err := os.ReadFile(logPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(secondData), string(newContent))
|
||||
|
||||
assert.Equal(t, int64(len(secondData)), rw.currentSize)
|
||||
}
|
||||
|
||||
func Test_Write_WhenOldFileExists_OldFileReplaced(t *testing.T) {
|
||||
rw, _, oldLogPath := setupRotatingWriter(t, 100)
|
||||
|
||||
require.NoError(t, os.WriteFile(oldLogPath, []byte("stale data"), 0o644))
|
||||
|
||||
_, err := rw.Write([]byte(strings.Repeat("A", 80)))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = rw.Write([]byte(strings.Repeat("B", 30)))
|
||||
require.NoError(t, err)
|
||||
|
||||
oldContent, err := os.ReadFile(oldLogPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, strings.Repeat("A", 80), string(oldContent))
|
||||
}
|
||||
|
||||
func Test_Write_MultipleSmallWrites_CurrentSizeAccumulated(t *testing.T) {
|
||||
rw, _, _ := setupRotatingWriter(t, 1024)
|
||||
|
||||
var totalWritten int64
|
||||
for i := 0; i < 10; i++ {
|
||||
data := []byte("line\n")
|
||||
n, err := rw.Write(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
totalWritten += int64(n)
|
||||
}
|
||||
|
||||
assert.Equal(t, totalWritten, rw.currentSize)
|
||||
assert.Equal(t, int64(50), rw.currentSize)
|
||||
}
|
||||
|
||||
func Test_Write_ExactlyAtBoundary_NoRotationUntilNextByte(t *testing.T) {
|
||||
rw, logPath, oldLogPath := setupRotatingWriter(t, 100)
|
||||
|
||||
exactData := []byte(strings.Repeat("X", 100))
|
||||
_, err := rw.Write(exactData)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = os.Stat(oldLogPath)
|
||||
assert.True(t, os.IsNotExist(err), ".old file should not exist yet")
|
||||
|
||||
content, err := os.ReadFile(logPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(exactData), string(content))
|
||||
|
||||
_, err = rw.Write([]byte("Z"))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = os.Stat(oldLogPath)
|
||||
assert.NoError(t, err, ".old file should exist after exceeding limit")
|
||||
|
||||
assert.Equal(t, int64(1), rw.currentSize)
|
||||
}
|
||||
|
||||
func setupRotatingWriter(t *testing.T, maxSize int64) (*rotatingWriter, string, string) {
|
||||
t.Helper()
|
||||
|
||||
dir := t.TempDir()
|
||||
logPath := filepath.Join(dir, "test.log")
|
||||
oldLogPath := filepath.Join(dir, "test.log.old")
|
||||
|
||||
f, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY, 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
rw := &rotatingWriter{
|
||||
file: f,
|
||||
currentSize: 0,
|
||||
maxSize: maxSize,
|
||||
logPath: logPath,
|
||||
oldLogPath: oldLogPath,
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
rw.file.Close()
|
||||
})
|
||||
|
||||
return rw, logPath, oldLogPath
|
||||
}
|
||||
@@ -7,6 +7,16 @@ run:
|
||||
|
||||
linters:
|
||||
default: standard
|
||||
enable:
|
||||
- funcorder
|
||||
- bodyclose
|
||||
- errorlint
|
||||
- gocritic
|
||||
- unconvert
|
||||
- misspell
|
||||
- errname
|
||||
- noctx
|
||||
- modernize
|
||||
|
||||
settings:
|
||||
errcheck:
|
||||
@@ -14,6 +24,18 @@ linters:
|
||||
|
||||
formatters:
|
||||
enable:
|
||||
- gofmt
|
||||
- gofumpt
|
||||
- golines
|
||||
- goimports
|
||||
- gci
|
||||
|
||||
settings:
|
||||
golines:
|
||||
max-len: 120
|
||||
gofumpt:
|
||||
module-path: databasus-backend
|
||||
extra-rules: true
|
||||
gci:
|
||||
sections:
|
||||
- standard
|
||||
- default
|
||||
- localmodule
|
||||
|
||||
@@ -12,6 +12,12 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-contrib/gzip"
|
||||
"github.com/gin-gonic/gin"
|
||||
swaggerFiles "github.com/swaggo/files"
|
||||
ginSwagger "github.com/swaggo/gin-swagger"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
@@ -28,6 +34,7 @@ import (
|
||||
"databasus-backend/internal/features/restores"
|
||||
"databasus-backend/internal/features/restores/restoring"
|
||||
"databasus-backend/internal/features/storages"
|
||||
system_agent "databasus-backend/internal/features/system/agent"
|
||||
system_healthcheck "databasus-backend/internal/features/system/healthcheck"
|
||||
system_version "databasus-backend/internal/features/system/version"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
@@ -40,12 +47,6 @@ import (
|
||||
files_utils "databasus-backend/internal/util/files"
|
||||
"databasus-backend/internal/util/logger"
|
||||
_ "databasus-backend/swagger" // swagger docs
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-contrib/gzip"
|
||||
"github.com/gin-gonic/gin"
|
||||
swaggerFiles "github.com/swaggo/files"
|
||||
ginSwagger "github.com/swaggo/gin-swagger"
|
||||
)
|
||||
|
||||
// @title Databasus Backend API
|
||||
@@ -82,7 +83,6 @@ func main() {
|
||||
config.GetEnv().TempFolder,
|
||||
config.GetEnv().DataFolder,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error("Failed to ensure directories", "error", err)
|
||||
os.Exit(1)
|
||||
@@ -149,7 +149,7 @@ func handlePasswordReset(log *slog.Logger) {
|
||||
resetPassword(*email, *newPassword, log)
|
||||
}
|
||||
|
||||
func resetPassword(email string, newPassword string, log *slog.Logger) {
|
||||
func resetPassword(email, newPassword string, log *slog.Logger) {
|
||||
log.Info("Resetting password...")
|
||||
|
||||
userService := users_services.GetUserService()
|
||||
@@ -212,6 +212,7 @@ func setUpRoutes(r *gin.Engine) {
|
||||
userController.RegisterRoutes(v1)
|
||||
system_healthcheck.GetHealthcheckController().RegisterRoutes(v1)
|
||||
system_version.GetVersionController().RegisterRoutes(v1)
|
||||
system_agent.GetAgentController().RegisterRoutes(v1)
|
||||
backups_controllers.GetBackupController().RegisterPublicRoutes(v1)
|
||||
backups_controllers.GetPostgresWalBackupController().RegisterRoutes(v1)
|
||||
databases.GetDatabaseController().RegisterPublicRoutes(v1)
|
||||
@@ -352,7 +353,9 @@ func generateSwaggerDocs(log *slog.Logger) {
|
||||
return
|
||||
}
|
||||
|
||||
cmd := exec.Command("swag", "init", "-d", currentDir, "-g", "cmd/main.go", "-o", "swagger")
|
||||
cmd := exec.CommandContext(
|
||||
context.Background(), "swag", "init", "-d", currentDir, "-g", "cmd/main.go", "-o", "swagger",
|
||||
)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
@@ -366,7 +369,7 @@ func generateSwaggerDocs(log *slog.Logger) {
|
||||
func runMigrations(log *slog.Logger) {
|
||||
log.Info("Running database migrations...")
|
||||
|
||||
cmd := exec.Command("goose", "-dir", "./migrations", "up")
|
||||
cmd := exec.CommandContext(context.Background(), "goose", "-dir", "./migrations", "up")
|
||||
cmd.Env = append(
|
||||
os.Environ(),
|
||||
"GOOSE_DRIVER=postgres",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
module databasus-backend
|
||||
|
||||
go 1.24.9
|
||||
go 1.26.1
|
||||
|
||||
require (
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
env_utils "databasus-backend/internal/util/env"
|
||||
"databasus-backend/internal/util/logger"
|
||||
"databasus-backend/internal/util/tools"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -11,6 +8,10 @@ import (
|
||||
|
||||
"github.com/ilyakaznacheev/cleanenv"
|
||||
"github.com/joho/godotenv"
|
||||
|
||||
env_utils "databasus-backend/internal/util/env"
|
||||
"databasus-backend/internal/util/logger"
|
||||
"databasus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
var log = logger.GetLogger()
|
||||
@@ -29,7 +30,7 @@ type EnvVariables struct {
|
||||
MongodbInstallDir string `env:"MONGODB_INSTALL_DIR"`
|
||||
|
||||
// Internal database
|
||||
DatabaseDsn string `env:"DATABASE_DSN" required:"true"`
|
||||
DatabaseDsn string `env:"DATABASE_DSN" required:"true"`
|
||||
// Internal Valkey
|
||||
ValkeyHost string `env:"VALKEY_HOST" required:"true"`
|
||||
ValkeyPort string `env:"VALKEY_PORT" required:"true"`
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/storage"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gorm.io/gorm"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
"databasus-backend/internal/storage"
|
||||
)
|
||||
|
||||
func Test_CleanOldAuditLogs_DeletesLogsOlderThanOneYear(t *testing.T) {
|
||||
|
||||
@@ -4,10 +4,10 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
)
|
||||
|
||||
type AuditLogController struct {
|
||||
|
||||
@@ -6,15 +6,15 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_GetGlobalAuditLogs_WithDifferentUserRoles_EnforcesPermissionsCorrectly(t *testing.T) {
|
||||
|
||||
@@ -8,14 +8,18 @@ import (
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var auditLogRepository = &AuditLogRepository{}
|
||||
var auditLogService = &AuditLogService{
|
||||
auditLogRepository,
|
||||
logger.GetLogger(),
|
||||
}
|
||||
var (
|
||||
auditLogRepository = &AuditLogRepository{}
|
||||
auditLogService = &AuditLogService{
|
||||
auditLogRepository,
|
||||
logger.GetLogger(),
|
||||
}
|
||||
)
|
||||
|
||||
var auditLogController = &AuditLogController{
|
||||
auditLogService,
|
||||
}
|
||||
|
||||
var auditLogBackgroundService = &AuditLogBackgroundService{
|
||||
auditLogService: auditLogService,
|
||||
logger: logger.GetLogger(),
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/storage"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/storage"
|
||||
)
|
||||
|
||||
type AuditLogRepository struct{}
|
||||
@@ -21,7 +22,7 @@ func (r *AuditLogRepository) GetGlobal(
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLogDTO, error) {
|
||||
var auditLogs = make([]*AuditLogDTO, 0)
|
||||
auditLogs := make([]*AuditLogDTO, 0)
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
@@ -37,7 +38,7 @@ func (r *AuditLogRepository) GetGlobal(
|
||||
LEFT JOIN users u ON al.user_id = u.id
|
||||
LEFT JOIN workspaces w ON al.workspace_id = w.id`
|
||||
|
||||
args := []interface{}{}
|
||||
args := []any{}
|
||||
|
||||
if beforeDate != nil {
|
||||
sql += " WHERE al.created_at < ?"
|
||||
@@ -57,7 +58,7 @@ func (r *AuditLogRepository) GetByUser(
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLogDTO, error) {
|
||||
var auditLogs = make([]*AuditLogDTO, 0)
|
||||
auditLogs := make([]*AuditLogDTO, 0)
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
@@ -74,7 +75,7 @@ func (r *AuditLogRepository) GetByUser(
|
||||
LEFT JOIN workspaces w ON al.workspace_id = w.id
|
||||
WHERE al.user_id = ?`
|
||||
|
||||
args := []interface{}{userID}
|
||||
args := []any{userID}
|
||||
|
||||
if beforeDate != nil {
|
||||
sql += " AND al.created_at < ?"
|
||||
@@ -94,7 +95,7 @@ func (r *AuditLogRepository) GetByWorkspace(
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLogDTO, error) {
|
||||
var auditLogs = make([]*AuditLogDTO, 0)
|
||||
auditLogs := make([]*AuditLogDTO, 0)
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
@@ -111,7 +112,7 @@ func (r *AuditLogRepository) GetByWorkspace(
|
||||
LEFT JOIN workspaces w ON al.workspace_id = w.id
|
||||
WHERE al.workspace_id = ?`
|
||||
|
||||
args := []interface{}{workspaceID}
|
||||
args := []any{workspaceID}
|
||||
|
||||
if beforeDate != nil {
|
||||
sql += " AND al.created_at < ?"
|
||||
|
||||
@@ -4,10 +4,10 @@ import (
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogService struct {
|
||||
|
||||
@@ -4,11 +4,11 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
)
|
||||
|
||||
func Test_AuditLogs_WorkspaceSpecificLogs(t *testing.T) {
|
||||
|
||||
@@ -5,6 +5,9 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -14,9 +17,6 @@ import (
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
|
||||
@@ -59,6 +59,10 @@ func (c *BackupCleaner) Run(ctx context.Context) {
|
||||
if err := c.cleanExceededBackups(); err != nil {
|
||||
c.logger.Error("Failed to clean exceeded backups", "error", err)
|
||||
}
|
||||
|
||||
if err := c.cleanStaleUploadedBasebackups(); err != nil {
|
||||
c.logger.Error("Failed to clean stale uploaded basebackups", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -100,6 +104,67 @@ func (c *BackupCleaner) AddBackupRemoveListener(listener backups_core.BackupRemo
|
||||
c.backupRemoveListeners = append(c.backupRemoveListeners, listener)
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanStaleUploadedBasebackups() error {
|
||||
staleBackups, err := c.backupRepository.FindStaleUploadedBasebackups(
|
||||
time.Now().UTC().Add(-10 * time.Minute),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find stale uploaded basebackups: %w", err)
|
||||
}
|
||||
|
||||
for _, backup := range staleBackups {
|
||||
staleStorage, storageErr := c.storageService.GetStorageByID(backup.StorageID)
|
||||
if storageErr != nil {
|
||||
c.logger.Error(
|
||||
"Failed to get storage for stale basebackup cleanup",
|
||||
"backupId", backup.ID,
|
||||
"storageId", backup.StorageID,
|
||||
"error", storageErr,
|
||||
)
|
||||
} else {
|
||||
if err := staleStorage.DeleteFile(c.fieldEncryptor, backup.FileName); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to delete stale basebackup file",
|
||||
"backupId", backup.ID,
|
||||
"fileName", backup.FileName,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
|
||||
metadataFileName := backup.FileName + ".metadata"
|
||||
if err := staleStorage.DeleteFile(c.fieldEncryptor, metadataFileName); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to delete stale basebackup metadata file",
|
||||
"backupId", backup.ID,
|
||||
"fileName", metadataFileName,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
failMsg := "basebackup finalization timed out after 10 minutes"
|
||||
backup.Status = backups_core.BackupStatusFailed
|
||||
backup.FailMessage = &failMsg
|
||||
|
||||
if err := c.backupRepository.Save(backup); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to mark stale uploaded basebackup as failed",
|
||||
"backupId", backup.ID,
|
||||
"error", err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Marked stale uploaded basebackup as failed and cleaned storage",
|
||||
"backupId", backup.ID,
|
||||
"databaseId", backup.DatabaseID,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanByRetentionPolicy() error {
|
||||
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
@@ -446,22 +511,24 @@ func buildGFSKeepSet(
|
||||
}
|
||||
|
||||
dailyCutoff := rawDailyCutoff
|
||||
if weeks > 0 {
|
||||
switch {
|
||||
case weeks > 0:
|
||||
dailyCutoff = laterOf(dailyCutoff, weeklyCutoff)
|
||||
} else if months > 0 {
|
||||
case months > 0:
|
||||
dailyCutoff = laterOf(dailyCutoff, monthlyCutoff)
|
||||
} else if years > 0 {
|
||||
case years > 0:
|
||||
dailyCutoff = laterOf(dailyCutoff, yearlyCutoff)
|
||||
}
|
||||
|
||||
hourlyCutoff := rawHourlyCutoff
|
||||
if days > 0 {
|
||||
switch {
|
||||
case days > 0:
|
||||
hourlyCutoff = laterOf(hourlyCutoff, dailyCutoff)
|
||||
} else if weeks > 0 {
|
||||
case weeks > 0:
|
||||
hourlyCutoff = laterOf(hourlyCutoff, weeklyCutoff)
|
||||
} else if months > 0 {
|
||||
case months > 0:
|
||||
hourlyCutoff = laterOf(hourlyCutoff, monthlyCutoff)
|
||||
} else if years > 0 {
|
||||
case years > 0:
|
||||
hourlyCutoff = laterOf(hourlyCutoff, yearlyCutoff)
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,9 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -13,9 +16,6 @@ import (
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_BuildGFSKeepSet(t *testing.T) {
|
||||
|
||||
@@ -4,6 +4,9 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -15,9 +18,6 @@ import (
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/storage"
|
||||
"databasus-backend/internal/util/period"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_CleanOldBackups_DeletesBackupsOlderThanRetentionTimePeriod(t *testing.T) {
|
||||
@@ -1004,6 +1004,191 @@ func (m *mockBackupRemoveListener) OnBeforeBackupRemove(backup *backups_core.Bac
|
||||
return nil
|
||||
}
|
||||
|
||||
func Test_CleanStaleUploadedBasebackups_MarksAsFailed(t *testing.T) {
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
staleTime := time.Now().UTC().Add(-15 * time.Minute)
|
||||
walBackupType := backups_core.PgWalBackupTypeFullBackup
|
||||
staleBackup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
PgWalBackupType: &walBackupType,
|
||||
UploadCompletedAt: &staleTime,
|
||||
CreatedAt: staleTime,
|
||||
}
|
||||
|
||||
err := backupRepository.Save(staleBackup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanStaleUploadedBasebackups()
|
||||
assert.NoError(t, err)
|
||||
|
||||
updated, err := backupRepository.FindByID(staleBackup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusFailed, updated.Status)
|
||||
assert.NotNil(t, updated.FailMessage)
|
||||
assert.Contains(t, *updated.FailMessage, "finalization timed out")
|
||||
}
|
||||
|
||||
func Test_CleanStaleUploadedBasebackups_SkipsRecentUploads(t *testing.T) {
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
recentTime := time.Now().UTC().Add(-2 * time.Minute)
|
||||
walBackupType := backups_core.PgWalBackupTypeFullBackup
|
||||
recentBackup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
PgWalBackupType: &walBackupType,
|
||||
UploadCompletedAt: &recentTime,
|
||||
CreatedAt: recentTime,
|
||||
}
|
||||
|
||||
err := backupRepository.Save(recentBackup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanStaleUploadedBasebackups()
|
||||
assert.NoError(t, err)
|
||||
|
||||
updated, err := backupRepository.FindByID(recentBackup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusInProgress, updated.Status)
|
||||
}
|
||||
|
||||
func Test_CleanStaleUploadedBasebackups_SkipsActiveStreaming(t *testing.T) {
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
walBackupType := backups_core.PgWalBackupTypeFullBackup
|
||||
activeBackup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
PgWalBackupType: &walBackupType,
|
||||
CreatedAt: time.Now().UTC().Add(-30 * time.Minute),
|
||||
}
|
||||
|
||||
err := backupRepository.Save(activeBackup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanStaleUploadedBasebackups()
|
||||
assert.NoError(t, err)
|
||||
|
||||
updated, err := backupRepository.FindByID(activeBackup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusInProgress, updated.Status)
|
||||
assert.Nil(t, updated.UploadCompletedAt)
|
||||
}
|
||||
|
||||
func Test_CleanStaleUploadedBasebackups_CleansStorageFiles(t *testing.T) {
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
staleTime := time.Now().UTC().Add(-15 * time.Minute)
|
||||
walBackupType := backups_core.PgWalBackupTypeFullBackup
|
||||
staleBackup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
PgWalBackupType: &walBackupType,
|
||||
UploadCompletedAt: &staleTime,
|
||||
BackupSizeMb: 500,
|
||||
FileName: "stale-basebackup-test-file",
|
||||
CreatedAt: staleTime,
|
||||
}
|
||||
|
||||
err := backupRepository.Save(staleBackup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanStaleUploadedBasebackups()
|
||||
assert.NoError(t, err)
|
||||
|
||||
updated, err := backupRepository.FindByID(staleBackup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusFailed, updated.Status)
|
||||
assert.NotNil(t, updated.FailMessage)
|
||||
assert.Contains(t, *updated.FailMessage, "finalization timed out")
|
||||
}
|
||||
|
||||
func createTestInterval() *intervals.Interval {
|
||||
timeOfDay := "04:00"
|
||||
interval := &intervals.Interval{
|
||||
|
||||
@@ -6,15 +6,15 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/mock"
|
||||
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type MockNotificationSender struct {
|
||||
|
||||
@@ -10,10 +10,10 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/valkey-io/valkey-go"
|
||||
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -415,7 +415,7 @@ func (r *BackupNodesRegistry) UnsubscribeNodeForBackupsAssignments() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) PublishBackupCompletion(nodeID uuid.UUID, backupID uuid.UUID) error {
|
||||
func (r *BackupNodesRegistry) PublishBackupCompletion(nodeID, backupID uuid.UUID) error {
|
||||
ctx := context.Background()
|
||||
|
||||
message := BackupCompletionMessage{
|
||||
@@ -437,7 +437,7 @@ func (r *BackupNodesRegistry) PublishBackupCompletion(nodeID uuid.UUID, backupID
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) SubscribeForBackupsCompletions(
|
||||
handler func(nodeID uuid.UUID, backupID uuid.UUID),
|
||||
handler func(nodeID, backupID uuid.UUID),
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
|
||||
@@ -9,11 +9,11 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/logger"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
func Test_HearthbeatNodeInRegistry_RegistersNodeWithTTL(t *testing.T) {
|
||||
@@ -903,7 +903,7 @@ func Test_SubscribeForBackupsCompletions_ReceivesCompletedBackups(t *testing.T)
|
||||
|
||||
receivedBackupID := make(chan uuid.UUID, 1)
|
||||
receivedNodeID := make(chan uuid.UUID, 1)
|
||||
handler := func(nodeID uuid.UUID, backupID uuid.UUID) {
|
||||
handler := func(nodeID, backupID uuid.UUID) {
|
||||
receivedNodeID <- nodeID
|
||||
receivedBackupID <- backupID
|
||||
}
|
||||
@@ -940,7 +940,7 @@ func Test_SubscribeForBackupsCompletions_ParsesJsonCorrectly(t *testing.T) {
|
||||
defer registry.UnsubscribeForBackupsCompletions()
|
||||
|
||||
receivedBackups := make(chan uuid.UUID, 2)
|
||||
handler := func(nodeID uuid.UUID, backupID uuid.UUID) {
|
||||
handler := func(nodeID, backupID uuid.UUID) {
|
||||
receivedBackups <- backupID
|
||||
}
|
||||
|
||||
@@ -969,7 +969,7 @@ func Test_SubscribeForBackupsCompletions_HandlesInvalidJson(t *testing.T) {
|
||||
defer registry.UnsubscribeForBackupsCompletions()
|
||||
|
||||
receivedBackupID := make(chan uuid.UUID, 1)
|
||||
handler := func(nodeID uuid.UUID, backupID uuid.UUID) {
|
||||
handler := func(nodeID, backupID uuid.UUID) {
|
||||
receivedBackupID <- backupID
|
||||
}
|
||||
|
||||
@@ -997,7 +997,7 @@ func Test_UnsubscribeForBackupsCompletions_StopsReceivingMessages(t *testing.T)
|
||||
backupID2 := uuid.New()
|
||||
|
||||
receivedBackupID := make(chan uuid.UUID, 2)
|
||||
handler := func(nodeID uuid.UUID, backupID uuid.UUID) {
|
||||
handler := func(nodeID, backupID uuid.UUID) {
|
||||
receivedBackupID <- backupID
|
||||
}
|
||||
|
||||
@@ -1032,7 +1032,7 @@ func Test_SubscribeForBackupsCompletions_WhenAlreadySubscribed_ReturnsError(t *t
|
||||
registry := createTestRegistry()
|
||||
defer registry.UnsubscribeForBackupsCompletions()
|
||||
|
||||
handler := func(nodeID uuid.UUID, backupID uuid.UUID) {}
|
||||
handler := func(nodeID, backupID uuid.UUID) {}
|
||||
|
||||
err := registry.SubscribeForBackupsCompletions(handler)
|
||||
assert.NoError(t, err)
|
||||
@@ -1064,9 +1064,9 @@ func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) {
|
||||
receivedBackups2 := make(chan uuid.UUID, 3)
|
||||
receivedBackups3 := make(chan uuid.UUID, 3)
|
||||
|
||||
handler1 := func(nodeID uuid.UUID, backupID uuid.UUID) { receivedBackups1 <- backupID }
|
||||
handler2 := func(nodeID uuid.UUID, backupID uuid.UUID) { receivedBackups2 <- backupID }
|
||||
handler3 := func(nodeID uuid.UUID, backupID uuid.UUID) { receivedBackups3 <- backupID }
|
||||
handler1 := func(nodeID, backupID uuid.UUID) { receivedBackups1 <- backupID }
|
||||
handler2 := func(nodeID, backupID uuid.UUID) { receivedBackups2 <- backupID }
|
||||
handler3 := func(nodeID, backupID uuid.UUID) { receivedBackups3 <- backupID }
|
||||
|
||||
err := registry1.SubscribeForBackupsCompletions(handler1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -441,7 +441,7 @@ func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
|
||||
return &bestNode.ID, nil
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) onBackupCompleted(nodeID uuid.UUID, backupID uuid.UUID) {
|
||||
func (s *BackupsScheduler) onBackupCompleted(nodeID, backupID uuid.UUID) {
|
||||
// Verify this task is actually a backup (registry contains multiple task types)
|
||||
_, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -12,11 +18,6 @@ import (
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/period"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_RunPendingBackups_WhenLastBackupWasYesterday_CreatesNewBackup(t *testing.T) {
|
||||
|
||||
@@ -8,6 +8,9 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/usecases"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
@@ -19,9 +22,6 @@ import (
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func CreateTestRouter() *gin.Engine {
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"errors"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
)
|
||||
|
||||
type BackupMetadata struct {
|
||||
|
||||
@@ -7,6 +7,10 @@ type CountingWriter struct {
|
||||
BytesWritten int64
|
||||
}
|
||||
|
||||
func NewCountingWriter(writer io.Writer) *CountingWriter {
|
||||
return &CountingWriter{Writer: writer}
|
||||
}
|
||||
|
||||
func (cw *CountingWriter) Write(p []byte) (n int, err error) {
|
||||
n, err = cw.Writer.Write(p)
|
||||
cw.BytesWritten += int64(n)
|
||||
@@ -16,7 +20,3 @@ func (cw *CountingWriter) Write(p []byte) (n int, err error) {
|
||||
func (cw *CountingWriter) GetBytesWritten() int64 {
|
||||
return cw.BytesWritten
|
||||
}
|
||||
|
||||
func NewCountingWriter(writer io.Writer) *CountingWriter {
|
||||
return &CountingWriter{Writer: writer}
|
||||
}
|
||||
|
||||
@@ -2,13 +2,7 @@ package backups_controllers
|
||||
|
||||
import (
|
||||
"context"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
backups_dto "databasus-backend/internal/features/backups/backups/dto"
|
||||
backups_services "databasus-backend/internal/features/backups/backups/services"
|
||||
"databasus-backend/internal/features/databases"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
files_utils "databasus-backend/internal/util/files"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -16,6 +10,14 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
backups_dto "databasus-backend/internal/features/backups/backups/dto"
|
||||
backups_services "databasus-backend/internal/features/backups/backups/services"
|
||||
"databasus-backend/internal/features/databases"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
files_utils "databasus-backend/internal/util/files"
|
||||
)
|
||||
|
||||
type BackupController struct {
|
||||
@@ -197,7 +199,7 @@ func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
|
||||
|
||||
response, err := c.backupService.GenerateDownloadToken(user, id)
|
||||
if err != nil {
|
||||
if err == backups_download.ErrDownloadAlreadyInProgress {
|
||||
if errors.Is(err, backups_download.ErrDownloadAlreadyInProgress) {
|
||||
ctx.JSON(
|
||||
http.StatusConflict,
|
||||
gin.H{
|
||||
@@ -248,7 +250,7 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
|
||||
downloadToken, rateLimiter, err := c.backupService.ValidateDownloadToken(token)
|
||||
if err != nil {
|
||||
if err == backups_download.ErrDownloadAlreadyInProgress {
|
||||
if errors.Is(err, backups_download.ErrDownloadAlreadyInProgress) {
|
||||
ctx.JSON(
|
||||
http.StatusConflict,
|
||||
gin.H{
|
||||
|
||||
@@ -3,15 +3,13 @@ package backups_controllers
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_dto "databasus-backend/internal/features/backups/backups/dto"
|
||||
backups_services "databasus-backend/internal/features/backups/backups/services"
|
||||
"databasus-backend/internal/features/databases"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
backups_dto "databasus-backend/internal/features/backups/backups/dto"
|
||||
backups_services "databasus-backend/internal/features/backups/backups/services"
|
||||
"databasus-backend/internal/features/databases"
|
||||
)
|
||||
|
||||
// PostgreWalBackupController handles WAL backup endpoints used by the databasus-cli agent.
|
||||
@@ -25,8 +23,11 @@ func (c *PostgreWalBackupController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
walRoutes := router.Group("/backups/postgres/wal")
|
||||
|
||||
walRoutes.GET("/next-full-backup-time", c.GetNextFullBackupTime)
|
||||
walRoutes.GET("/is-wal-chain-valid-since-last-full-backup", c.IsWalChainValidSinceLastBackup)
|
||||
walRoutes.POST("/error", c.ReportError)
|
||||
walRoutes.POST("/upload", c.Upload)
|
||||
walRoutes.POST("/upload/wal", c.UploadWalSegment)
|
||||
walRoutes.POST("/upload/full-start", c.StartFullBackupUpload)
|
||||
walRoutes.POST("/upload/full-complete", c.CompleteFullBackupUpload)
|
||||
walRoutes.GET("/restore/plan", c.GetRestorePlan)
|
||||
walRoutes.GET("/restore/download", c.DownloadBackupFile)
|
||||
}
|
||||
@@ -90,91 +91,66 @@ func (c *PostgreWalBackupController) ReportError(ctx *gin.Context) {
|
||||
ctx.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
// Upload
|
||||
// @Summary Stream upload a basebackup or WAL segment
|
||||
// @Description Accepts a zstd-compressed binary stream and stores it in the database's configured storage.
|
||||
// The server generates the storage filename; agents do not control the destination path.
|
||||
// For WAL segment uploads the server validates the WAL chain and returns 409 if a gap is detected
|
||||
// or 400 if no full backup exists yet (agent should trigger a full basebackup in both cases).
|
||||
// IsWalChainValidSinceLastBackup
|
||||
// @Summary Check WAL chain validity since last full backup
|
||||
// @Description Checks whether the WAL chain is continuous since the last completed full backup.
|
||||
// Returns isValid=true if the chain is intact, or isValid=false with error details if not.
|
||||
// @Tags backups-wal
|
||||
// @Accept application/octet-stream
|
||||
// @Produce json
|
||||
// @Security AgentToken
|
||||
// @Param X-Upload-Type header string true "Upload type" Enums(basebackup, wal)
|
||||
// @Param X-Wal-Segment-Name header string false "24-hex WAL segment identifier (required for wal uploads, e.g. 0000000100000001000000AB)"
|
||||
// @Param X-Wal-Segment-Size header int false "WAL segment size in bytes reported by the PostgreSQL instance (default: 16777216)"
|
||||
// @Param fullBackupWalStartSegment query string false "First WAL segment needed to make the basebackup consistent (required for basebackup uploads)"
|
||||
// @Param fullBackupWalStopSegment query string false "Last WAL segment included in the basebackup (required for basebackup uploads)"
|
||||
// @Success 204
|
||||
// @Failure 400 {object} backups_dto.UploadGapResponse "No full backup exists (error: no_full_backup)"
|
||||
// @Success 200 {object} backups_dto.IsWalChainValidResponse
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 409 {object} backups_dto.UploadGapResponse "WAL chain gap detected (error: gap_detected)"
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /backups/postgres/wal/upload [post]
|
||||
func (c *PostgreWalBackupController) Upload(ctx *gin.Context) {
|
||||
// @Router /backups/postgres/wal/is-wal-chain-valid-since-last-full-backup [get]
|
||||
func (c *PostgreWalBackupController) IsWalChainValidSinceLastBackup(ctx *gin.Context) {
|
||||
database, err := c.getDatabase(ctx)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
|
||||
return
|
||||
}
|
||||
|
||||
uploadType := backups_core.PgWalUploadType(ctx.GetHeader("X-Upload-Type"))
|
||||
if uploadType != backups_core.PgWalUploadTypeBasebackup &&
|
||||
uploadType != backups_core.PgWalUploadTypeWal {
|
||||
response, err := c.walService.IsWalChainValid(database)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// UploadWalSegment
|
||||
// @Summary Stream upload a WAL segment
|
||||
// @Description Accepts a zstd-compressed WAL segment binary stream and stores it in the database's configured storage.
|
||||
// WAL segments are accepted unconditionally.
|
||||
// @Tags backups-wal
|
||||
// @Accept application/octet-stream
|
||||
// @Security AgentToken
|
||||
// @Param X-Wal-Segment-Name header string true "24-hex WAL segment identifier (e.g. 0000000100000001000000AB)"
|
||||
// @Success 204
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /backups/postgres/wal/upload/wal [post]
|
||||
func (c *PostgreWalBackupController) UploadWalSegment(ctx *gin.Context) {
|
||||
database, err := c.getDatabase(ctx)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
|
||||
return
|
||||
}
|
||||
|
||||
walSegmentName := ctx.GetHeader("X-Wal-Segment-Name")
|
||||
if walSegmentName == "" {
|
||||
ctx.JSON(
|
||||
http.StatusBadRequest,
|
||||
gin.H{"error": "X-Upload-Type must be 'basebackup' or 'wal'"},
|
||||
gin.H{"error": "X-Wal-Segment-Name is required for wal uploads"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
walSegmentName := ""
|
||||
if uploadType == backups_core.PgWalUploadTypeWal {
|
||||
walSegmentName = ctx.GetHeader("X-Wal-Segment-Name")
|
||||
if walSegmentName == "" {
|
||||
ctx.JSON(
|
||||
http.StatusBadRequest,
|
||||
gin.H{"error": "X-Wal-Segment-Name is required for wal uploads"},
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if uploadType == backups_core.PgWalUploadTypeBasebackup {
|
||||
if ctx.Query("fullBackupWalStartSegment") == "" ||
|
||||
ctx.Query("fullBackupWalStopSegment") == "" {
|
||||
ctx.JSON(
|
||||
http.StatusBadRequest,
|
||||
gin.H{
|
||||
"error": "fullBackupWalStartSegment and fullBackupWalStopSegment are required for basebackup uploads",
|
||||
},
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
walSegmentSizeBytes := int64(0)
|
||||
if raw := ctx.GetHeader("X-Wal-Segment-Size"); raw != "" {
|
||||
parsed, parseErr := strconv.ParseInt(raw, 10, 64)
|
||||
if parseErr != nil || parsed <= 0 {
|
||||
ctx.JSON(
|
||||
http.StatusBadRequest,
|
||||
gin.H{"error": "X-Wal-Segment-Size must be a positive integer"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
walSegmentSizeBytes = parsed
|
||||
}
|
||||
|
||||
gapResp, uploadErr := c.walService.UploadWal(
|
||||
uploadErr := c.walService.UploadWalSegment(
|
||||
ctx.Request.Context(),
|
||||
database,
|
||||
uploadType,
|
||||
walSegmentName,
|
||||
ctx.Query("fullBackupWalStartSegment"),
|
||||
ctx.Query("fullBackupWalStopSegment"),
|
||||
walSegmentSizeBytes,
|
||||
ctx.Request.Body,
|
||||
)
|
||||
|
||||
@@ -183,17 +159,89 @@ func (c *PostgreWalBackupController) Upload(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if gapResp != nil {
|
||||
if gapResp.Error == "no_full_backup" {
|
||||
ctx.JSON(http.StatusBadRequest, gapResp)
|
||||
return
|
||||
}
|
||||
ctx.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusConflict, gapResp)
|
||||
// StartFullBackupUpload
|
||||
// @Summary Stream upload a full basebackup (Phase 1)
|
||||
// @Description Accepts a zstd-compressed basebackup binary stream and stores it in the database's configured storage.
|
||||
// Returns a backupId that must be completed via /upload/full-complete with WAL segment names.
|
||||
// @Tags backups-wal
|
||||
// @Accept application/octet-stream
|
||||
// @Produce json
|
||||
// @Security AgentToken
|
||||
// @Success 200 {object} backups_dto.UploadBasebackupResponse
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /backups/postgres/wal/upload/full-start [post]
|
||||
func (c *PostgreWalBackupController) StartFullBackupUpload(ctx *gin.Context) {
|
||||
database, err := c.getDatabase(ctx)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Status(http.StatusNoContent)
|
||||
backupID, uploadErr := c.walService.UploadBasebackup(
|
||||
ctx.Request.Context(),
|
||||
database,
|
||||
ctx.Request.Body,
|
||||
)
|
||||
|
||||
if uploadErr != nil {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": uploadErr.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, backups_dto.UploadBasebackupResponse{
|
||||
BackupID: backupID,
|
||||
})
|
||||
}
|
||||
|
||||
// CompleteFullBackupUpload
|
||||
// @Summary Complete a previously uploaded basebackup (Phase 2)
|
||||
// @Description Sets WAL segment names and marks the basebackup as completed, or marks it as failed if an error is provided.
|
||||
// @Tags backups-wal
|
||||
// @Accept json
|
||||
// @Security AgentToken
|
||||
// @Param request body backups_dto.FinalizeBasebackupRequest true "Completion details"
|
||||
// @Success 200
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /backups/postgres/wal/upload/full-complete [post]
|
||||
func (c *PostgreWalBackupController) CompleteFullBackupUpload(ctx *gin.Context) {
|
||||
database, err := c.getDatabase(ctx)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
|
||||
return
|
||||
}
|
||||
|
||||
var request backups_dto.FinalizeBasebackupRequest
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if request.Error == nil && (request.StartSegment == "" || request.StopSegment == "") {
|
||||
ctx.JSON(
|
||||
http.StatusBadRequest,
|
||||
gin.H{"error": "startSegment and stopSegment are required when no error is provided"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.walService.FinalizeBasebackup(
|
||||
database,
|
||||
request.BackupID,
|
||||
request.StartSegment,
|
||||
request.StopSegment,
|
||||
request.Error,
|
||||
); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
// GetRestorePlan
|
||||
|
||||
@@ -10,6 +10,11 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_dto "databasus-backend/internal/features/backups/backups/dto"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
@@ -23,11 +28,6 @@ import (
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_WalUpload_InProgressStatusSetBeforeStream(t *testing.T) {
|
||||
@@ -38,7 +38,7 @@ func Test_WalUpload_InProgressStatusSetBeforeStream(t *testing.T) {
|
||||
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
req := newWalUploadRequest(pr, agentToken, "wal", "000000010000000100000011", "", "")
|
||||
req := newWalSegmentUploadRequest(pr, agentToken, "000000010000000100000011")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
done := make(chan struct{})
|
||||
@@ -67,7 +67,7 @@ func Test_WalUpload_CompletedStatusAfterSuccessfulStream(t *testing.T) {
|
||||
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
|
||||
|
||||
body := bytes.NewReader([]byte("wal segment content"))
|
||||
req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000011", "", "")
|
||||
req := newWalSegmentUploadRequest(body, agentToken, "000000010000000100000011")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
@@ -99,7 +99,7 @@ func Test_WalUpload_FailedStatusWithErrorOnStreamError(t *testing.T) {
|
||||
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
req := newWalUploadRequest(pr, agentToken, "wal", "000000010000000100000011", "", "")
|
||||
req := newWalSegmentUploadRequest(pr, agentToken, "000000010000000100000011")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
done := make(chan struct{})
|
||||
@@ -129,59 +129,171 @@ func Test_WalUpload_FailedStatusWithErrorOnStreamError(t *testing.T) {
|
||||
assert.NotNil(t, walBackup.FailMessage)
|
||||
}
|
||||
|
||||
func Test_WalUpload_Basebackup_MissingWalSegments_Returns400(t *testing.T) {
|
||||
func Test_WalUpload_Basebackup_StreamingUpload_Returns200WithBackupId(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
body := bytes.NewReader([]byte("basebackup content"))
|
||||
req := newWalUploadRequest(body, agentToken, backups_core.PgWalUploadTypeBasebackup, "", "", "")
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/v1/backups/postgres/wal/upload/full-start", body)
|
||||
req.Header.Set("Authorization", agentToken)
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var response backups_dto.UploadBasebackupResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
assert.NotEqual(t, uuid.Nil, response.BackupID)
|
||||
|
||||
backup, err := backups_core.GetBackupRepository().FindByID(response.BackupID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusInProgress, backup.Status)
|
||||
assert.NotNil(t, backup.UploadCompletedAt)
|
||||
}
|
||||
|
||||
func Test_FinalizeBasebackup_ValidSegments_MarksCompleted(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
backupID := uploadBasebackupPhase1(t, router, agentToken)
|
||||
|
||||
completeFullBackupUpload(t, router, agentToken, backupID,
|
||||
"000000010000000100000001", "000000010000000100000010", nil)
|
||||
|
||||
backup, err := backups_core.GetBackupRepository().FindByID(backupID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
require.NotNil(t, backup.PgFullBackupWalStartSegmentName)
|
||||
assert.Equal(t, "000000010000000100000001", *backup.PgFullBackupWalStartSegmentName)
|
||||
require.NotNil(t, backup.PgFullBackupWalStopSegmentName)
|
||||
assert.Equal(t, "000000010000000100000010", *backup.PgFullBackupWalStopSegmentName)
|
||||
}
|
||||
|
||||
func Test_FinalizeBasebackup_WithError_MarksFailed(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
backupID := uploadBasebackupPhase1(t, router, agentToken)
|
||||
|
||||
errMsg := "pg_basebackup stderr parse failed"
|
||||
completeFullBackupUpload(t, router, agentToken, backupID, "", "", &errMsg)
|
||||
|
||||
backup, err := backups_core.GetBackupRepository().FindByID(backupID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusFailed, backup.Status)
|
||||
require.NotNil(t, backup.FailMessage)
|
||||
assert.Equal(t, errMsg, *backup.FailMessage)
|
||||
}
|
||||
|
||||
func Test_FinalizeBasebackup_InvalidBackupId_Returns400(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
nonExistentID := uuid.New()
|
||||
body, _ := json.Marshal(backups_dto.FinalizeBasebackupRequest{
|
||||
BackupID: nonExistentID,
|
||||
StartSegment: "000000010000000100000001",
|
||||
StopSegment: "000000010000000100000010",
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/backups/postgres/wal/upload/full-complete",
|
||||
bytes.NewReader(body),
|
||||
)
|
||||
req.Header.Set("Authorization", agentToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
}
|
||||
|
||||
func Test_WalUpload_WalSegment_NoFullBackup_Returns400(t *testing.T) {
|
||||
func Test_FinalizeBasebackup_AlreadyCompleted_Returns400(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
backupID := uploadBasebackupPhase1(t, router, agentToken)
|
||||
|
||||
completeFullBackupUpload(t, router, agentToken, backupID,
|
||||
"000000010000000100000001", "000000010000000100000010", nil)
|
||||
|
||||
// Second finalize should fail.
|
||||
body, _ := json.Marshal(backups_dto.FinalizeBasebackupRequest{
|
||||
BackupID: backupID,
|
||||
StartSegment: "000000010000000100000001",
|
||||
StopSegment: "000000010000000100000010",
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/backups/postgres/wal/upload/full-complete",
|
||||
bytes.NewReader(body),
|
||||
)
|
||||
req.Header.Set("Authorization", agentToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
}
|
||||
|
||||
func Test_FinalizeBasebackup_InvalidToken_Returns401(t *testing.T) {
|
||||
router, db, storage, _, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
body, _ := json.Marshal(backups_dto.FinalizeBasebackupRequest{
|
||||
BackupID: uuid.New(),
|
||||
StartSegment: "000000010000000100000001",
|
||||
StopSegment: "000000010000000100000010",
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/backups/postgres/wal/upload/full-complete",
|
||||
bytes.NewReader(body),
|
||||
)
|
||||
req.Header.Set("Authorization", "invalid-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
|
||||
func Test_WalUpload_WalSegment_WithoutFullBackup_Returns204(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
// No full backup inserted — chain anchor is missing.
|
||||
body := bytes.NewReader([]byte("wal content"))
|
||||
req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000001", "", "")
|
||||
req := newWalSegmentUploadRequest(body, agentToken, "000000010000000100000001")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
|
||||
var resp backups_dto.UploadGapResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||
assert.Equal(t, "no_full_backup", resp.Error)
|
||||
assert.Equal(t, http.StatusNoContent, w.Code)
|
||||
}
|
||||
|
||||
func Test_WalUpload_WalSegment_GapDetected_Returns409WithExpectedAndReceived(t *testing.T) {
|
||||
func Test_WalUpload_WalSegment_WithGap_Returns204(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
// Full backup stops at ...0010; upload one WAL segment at ...0011.
|
||||
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
|
||||
uploadWalSegment(t, router, agentToken, "000000010000000100000011")
|
||||
|
||||
// Send ...0013 — should be rejected because ...0012 is missing.
|
||||
// Skip ...0012, upload ...0013 — should succeed (no chain validation on upload).
|
||||
body := bytes.NewReader([]byte("wal content"))
|
||||
req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000013", "", "")
|
||||
req := newWalSegmentUploadRequest(body, agentToken, "000000010000000100000013")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusConflict, w.Code)
|
||||
|
||||
var resp backups_dto.UploadGapResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||
assert.Equal(t, "gap_detected", resp.Error)
|
||||
assert.Equal(t, "000000010000000100000012", resp.ExpectedSegmentName)
|
||||
assert.Equal(t, "000000010000000100000013", resp.ReceivedSegmentName)
|
||||
assert.Equal(t, http.StatusNoContent, w.Code)
|
||||
}
|
||||
|
||||
func Test_WalUpload_WalSegment_DuplicateSegment_Returns200Idempotent(t *testing.T) {
|
||||
@@ -192,14 +304,14 @@ func Test_WalUpload_WalSegment_DuplicateSegment_Returns200Idempotent(t *testing.
|
||||
|
||||
// Upload ...0011 once.
|
||||
body1 := bytes.NewReader([]byte("wal content"))
|
||||
req1 := newWalUploadRequest(body1, agentToken, "wal", "000000010000000100000011", "", "")
|
||||
req1 := newWalSegmentUploadRequest(body1, agentToken, "000000010000000100000011")
|
||||
w1 := httptest.NewRecorder()
|
||||
router.ServeHTTP(w1, req1)
|
||||
require.Equal(t, http.StatusNoContent, w1.Code)
|
||||
|
||||
// Upload the same segment again — must return 204 (idempotent).
|
||||
body2 := bytes.NewReader([]byte("wal content"))
|
||||
req2 := newWalUploadRequest(body2, agentToken, "wal", "000000010000000100000011", "", "")
|
||||
req2 := newWalSegmentUploadRequest(body2, agentToken, "000000010000000100000011")
|
||||
w2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(w2, req2)
|
||||
|
||||
@@ -228,7 +340,7 @@ func Test_WalUpload_WalSegment_ValidNextSegment_Returns200AndCreatesRecord(t *te
|
||||
|
||||
// First WAL segment after the full backup stop segment.
|
||||
body := bytes.NewReader([]byte("wal segment data"))
|
||||
req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000011", "", "")
|
||||
req := newWalSegmentUploadRequest(body, agentToken, "000000010000000100000011")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
@@ -255,6 +367,108 @@ func Test_WalUpload_WalSegment_ValidNextSegment_Returns200AndCreatesRecord(t *te
|
||||
assert.Equal(t, "000000010000000100000011", *walBackup.PgWalSegmentName)
|
||||
}
|
||||
|
||||
func Test_IsWalChainValid_NoFullBackup_ReturnsFalse(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
var response backups_dto.IsWalChainValidResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t, router,
|
||||
"/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup",
|
||||
agentToken,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.False(t, response.IsValid)
|
||||
assert.Equal(t, "no_full_backup", response.Error)
|
||||
}
|
||||
|
||||
func Test_IsWalChainValid_FullBackupOnly_ReturnsTrue(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
|
||||
|
||||
var response backups_dto.IsWalChainValidResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t, router,
|
||||
"/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup",
|
||||
agentToken,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.True(t, response.IsValid)
|
||||
assert.Empty(t, response.Error)
|
||||
}
|
||||
|
||||
func Test_IsWalChainValid_ContinuousChain_ReturnsTrue(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
|
||||
uploadWalSegment(t, router, agentToken, "000000010000000100000011")
|
||||
uploadWalSegment(t, router, agentToken, "000000010000000100000012")
|
||||
uploadWalSegment(t, router, agentToken, "000000010000000100000013")
|
||||
|
||||
var response backups_dto.IsWalChainValidResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t, router,
|
||||
"/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup",
|
||||
agentToken,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.True(t, response.IsValid)
|
||||
}
|
||||
|
||||
func Test_IsWalChainValid_BrokenChain_ReturnsFalse(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010")
|
||||
uploadWalSegment(t, router, agentToken, "000000010000000100000011")
|
||||
uploadWalSegment(t, router, agentToken, "000000010000000100000012")
|
||||
uploadWalSegment(t, router, agentToken, "000000010000000100000013")
|
||||
|
||||
// Delete the middle segment to create a gap.
|
||||
middleSeg, err := backups_core.GetBackupRepository().FindWalSegmentByName(
|
||||
db.ID, "000000010000000100000012",
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, middleSeg)
|
||||
require.NoError(t, backups_core.GetBackupRepository().DeleteByID(middleSeg.ID))
|
||||
|
||||
var response backups_dto.IsWalChainValidResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t, router,
|
||||
"/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup",
|
||||
agentToken,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.False(t, response.IsValid)
|
||||
assert.Equal(t, "wal_chain_broken", response.Error)
|
||||
assert.Equal(t, "000000010000000100000011", response.LastContiguousSegment)
|
||||
}
|
||||
|
||||
func Test_IsWalChainValid_InvalidToken_Returns401(t *testing.T) {
|
||||
router, db, storage, _, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
|
||||
resp := test_utils.MakeGetRequest(
|
||||
t, router,
|
||||
"/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup",
|
||||
"invalid-token",
|
||||
http.StatusUnauthorized,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(resp.Body), "invalid agent token")
|
||||
}
|
||||
|
||||
func Test_ReportError_ValidTokenAndError_CreatesFailedBackupRecord(t *testing.T) {
|
||||
router, db, storage, agentToken, _ := createWalTestSetup(t)
|
||||
defer removeWalTestSetup(db, storage)
|
||||
@@ -457,29 +671,14 @@ func Test_GetNextFullBackupTime_WalSegmentAfterFullBackup_DoesNotImpactTime(t *t
|
||||
|
||||
setHourlyInterval(t, router, db.ID, ownerToken)
|
||||
|
||||
// Upload basebackup via API.
|
||||
bbBody := bytes.NewReader([]byte("basebackup content"))
|
||||
bbReq := newWalUploadRequest(
|
||||
bbBody, agentToken, backups_core.PgWalUploadTypeBasebackup, "",
|
||||
"000000010000000100000001", "000000010000000100000010",
|
||||
)
|
||||
bbW := httptest.NewRecorder()
|
||||
router.ServeHTTP(bbW, bbReq)
|
||||
require.Equal(t, http.StatusNoContent, bbW.Code)
|
||||
uploadBasebackup(t, router, agentToken,
|
||||
"000000010000000100000001", "000000010000000100000010")
|
||||
|
||||
// Shift the full backup's CreatedAt to 2 hours ago.
|
||||
twoHoursAgo := time.Now().UTC().Add(-2 * time.Hour)
|
||||
updateLastFullBackupTime(t, db.ID, twoHoursAgo)
|
||||
|
||||
// Upload WAL segment via API.
|
||||
walBody := bytes.NewReader([]byte("wal segment content"))
|
||||
walReq := newWalUploadRequest(
|
||||
walBody, agentToken, backups_core.PgWalUploadTypeWal,
|
||||
"000000010000000100000011", "", "",
|
||||
)
|
||||
walW := httptest.NewRecorder()
|
||||
router.ServeHTTP(walW, walReq)
|
||||
require.Equal(t, http.StatusNoContent, walW.Code)
|
||||
uploadWalSegment(t, router, agentToken, "000000010000000100000011")
|
||||
|
||||
var response backups_dto.GetNextFullBackupTimeResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
@@ -508,15 +707,8 @@ func Test_GetNextFullBackupTime_FailedBasebackup_DoesNotImpactTime(t *testing.T)
|
||||
|
||||
setHourlyInterval(t, router, db.ID, ownerToken)
|
||||
|
||||
// Upload a successful basebackup via API.
|
||||
bbBody := bytes.NewReader([]byte("basebackup content"))
|
||||
bbReq := newWalUploadRequest(
|
||||
bbBody, agentToken, backups_core.PgWalUploadTypeBasebackup, "",
|
||||
"000000010000000100000001", "000000010000000100000010",
|
||||
)
|
||||
bbW := httptest.NewRecorder()
|
||||
router.ServeHTTP(bbW, bbReq)
|
||||
require.Equal(t, http.StatusNoContent, bbW.Code)
|
||||
uploadBasebackup(t, router, agentToken,
|
||||
"000000010000000100000001", "000000010000000100000010")
|
||||
|
||||
// Shift the full backup's CreatedAt to 2 hours ago.
|
||||
twoHoursAgo := time.Now().UTC().Add(-2 * time.Hour)
|
||||
@@ -563,15 +755,8 @@ func Test_GetNextFullBackupTime_NewCompletedFullBackup_ImpactsTime(t *testing.T)
|
||||
|
||||
setHourlyInterval(t, router, db.ID, ownerToken)
|
||||
|
||||
// Upload first basebackup via API.
|
||||
bb1 := bytes.NewReader([]byte("first basebackup"))
|
||||
bb1Req := newWalUploadRequest(
|
||||
bb1, agentToken, backups_core.PgWalUploadTypeBasebackup, "",
|
||||
"000000010000000100000001", "000000010000000100000010",
|
||||
)
|
||||
bb1W := httptest.NewRecorder()
|
||||
router.ServeHTTP(bb1W, bb1Req)
|
||||
require.Equal(t, http.StatusNoContent, bb1W.Code)
|
||||
uploadBasebackup(t, router, agentToken,
|
||||
"000000010000000100000001", "000000010000000100000010")
|
||||
|
||||
// Shift the first backup's CreatedAt to 3 hours ago.
|
||||
threeHoursAgo := time.Now().UTC().Add(-3 * time.Hour)
|
||||
@@ -595,15 +780,8 @@ func Test_GetNextFullBackupTime_NewCompletedFullBackup_ImpactsTime(t *testing.T)
|
||||
"first next time should be in the past (old backup)",
|
||||
)
|
||||
|
||||
// Upload second basebackup via API (created now).
|
||||
bb2 := bytes.NewReader([]byte("second basebackup"))
|
||||
bb2Req := newWalUploadRequest(
|
||||
bb2, agentToken, backups_core.PgWalUploadTypeBasebackup, "",
|
||||
"000000010000000100000011", "000000010000000100000020",
|
||||
)
|
||||
bb2W := httptest.NewRecorder()
|
||||
router.ServeHTTP(bb2W, bb2Req)
|
||||
require.Equal(t, http.StatusNoContent, bb2W.Code)
|
||||
uploadBasebackup(t, router, agentToken,
|
||||
"000000010000000100000011", "000000010000000100000020")
|
||||
|
||||
var secondResponse backups_dto.GetNextFullBackupTimeResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
@@ -841,15 +1019,18 @@ func Test_DownloadRestoreFile_UploadThenDownload_ContentMatches(t *testing.T) {
|
||||
|
||||
uploadContent := "test-basebackup-content-for-download"
|
||||
body := bytes.NewReader([]byte(uploadContent))
|
||||
req := newWalUploadRequest(
|
||||
body, agentToken, backups_core.PgWalUploadTypeBasebackup, "",
|
||||
"000000010000000100000001", "000000010000000100000010",
|
||||
)
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/v1/backups/postgres/wal/upload/full-start", body)
|
||||
req.Header.Set("Authorization", agentToken)
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
require.Equal(t, http.StatusNoContent, w.Code)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
WaitForBackupCompletion(t, db.ID, 0, 5*time.Second)
|
||||
var uploadResp backups_dto.UploadBasebackupResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &uploadResp))
|
||||
|
||||
completeFullBackupUpload(t, router, agentToken, uploadResp.BackupID,
|
||||
"000000010000000100000001", "000000010000000100000010", nil)
|
||||
|
||||
var planResp backups_dto.GetRestorePlanResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
@@ -883,7 +1064,7 @@ func Test_DownloadRestoreFile_WalSegment_UploadThenDownload_ContentMatches(t *te
|
||||
|
||||
walContent := "test-wal-segment-content-for-download"
|
||||
body := bytes.NewReader([]byte(walContent))
|
||||
req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000011", "", "")
|
||||
req := newWalSegmentUploadRequest(body, agentToken, "000000010000000100000011")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
require.Equal(t, http.StatusNoContent, w.Code)
|
||||
@@ -1088,35 +1269,81 @@ func removeWalTestSetup(db *databases.Database, storage *storages.Storage) {
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
}
|
||||
|
||||
func newWalUploadRequest(
|
||||
func newWalSegmentUploadRequest(
|
||||
body io.Reader,
|
||||
agentToken string,
|
||||
uploadType backups_core.PgWalUploadType,
|
||||
walSegmentName string,
|
||||
walStart string,
|
||||
walStop string,
|
||||
segmentName string,
|
||||
) *http.Request {
|
||||
url := "/api/v1/backups/postgres/wal/upload"
|
||||
if walStart != "" || walStop != "" {
|
||||
url += "?fullBackupWalStartSegment=" + walStart + "&fullBackupWalStopSegment=" + walStop
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, url, body)
|
||||
req, err := http.NewRequest(http.MethodPost, "/api/v1/backups/postgres/wal/upload/wal", body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", agentToken)
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
req.Header.Set("X-Upload-Type", string(uploadType))
|
||||
|
||||
if walSegmentName != "" {
|
||||
req.Header.Set("X-Wal-Segment-Name", walSegmentName)
|
||||
}
|
||||
req.Header.Set("X-Wal-Segment-Name", segmentName)
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func uploadBasebackupPhase1(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
agentToken string,
|
||||
) uuid.UUID {
|
||||
t.Helper()
|
||||
|
||||
body := bytes.NewReader([]byte("test-basebackup-content"))
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/v1/backups/postgres/wal/upload/full-start", body)
|
||||
req.Header.Set("Authorization", agentToken)
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var response backups_dto.UploadBasebackupResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
require.NotEqual(t, uuid.Nil, response.BackupID)
|
||||
|
||||
return response.BackupID
|
||||
}
|
||||
|
||||
func completeFullBackupUpload(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
agentToken string,
|
||||
backupID uuid.UUID,
|
||||
walStart string,
|
||||
walStop string,
|
||||
errMsg *string,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
request := backups_dto.FinalizeBasebackupRequest{
|
||||
BackupID: backupID,
|
||||
StartSegment: walStart,
|
||||
StopSegment: walStop,
|
||||
Error: errMsg,
|
||||
}
|
||||
|
||||
reqBody, _ := json.Marshal(request)
|
||||
req, _ := http.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/backups/postgres/wal/upload/full-complete",
|
||||
bytes.NewReader(reqBody),
|
||||
)
|
||||
req.Header.Set("Authorization", agentToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func uploadBasebackup(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
@@ -1126,15 +1353,8 @@ func uploadBasebackup(
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
body := bytes.NewReader([]byte("test-basebackup-content"))
|
||||
req := newWalUploadRequest(
|
||||
body, agentToken, backups_core.PgWalUploadTypeBasebackup, "",
|
||||
walStart, walStop,
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusNoContent, w.Code)
|
||||
backupID := uploadBasebackupPhase1(t, router, agentToken)
|
||||
completeFullBackupUpload(t, router, agentToken, backupID, walStart, walStop, nil)
|
||||
}
|
||||
|
||||
func uploadWalSegment(
|
||||
@@ -1146,9 +1366,12 @@ func uploadWalSegment(
|
||||
t.Helper()
|
||||
|
||||
body := bytes.NewReader([]byte("test-wal-segment-content"))
|
||||
req := newWalUploadRequest(
|
||||
body, agentToken, backups_core.PgWalUploadTypeWal, segmentName, "", "",
|
||||
)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/v1/backups/postgres/wal/upload/wal", body)
|
||||
req.Header.Set("Authorization", agentToken)
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
req.Header.Set("X-Wal-Segment-Name", segmentName)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
|
||||
@@ -4,14 +4,14 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func CreateTestRouter() *gin.Engine {
|
||||
|
||||
@@ -4,10 +4,10 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
files_utils "databasus-backend/internal/util/files"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type PgWalBackupType string
|
||||
@@ -43,7 +43,8 @@ type Backup struct {
|
||||
PgVersion *string `json:"pgVersion" gorm:"column:pg_version;type:text"`
|
||||
PgWalSegmentName *string `json:"pgWalSegmentName" gorm:"column:pg_wal_segment_name;type:text"`
|
||||
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
|
||||
UploadCompletedAt *time.Time `json:"uploadCompletedAt" gorm:"column:upload_completed_at"`
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
|
||||
}
|
||||
|
||||
func (b *Backup) GenerateFilename(dbName string) {
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
package backups_core
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/storage"
|
||||
"errors"
|
||||
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"databasus-backend/internal/storage"
|
||||
)
|
||||
|
||||
type BackupRepository struct{}
|
||||
@@ -88,7 +88,7 @@ func (r *BackupRepository) FindLastByDatabaseID(databaseID uuid.UUID) (*Backup,
|
||||
Where("database_id = ?", databaseID).
|
||||
Order("created_at DESC").
|
||||
First(&backup).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -349,6 +349,24 @@ func (r *BackupRepository) FindWalSegmentByName(
|
||||
return &backup, nil
|
||||
}
|
||||
|
||||
func (r *BackupRepository) FindStaleUploadedBasebackups(olderThan time.Time) ([]*Backup, error) {
|
||||
var backups []*Backup
|
||||
|
||||
err := storage.
|
||||
GetDb().
|
||||
Where(
|
||||
"status = ? AND upload_completed_at IS NOT NULL AND upload_completed_at < ?",
|
||||
BackupStatusInProgress,
|
||||
olderThan,
|
||||
).
|
||||
Find(&backups).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return backups, nil
|
||||
}
|
||||
|
||||
func (r *BackupRepository) FindLastWalSegmentAfter(
|
||||
databaseID uuid.UUID,
|
||||
afterSegmentName string,
|
||||
|
||||
@@ -13,9 +13,11 @@ var downloadTokenRepository = &DownloadTokenRepository{}
|
||||
|
||||
var downloadTracker = NewDownloadTracker(cache_utils.GetValkeyClient())
|
||||
|
||||
var bandwidthManager *BandwidthManager
|
||||
var downloadTokenService *DownloadTokenService
|
||||
var downloadTokenBackgroundService *DownloadTokenBackgroundService
|
||||
var (
|
||||
bandwidthManager *BandwidthManager
|
||||
downloadTokenService *DownloadTokenService
|
||||
downloadTokenBackgroundService *DownloadTokenBackgroundService
|
||||
)
|
||||
|
||||
func init() {
|
||||
env := config.GetEnv()
|
||||
|
||||
@@ -66,9 +66,7 @@ func (rl *RateLimiter) Wait(bytes int64) {
|
||||
tokensNeeded := float64(bytes) - rl.availableTokens
|
||||
waitTime := time.Duration(tokensNeeded/float64(rl.bytesPerSecond)*1000) * time.Millisecond
|
||||
|
||||
if waitTime < time.Millisecond {
|
||||
waitTime = time.Millisecond
|
||||
}
|
||||
waitTime = max(waitTime, time.Millisecond)
|
||||
|
||||
rl.mu.Unlock()
|
||||
time.Sleep(waitTime)
|
||||
|
||||
@@ -2,12 +2,14 @@ package backups_download
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"databasus-backend/internal/storage"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"databasus-backend/internal/storage"
|
||||
)
|
||||
|
||||
type DownloadTokenRepository struct{}
|
||||
@@ -28,9 +30,8 @@ func (r *DownloadTokenRepository) FindByToken(token string) (*DownloadToken, err
|
||||
err := storage.GetDb().
|
||||
Where("token = ?", token).
|
||||
First(&downloadToken).Error
|
||||
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/valkey-io/valkey-go"
|
||||
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -16,9 +17,7 @@ const (
|
||||
downloadHeartbeatDelay = 3 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDownloadAlreadyInProgress = errors.New("download already in progress for this user")
|
||||
)
|
||||
var ErrDownloadAlreadyInProgress = errors.New("download already in progress for this user")
|
||||
|
||||
type DownloadTracker struct {
|
||||
cache *cache_utils.CacheUtil[string]
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
package backups_dto
|
||||
|
||||
import (
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/encryption"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/encryption"
|
||||
)
|
||||
|
||||
type GetBackupsRequest struct {
|
||||
@@ -43,10 +44,10 @@ type ReportErrorRequest struct {
|
||||
Error string `json:"error" binding:"required"`
|
||||
}
|
||||
|
||||
type UploadGapResponse struct {
|
||||
Error string `json:"error"`
|
||||
ExpectedSegmentName string `json:"expectedSegmentName"`
|
||||
ReceivedSegmentName string `json:"receivedSegmentName"`
|
||||
type IsWalChainValidResponse struct {
|
||||
IsValid bool `json:"isValid"`
|
||||
Error string `json:"error,omitempty"`
|
||||
LastContiguousSegment string `json:"lastContiguousSegment,omitempty"`
|
||||
}
|
||||
|
||||
type RestorePlanFullBackup struct {
|
||||
@@ -76,3 +77,14 @@ type GetRestorePlanResponse struct {
|
||||
TotalSizeBytes int64 `json:"totalSizeBytes"`
|
||||
LatestAvailableSegment string `json:"latestAvailableSegment"`
|
||||
}
|
||||
|
||||
type UploadBasebackupResponse struct {
|
||||
BackupID uuid.UUID `json:"backupId"`
|
||||
}
|
||||
|
||||
type FinalizeBasebackupRequest struct {
|
||||
BackupID uuid.UUID `json:"backupId" binding:"required"`
|
||||
StartSegment string `json:"startSegment"`
|
||||
StopSegment string `json:"stopSegment"`
|
||||
Error *string `json:"error"`
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
@@ -69,7 +70,7 @@ func NewDecryptionReader(
|
||||
func (r *DecryptionReader) Read(p []byte) (n int, err error) {
|
||||
for len(r.buffer) < len(p) && !r.eof {
|
||||
if err := r.readAndDecryptChunk(); err != nil {
|
||||
if err == io.EOF {
|
||||
if errors.Is(err, io.EOF) {
|
||||
r.eof = true
|
||||
break
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package backups_services
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
@@ -15,8 +18,6 @@ import (
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
var taskCancelManager = task_cancellation.GetTaskCancelManager()
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_dto "databasus-backend/internal/features/backups/backups/dto"
|
||||
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
|
||||
@@ -16,8 +18,6 @@ import (
|
||||
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
util_wal "databasus-backend/internal/util/wal"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// PostgreWalBackupService handles WAL segment and basebackup uploads from the databasus-cli agent.
|
||||
@@ -30,75 +30,46 @@ type PostgreWalBackupService struct {
|
||||
backupService *BackupService
|
||||
}
|
||||
|
||||
// UploadWal accepts a streaming WAL segment or basebackup upload from the agent.
|
||||
// For WAL segments it validates the WAL chain before accepting. Returns an UploadGapResponse
|
||||
// (409) when the chain is broken so the agent knows to trigger a full basebackup.
|
||||
func (s *PostgreWalBackupService) UploadWal(
|
||||
// UploadWalSegment accepts a streaming WAL segment upload from the agent.
|
||||
// WAL segments are accepted unconditionally.
|
||||
func (s *PostgreWalBackupService) UploadWalSegment(
|
||||
ctx context.Context,
|
||||
database *databases.Database,
|
||||
uploadType backups_core.PgWalUploadType,
|
||||
walSegmentName string,
|
||||
fullBackupWalStartSegment string,
|
||||
fullBackupWalStopSegment string,
|
||||
walSegmentSizeBytes int64,
|
||||
body io.Reader,
|
||||
) (*backups_dto.UploadGapResponse, error) {
|
||||
) error {
|
||||
if err := s.validateWalBackupType(database); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if uploadType == backups_core.PgWalUploadTypeBasebackup {
|
||||
if fullBackupWalStartSegment == "" || fullBackupWalStopSegment == "" {
|
||||
return nil, fmt.Errorf(
|
||||
"fullBackupWalStartSegment and fullBackupWalStopSegment are required for basebackup uploads",
|
||||
)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get backup config: %w", err)
|
||||
return fmt.Errorf("failed to get backup config: %w", err)
|
||||
}
|
||||
|
||||
if backupConfig.Storage == nil {
|
||||
return nil, fmt.Errorf("no storage configured for database %s", database.ID)
|
||||
return fmt.Errorf("no storage configured for database %s", database.ID)
|
||||
}
|
||||
|
||||
if uploadType == backups_core.PgWalUploadTypeWal {
|
||||
// Idempotency: check before chain validation so a successful re-upload is
|
||||
// not misidentified as a gap.
|
||||
existing, err := s.backupRepository.FindWalSegmentByName(database.ID, walSegmentName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check for duplicate WAL segment: %w", err)
|
||||
}
|
||||
|
||||
if existing != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
gapResp, err := s.validateWalChain(database.ID, walSegmentName, walSegmentSizeBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if gapResp != nil {
|
||||
return gapResp, nil
|
||||
}
|
||||
existing, err := s.backupRepository.FindWalSegmentByName(database.ID, walSegmentName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check for duplicate WAL segment: %w", err)
|
||||
}
|
||||
|
||||
backup := s.createBackupRecord(
|
||||
if existing != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
backup := s.createWalSegmentRecord(
|
||||
database.ID,
|
||||
backupConfig.Storage.ID,
|
||||
uploadType,
|
||||
database.Name,
|
||||
walSegmentName,
|
||||
fullBackupWalStartSegment,
|
||||
fullBackupWalStopSegment,
|
||||
backupConfig.Encryption,
|
||||
)
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
return nil, fmt.Errorf("failed to create backup record: %w", err)
|
||||
return fmt.Errorf("failed to create backup record: %w", err)
|
||||
}
|
||||
|
||||
sizeBytes, streamErr := s.streamToStorage(ctx, backup, backupConfig, body)
|
||||
@@ -106,12 +77,106 @@ func (s *PostgreWalBackupService) UploadWal(
|
||||
errMsg := streamErr.Error()
|
||||
s.markFailed(backup, errMsg)
|
||||
|
||||
return nil, fmt.Errorf("upload failed: %w", streamErr)
|
||||
return fmt.Errorf("upload failed: %w", streamErr)
|
||||
}
|
||||
|
||||
s.markCompleted(backup, sizeBytes)
|
||||
|
||||
return nil, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// UploadBasebackup accepts a streaming basebackup upload from the agent (Phase 1).
|
||||
// The backup stays IN_PROGRESS with UploadCompletedAt set after streaming finishes.
|
||||
// The agent must call FinalizeBasebackup (Phase 2) with WAL segment names to complete.
|
||||
func (s *PostgreWalBackupService) UploadBasebackup(
|
||||
ctx context.Context,
|
||||
database *databases.Database,
|
||||
body io.Reader,
|
||||
) (uuid.UUID, error) {
|
||||
if err := s.validateWalBackupType(database); err != nil {
|
||||
return uuid.Nil, err
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
|
||||
if err != nil {
|
||||
return uuid.Nil, fmt.Errorf("failed to get backup config: %w", err)
|
||||
}
|
||||
|
||||
if backupConfig.Storage == nil {
|
||||
return uuid.Nil, fmt.Errorf("no storage configured for database %s", database.ID)
|
||||
}
|
||||
|
||||
backup := s.createBasebackupRecord(
|
||||
database.ID,
|
||||
backupConfig.Storage.ID,
|
||||
database.Name,
|
||||
backupConfig.Encryption,
|
||||
)
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
return uuid.Nil, fmt.Errorf("failed to create backup record: %w", err)
|
||||
}
|
||||
|
||||
sizeBytes, streamErr := s.streamToStorage(ctx, backup, backupConfig, body)
|
||||
if streamErr != nil {
|
||||
errMsg := streamErr.Error()
|
||||
s.markFailed(backup, errMsg)
|
||||
|
||||
return uuid.Nil, fmt.Errorf("upload failed: %w", streamErr)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
backup.UploadCompletedAt = &now
|
||||
backup.BackupSizeMb = float64(sizeBytes) / (1024 * 1024)
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
return uuid.Nil, fmt.Errorf("failed to update backup after upload: %w", err)
|
||||
}
|
||||
|
||||
return backup.ID, nil
|
||||
}
|
||||
|
||||
// FinalizeBasebackup completes a previously uploaded basebackup (Phase 2).
|
||||
// Sets WAL segment names and marks the backup as COMPLETED, or marks it FAILED if errorMsg is provided.
|
||||
func (s *PostgreWalBackupService) FinalizeBasebackup(
|
||||
database *databases.Database,
|
||||
backupID uuid.UUID,
|
||||
startSegment string,
|
||||
stopSegment string,
|
||||
errorMsg *string,
|
||||
) error {
|
||||
if err := s.validateWalBackupType(database); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
backup, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backup not found: %w", err)
|
||||
}
|
||||
|
||||
if backup.DatabaseID != database.ID {
|
||||
return fmt.Errorf("backup does not belong to this database")
|
||||
}
|
||||
|
||||
if backup.Status != backups_core.BackupStatusInProgress || backup.UploadCompletedAt == nil {
|
||||
return fmt.Errorf("backup is not awaiting finalization")
|
||||
}
|
||||
|
||||
if errorMsg != nil {
|
||||
s.markFailed(backup, *errorMsg)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
backup.PgFullBackupWalStartSegmentName = &startSegment
|
||||
backup.PgFullBackupWalStopSegmentName = &stopSegment
|
||||
backup.Status = backups_core.BackupStatusCompleted
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
return fmt.Errorf("failed to finalize backup: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PostgreWalBackupService) GetRestorePlan(
|
||||
@@ -225,97 +290,171 @@ func (s *PostgreWalBackupService) DownloadBackupFile(
|
||||
return s.backupService.GetBackupReader(backupID)
|
||||
}
|
||||
|
||||
func (s *PostgreWalBackupService) validateWalChain(
|
||||
databaseID uuid.UUID,
|
||||
incomingSegment string,
|
||||
walSegmentSizeBytes int64,
|
||||
) (*backups_dto.UploadGapResponse, error) {
|
||||
fullBackup, err := s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(databaseID)
|
||||
func (s *PostgreWalBackupService) GetNextFullBackupTime(
|
||||
database *databases.Database,
|
||||
) (*backups_dto.GetNextFullBackupTimeResponse, error) {
|
||||
if err := s.validateWalBackupType(database); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get backup config: %w", err)
|
||||
}
|
||||
|
||||
if backupConfig.BackupInterval == nil {
|
||||
return nil, fmt.Errorf("no backup interval configured for database %s", database.ID)
|
||||
}
|
||||
|
||||
lastFullBackup, err := s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(
|
||||
database.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query last full backup: %w", err)
|
||||
}
|
||||
|
||||
var lastBackupTime *time.Time
|
||||
if lastFullBackup != nil {
|
||||
lastBackupTime = &lastFullBackup.CreatedAt
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
nextTime := backupConfig.BackupInterval.NextTriggerTime(now, lastBackupTime)
|
||||
|
||||
return &backups_dto.GetNextFullBackupTimeResponse{
|
||||
NextFullBackupTime: nextTime,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ReportError creates a FAILED backup record with the agent's error message.
|
||||
func (s *PostgreWalBackupService) ReportError(
|
||||
database *databases.Database,
|
||||
errorMsg string,
|
||||
) error {
|
||||
if err := s.validateWalBackupType(database); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get backup config: %w", err)
|
||||
}
|
||||
|
||||
if backupConfig.Storage == nil {
|
||||
return fmt.Errorf("no storage configured for database %s", database.ID)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: backupConfig.Storage.ID,
|
||||
Status: backups_core.BackupStatusFailed,
|
||||
FailMessage: &errorMsg,
|
||||
Encryption: backupConfig.Encryption,
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
backup.GenerateFilename(database.Name)
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
return fmt.Errorf("failed to save error backup record: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsWalChainValid checks whether the WAL chain is continuous since the last completed full backup.
|
||||
func (s *PostgreWalBackupService) IsWalChainValid(
|
||||
database *databases.Database,
|
||||
) (*backups_dto.IsWalChainValidResponse, error) {
|
||||
if err := s.validateWalBackupType(database); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fullBackup, err := s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(database.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query full backup: %w", err)
|
||||
}
|
||||
|
||||
// No full backup exists yet: cannot accept WAL segments without a chain anchor.
|
||||
if fullBackup == nil || fullBackup.PgFullBackupWalStopSegmentName == nil {
|
||||
return &backups_dto.UploadGapResponse{
|
||||
Error: "no_full_backup",
|
||||
ExpectedSegmentName: "",
|
||||
ReceivedSegmentName: incomingSegment,
|
||||
return &backups_dto.IsWalChainValidResponse{
|
||||
IsValid: false,
|
||||
Error: "no_full_backup",
|
||||
}, nil
|
||||
}
|
||||
|
||||
stopSegment := *fullBackup.PgFullBackupWalStopSegmentName
|
||||
startSegment := ""
|
||||
if fullBackup.PgFullBackupWalStartSegmentName != nil {
|
||||
startSegment = *fullBackup.PgFullBackupWalStartSegmentName
|
||||
}
|
||||
|
||||
lastWal, err := s.backupRepository.FindLastWalSegmentAfter(databaseID, stopSegment)
|
||||
walSegments, err := s.backupRepository.FindCompletedWalSegmentsAfter(database.ID, startSegment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query last WAL segment: %w", err)
|
||||
return nil, fmt.Errorf("failed to query WAL segments: %w", err)
|
||||
}
|
||||
|
||||
walCalculator := util_wal.NewWalCalculator(walSegmentSizeBytes)
|
||||
|
||||
var chainTail string
|
||||
if lastWal != nil && lastWal.PgWalSegmentName != nil {
|
||||
chainTail = *lastWal.PgWalSegmentName
|
||||
} else {
|
||||
chainTail = stopSegment
|
||||
}
|
||||
|
||||
expectedNext, err := walCalculator.NextSegment(chainTail)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("WAL arithmetic failed for %q: %w", chainTail, err)
|
||||
}
|
||||
|
||||
if incomingSegment != expectedNext {
|
||||
return &backups_dto.UploadGapResponse{
|
||||
Error: "gap_detected",
|
||||
ExpectedSegmentName: expectedNext,
|
||||
ReceivedSegmentName: incomingSegment,
|
||||
chainErr := s.validateRestoreWalChain(fullBackup, walSegments)
|
||||
if chainErr != nil {
|
||||
return &backups_dto.IsWalChainValidResponse{
|
||||
IsValid: false,
|
||||
Error: chainErr.Error,
|
||||
LastContiguousSegment: chainErr.LastContiguousSegment,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
return &backups_dto.IsWalChainValidResponse{
|
||||
IsValid: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *PostgreWalBackupService) createBackupRecord(
|
||||
func (s *PostgreWalBackupService) createBasebackupRecord(
|
||||
databaseID uuid.UUID,
|
||||
storageID uuid.UUID,
|
||||
uploadType backups_core.PgWalUploadType,
|
||||
dbName string,
|
||||
walSegmentName string,
|
||||
fullBackupWalStartSegment string,
|
||||
fullBackupWalStopSegment string,
|
||||
encryption backups_config.BackupEncryption,
|
||||
) *backups_core.Backup {
|
||||
now := time.Now().UTC()
|
||||
walBackupType := backups_core.PgWalBackupTypeFullBackup
|
||||
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: databaseID,
|
||||
StorageID: storageID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
Encryption: encryption,
|
||||
CreatedAt: now,
|
||||
ID: uuid.New(),
|
||||
DatabaseID: databaseID,
|
||||
StorageID: storageID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
PgWalBackupType: &walBackupType,
|
||||
Encryption: encryption,
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
backup.GenerateFilename(dbName)
|
||||
|
||||
if uploadType == backups_core.PgWalUploadTypeBasebackup {
|
||||
walBackupType := backups_core.PgWalBackupTypeFullBackup
|
||||
backup.PgWalBackupType = &walBackupType
|
||||
return backup
|
||||
}
|
||||
|
||||
if fullBackupWalStartSegment != "" {
|
||||
backup.PgFullBackupWalStartSegmentName = &fullBackupWalStartSegment
|
||||
}
|
||||
func (s *PostgreWalBackupService) createWalSegmentRecord(
|
||||
databaseID uuid.UUID,
|
||||
storageID uuid.UUID,
|
||||
dbName string,
|
||||
walSegmentName string,
|
||||
encryption backups_config.BackupEncryption,
|
||||
) *backups_core.Backup {
|
||||
now := time.Now().UTC()
|
||||
walBackupType := backups_core.PgWalBackupTypeWalSegment
|
||||
|
||||
if fullBackupWalStopSegment != "" {
|
||||
backup.PgFullBackupWalStopSegmentName = &fullBackupWalStopSegment
|
||||
}
|
||||
} else {
|
||||
walBackupType := backups_core.PgWalBackupTypeWalSegment
|
||||
backup.PgWalBackupType = &walBackupType
|
||||
backup.PgWalSegmentName = &walSegmentName
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: databaseID,
|
||||
StorageID: storageID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
PgWalBackupType: &walBackupType,
|
||||
PgWalSegmentName: &walSegmentName,
|
||||
Encryption: encryption,
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
backup.GenerateFilename(dbName)
|
||||
|
||||
return backup
|
||||
}
|
||||
|
||||
@@ -432,80 +571,6 @@ func (s *PostgreWalBackupService) markFailed(backup *backups_core.Backup, errMsg
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PostgreWalBackupService) GetNextFullBackupTime(
|
||||
database *databases.Database,
|
||||
) (*backups_dto.GetNextFullBackupTimeResponse, error) {
|
||||
if err := s.validateWalBackupType(database); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get backup config: %w", err)
|
||||
}
|
||||
|
||||
if backupConfig.BackupInterval == nil {
|
||||
return nil, fmt.Errorf("no backup interval configured for database %s", database.ID)
|
||||
}
|
||||
|
||||
lastFullBackup, err := s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(
|
||||
database.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query last full backup: %w", err)
|
||||
}
|
||||
|
||||
var lastBackupTime *time.Time
|
||||
if lastFullBackup != nil {
|
||||
lastBackupTime = &lastFullBackup.CreatedAt
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
nextTime := backupConfig.BackupInterval.NextTriggerTime(now, lastBackupTime)
|
||||
|
||||
return &backups_dto.GetNextFullBackupTimeResponse{
|
||||
NextFullBackupTime: nextTime,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ReportError creates a FAILED backup record with the agent's error message.
|
||||
func (s *PostgreWalBackupService) ReportError(
|
||||
database *databases.Database,
|
||||
errorMsg string,
|
||||
) error {
|
||||
if err := s.validateWalBackupType(database); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get backup config: %w", err)
|
||||
}
|
||||
|
||||
if backupConfig.Storage == nil {
|
||||
return fmt.Errorf("no storage configured for database %s", database.ID)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: backupConfig.Storage.ID,
|
||||
Status: backups_core.BackupStatusFailed,
|
||||
FailMessage: &errorMsg,
|
||||
Encryption: backupConfig.Encryption,
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
backup.GenerateFilename(database.Name)
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
return fmt.Errorf("failed to save error backup record: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PostgreWalBackupService) resolveFullBackup(
|
||||
databaseID uuid.UUID,
|
||||
backupID *uuid.UUID,
|
||||
@@ -609,5 +674,5 @@ func (cr *countingReader) Read(p []byte) (n int, err error) {
|
||||
n, err = cr.r.Read(p)
|
||||
cr.n += int64(n)
|
||||
|
||||
return
|
||||
return n, err
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
@@ -23,8 +25,6 @@ import (
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
files_utils "databasus-backend/internal/util/files"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BackupService struct {
|
||||
|
||||
@@ -279,10 +279,10 @@ func (uc *CreateMariadbBackupUsecase) createTempMyCnfFile(
|
||||
password string,
|
||||
) (string, error) {
|
||||
tempFolder := config.GetEnv().TempFolder
|
||||
if err := os.MkdirAll(tempFolder, 0700); err != nil {
|
||||
if err := os.MkdirAll(tempFolder, 0o700); err != nil {
|
||||
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
|
||||
}
|
||||
if err := os.Chmod(tempFolder, 0700); err != nil {
|
||||
if err := os.Chmod(tempFolder, 0o700); err != nil {
|
||||
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
|
||||
}
|
||||
|
||||
@@ -291,7 +291,7 @@ func (uc *CreateMariadbBackupUsecase) createTempMyCnfFile(
|
||||
return "", fmt.Errorf("failed to create temp directory: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(tempDir, 0700); err != nil {
|
||||
if err := os.Chmod(tempDir, 0o700); err != nil {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
return "", fmt.Errorf("failed to set temp directory permissions: %w", err)
|
||||
}
|
||||
@@ -311,7 +311,7 @@ port=%d
|
||||
content += "ssl=false\n"
|
||||
}
|
||||
|
||||
err = os.WriteFile(myCnfFile, []byte(content), 0600)
|
||||
err = os.WriteFile(myCnfFile, []byte(content), 0o600)
|
||||
if err != nil {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
return "", fmt.Errorf("failed to write .my.cnf: %w", err)
|
||||
@@ -548,8 +548,8 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpErrorMessage(
|
||||
stderrStr,
|
||||
)
|
||||
|
||||
exitErr, ok := waitErr.(*exec.ExitError)
|
||||
if !ok {
|
||||
var exitErr *exec.ExitError
|
||||
if !errors.As(waitErr, &exitErr) {
|
||||
return errors.New(errorMsg)
|
||||
}
|
||||
|
||||
|
||||
@@ -298,10 +298,10 @@ func (uc *CreateMysqlBackupUsecase) createTempMyCnfFile(
|
||||
password string,
|
||||
) (string, error) {
|
||||
tempFolder := config.GetEnv().TempFolder
|
||||
if err := os.MkdirAll(tempFolder, 0700); err != nil {
|
||||
if err := os.MkdirAll(tempFolder, 0o700); err != nil {
|
||||
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
|
||||
}
|
||||
if err := os.Chmod(tempFolder, 0700); err != nil {
|
||||
if err := os.Chmod(tempFolder, 0o700); err != nil {
|
||||
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
|
||||
}
|
||||
|
||||
@@ -310,7 +310,7 @@ func (uc *CreateMysqlBackupUsecase) createTempMyCnfFile(
|
||||
return "", fmt.Errorf("failed to create temp directory: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(tempDir, 0700); err != nil {
|
||||
if err := os.Chmod(tempDir, 0o700); err != nil {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
return "", fmt.Errorf("failed to set temp directory permissions: %w", err)
|
||||
}
|
||||
@@ -328,7 +328,7 @@ port=%d
|
||||
content += "ssl-mode=REQUIRED\n"
|
||||
}
|
||||
|
||||
err = os.WriteFile(myCnfFile, []byte(content), 0600)
|
||||
err = os.WriteFile(myCnfFile, []byte(content), 0o600)
|
||||
if err != nil {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
return "", fmt.Errorf("failed to write .my.cnf: %w", err)
|
||||
@@ -565,8 +565,8 @@ func (uc *CreateMysqlBackupUsecase) buildMysqldumpErrorMessage(
|
||||
stderrStr,
|
||||
)
|
||||
|
||||
exitErr, ok := waitErr.(*exec.ExitError)
|
||||
if !ok {
|
||||
var exitErr *exec.ExitError
|
||||
if !errors.As(waitErr, &exitErr) {
|
||||
return errors.New(errorMsg)
|
||||
}
|
||||
|
||||
|
||||
@@ -13,6 +13,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
@@ -24,8 +26,6 @@ import (
|
||||
"databasus-backend/internal/features/storages"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/tools"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -595,8 +595,8 @@ func (uc *CreatePostgresqlBackupUsecase) buildPgDumpErrorMessage(
|
||||
stderrStr := string(stderrOutput)
|
||||
errorMsg := fmt.Sprintf("%s failed: %v – stderr: %s", filepath.Base(pgBin), waitErr, stderrStr)
|
||||
|
||||
exitErr, ok := waitErr.(*exec.ExitError)
|
||||
if !ok {
|
||||
var exitErr *exec.ExitError
|
||||
if !errors.As(waitErr, &exitErr) {
|
||||
return errors.New(errorMsg)
|
||||
}
|
||||
|
||||
@@ -748,10 +748,10 @@ func (uc *CreatePostgresqlBackupUsecase) createTempPgpassFile(
|
||||
)
|
||||
|
||||
tempFolder := config.GetEnv().TempFolder
|
||||
if err := os.MkdirAll(tempFolder, 0700); err != nil {
|
||||
if err := os.MkdirAll(tempFolder, 0o700); err != nil {
|
||||
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
|
||||
}
|
||||
if err := os.Chmod(tempFolder, 0700); err != nil {
|
||||
if err := os.Chmod(tempFolder, 0o700); err != nil {
|
||||
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
|
||||
}
|
||||
|
||||
@@ -760,13 +760,13 @@ func (uc *CreatePostgresqlBackupUsecase) createTempPgpassFile(
|
||||
return "", fmt.Errorf("failed to create temporary directory: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(tempDir, 0700); err != nil {
|
||||
if err := os.Chmod(tempDir, 0o700); err != nil {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
return "", fmt.Errorf("failed to set temporary directory permissions: %w", err)
|
||||
}
|
||||
|
||||
pgpassFile := filepath.Join(tempDir, ".pgpass")
|
||||
err = os.WriteFile(pgpassFile, []byte(pgpassContent), 0600)
|
||||
err = os.WriteFile(pgpassFile, []byte(pgpassContent), 0o600)
|
||||
if err != nil {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
return "", fmt.Errorf("failed to write temporary .pgpass file: %w", err)
|
||||
|
||||
@@ -4,10 +4,10 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
)
|
||||
|
||||
type BackupConfigController struct {
|
||||
|
||||
@@ -12,16 +12,19 @@ import (
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var backupConfigRepository = &BackupConfigRepository{}
|
||||
var backupConfigService = &BackupConfigService{
|
||||
backupConfigRepository,
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
plans.GetDatabasePlanService(),
|
||||
nil,
|
||||
}
|
||||
var (
|
||||
backupConfigRepository = &BackupConfigRepository{}
|
||||
backupConfigService = &BackupConfigService{
|
||||
backupConfigRepository,
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
plans.GetDatabasePlanService(),
|
||||
nil,
|
||||
}
|
||||
)
|
||||
|
||||
var backupConfigController = &BackupConfigController{
|
||||
backupConfigService,
|
||||
}
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
package backups_config
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/features/storages"
|
||||
"databasus-backend/internal/util/period"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/features/storages"
|
||||
"databasus-backend/internal/util/period"
|
||||
)
|
||||
|
||||
type BackupConfig struct {
|
||||
@@ -43,7 +44,7 @@ type BackupConfig struct {
|
||||
Encryption BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
|
||||
|
||||
// MaxBackupSizeMB limits individual backup size. 0 = unlimited.
|
||||
MaxBackupSizeMB int64 `json:"maxBackupSizeMb" gorm:"column:max_backup_size_mb;type:int;not null"`
|
||||
MaxBackupSizeMB int64 `json:"maxBackupSizeMb" gorm:"column:max_backup_size_mb;type:int;not null"`
|
||||
// MaxBackupsTotalSizeMB limits total size of all backups. 0 = unlimited.
|
||||
MaxBackupsTotalSizeMB int64 `json:"maxBackupsTotalSizeMb" gorm:"column:max_backups_total_size_mb;type:int;not null"`
|
||||
}
|
||||
|
||||
@@ -3,12 +3,12 @@ package backups_config
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/features/intervals"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/util/period"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_Validate_WhenRetentionTimePeriodIsWeekAndPlanAllowsMonth_ValidationPasses(t *testing.T) {
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
package backups_config
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/storage"
|
||||
"errors"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"databasus-backend/internal/storage"
|
||||
)
|
||||
|
||||
type BackupConfigRepository struct{}
|
||||
@@ -47,7 +48,6 @@ func (r *BackupConfigRepository) Save(
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package backups_config
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
@@ -10,8 +12,6 @@ import (
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_models "databasus-backend/internal/features/users/models"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BackupConfigService struct {
|
||||
@@ -214,39 +214,6 @@ func (s *BackupConfigService) CreateDisabledBackupConfig(databaseID uuid.UUID) e
|
||||
return s.initializeDefaultConfig(databaseID)
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) initializeDefaultConfig(
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
plan, err := s.databasePlanService.GetDatabasePlan(databaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
timeOfDay := "04:00"
|
||||
|
||||
_, err = s.backupConfigRepository.Save(&BackupConfig{
|
||||
DatabaseID: databaseID,
|
||||
IsBackupsEnabled: false,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: plan.MaxStoragePeriod,
|
||||
MaxBackupSizeMB: plan.MaxBackupSizeMB,
|
||||
MaxBackupsTotalSizeMB: plan.MaxBackupsTotalSizeMB,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
NotificationBackupSuccess,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) TransferDatabaseToWorkspace(
|
||||
user *users_models.User,
|
||||
databaseID uuid.UUID,
|
||||
@@ -290,7 +257,8 @@ func (s *BackupConfigService) TransferDatabaseToWorkspace(
|
||||
s.transferNotifiers(user, database, request.TargetWorkspaceID)
|
||||
}
|
||||
|
||||
if request.IsTransferWithStorage {
|
||||
switch {
|
||||
case request.IsTransferWithStorage:
|
||||
if backupConfig.StorageID == nil {
|
||||
return ErrDatabaseHasNoStorage
|
||||
}
|
||||
@@ -315,7 +283,7 @@ func (s *BackupConfigService) TransferDatabaseToWorkspace(
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if request.TargetStorageID != nil {
|
||||
case request.TargetStorageID != nil:
|
||||
targetStorage, err := s.storageService.GetStorageByID(*request.TargetStorageID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -332,7 +300,7 @@ func (s *BackupConfigService) TransferDatabaseToWorkspace(
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
return ErrTargetStorageNotSpecified
|
||||
}
|
||||
|
||||
@@ -351,6 +319,39 @@ func (s *BackupConfigService) TransferDatabaseToWorkspace(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) initializeDefaultConfig(
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
plan, err := s.databasePlanService.GetDatabasePlan(databaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
timeOfDay := "04:00"
|
||||
|
||||
_, err = s.backupConfigRepository.Save(&BackupConfig{
|
||||
DatabaseID: databaseID,
|
||||
IsBackupsEnabled: false,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: plan.MaxStoragePeriod,
|
||||
MaxBackupSizeMB: plan.MaxBackupSizeMB,
|
||||
MaxBackupsTotalSizeMB: plan.MaxBackupsTotalSizeMB,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
NotificationBackupSuccess,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) transferNotifiers(
|
||||
user *users_models.User,
|
||||
database *databases.Database,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user