mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 00:32:03 +02:00
Compare commits
292 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7123de9fa3 | ||
|
|
d1c41ed53a | ||
|
|
c2ddbfc86f | ||
|
|
f287967b5d | ||
|
|
ef879df08f | ||
|
|
44ddcb836e | ||
|
|
b5178f5752 | ||
|
|
7913c1b474 | ||
|
|
2815cc3752 | ||
|
|
189573fa1b | ||
|
|
81f77760c9 | ||
|
|
63e23b2489 | ||
|
|
8c1b8ac00f | ||
|
|
1926096377 | ||
|
|
0a131511a8 | ||
|
|
aa01ce0b76 | ||
|
|
496fc05993 | ||
|
|
1ac0eb4d5b | ||
|
|
6b052902f7 | ||
|
|
c7d091fe51 | ||
|
|
b1dfd1c425 | ||
|
|
4bee78646a | ||
|
|
927eeabc0f | ||
|
|
3a5a53c92d | ||
|
|
f0ab470a84 | ||
|
|
f728fda759 | ||
|
|
80b5df6283 | ||
|
|
67556a0db1 | ||
|
|
c4cf7f8446 | ||
|
|
61a0bcabb1 | ||
|
|
f7f70a13eb | ||
|
|
f1e289c421 | ||
|
|
c0952e057f | ||
|
|
b4d4e0a1d7 | ||
|
|
c648e9c29f | ||
|
|
3fce6d2a99 | ||
|
|
198b94ba9d | ||
|
|
80cd0bf5d3 | ||
|
|
231e3cc709 | ||
|
|
8cf0fdacb1 | ||
|
|
2d28af19dc | ||
|
|
67dc257fda | ||
|
|
881167f812 | ||
|
|
cf807cfc54 | ||
|
|
df91651709 | ||
|
|
b0592dae9e | ||
|
|
c997202484 | ||
|
|
a17ea2f3e2 | ||
|
|
856aa1c256 | ||
|
|
f60f677351 | ||
|
|
4c980746ab | ||
|
|
89197bbbc6 | ||
|
|
e2ac5bfbd7 | ||
|
|
cf6e8f212a | ||
|
|
6ee7e02f5d | ||
|
|
14bcd3d70b | ||
|
|
5faa11f82a | ||
|
|
2c4e3e567b | ||
|
|
82d615545b | ||
|
|
e913f4c32e | ||
|
|
57a75918e4 | ||
|
|
8a601c7f68 | ||
|
|
f0064b4be3 | ||
|
|
94505bab3f | ||
|
|
9acf3cff09 | ||
|
|
0d7e147df6 | ||
|
|
1394b47570 | ||
|
|
a9865ae3e4 | ||
|
|
4b5478e60a | ||
|
|
6355301903 | ||
|
|
29b403a9c6 | ||
|
|
12606053f4 | ||
|
|
904b386378 | ||
|
|
1d9738b808 | ||
|
|
58b37f4c92 | ||
|
|
6c4f814c94 | ||
|
|
bcd13c27d3 | ||
|
|
120f9600bf | ||
|
|
563c7c1d64 | ||
|
|
68f15f7661 | ||
|
|
627d96a00d | ||
|
|
02b9a9ec8d | ||
|
|
415dda8752 | ||
|
|
3faf85796a | ||
|
|
edd2759f5a | ||
|
|
c283856f38 | ||
|
|
6059e1a33b | ||
|
|
2deda2e7ea | ||
|
|
acf1143752 | ||
|
|
889063a8b4 | ||
|
|
a1e20e7b10 | ||
|
|
7e76945550 | ||
|
|
d98acfc4af | ||
|
|
0ffc7c8c96 | ||
|
|
1b011bdcd4 | ||
|
|
7e209ff537 | ||
|
|
f712e3a437 | ||
|
|
bcd7d8e1aa | ||
|
|
880a7488e9 | ||
|
|
ca4d483f2c | ||
|
|
1b511410a6 | ||
|
|
c8edff8046 | ||
|
|
f60e3d956b | ||
|
|
f2cb9022f2 | ||
|
|
4b3f36eea2 | ||
|
|
460063e7a5 | ||
|
|
a0f02b253e | ||
|
|
812f11bc2f | ||
|
|
e796e3ddf0 | ||
|
|
c96d3db337 | ||
|
|
ed6c3a2034 | ||
|
|
05115047c3 | ||
|
|
446b96c6c0 | ||
|
|
36a0448da1 | ||
|
|
8e392cfeab | ||
|
|
6683db1e52 | ||
|
|
703b883936 | ||
|
|
e818bcff82 | ||
|
|
b2f98f1332 | ||
|
|
230cc27ea6 | ||
|
|
cd197ff94b | ||
|
|
91f35a3e17 | ||
|
|
30c2e2d156 | ||
|
|
ef7c5b45e6 | ||
|
|
920c98e229 | ||
|
|
2a19a96aae | ||
|
|
75aa2108d9 | ||
|
|
0a0040839e | ||
|
|
ff4f795ece | ||
|
|
dc05502580 | ||
|
|
1ca38f5583 | ||
|
|
40b3ff61c7 | ||
|
|
e1b245a573 | ||
|
|
fdf29b71f2 | ||
|
|
49da981c21 | ||
|
|
9d611d3559 | ||
|
|
22cab53dab | ||
|
|
d761c4156c | ||
|
|
cbb8b82711 | ||
|
|
8e3d1e5bff | ||
|
|
349e7f0ee8 | ||
|
|
3a274e135b | ||
|
|
61e937bc2a | ||
|
|
f67919fe1a | ||
|
|
91ee5966d8 | ||
|
|
d77d7d69a3 | ||
|
|
fc88b730d5 | ||
|
|
1f1d80245f | ||
|
|
16a29cf458 | ||
|
|
43e04500ac | ||
|
|
cee3022f85 | ||
|
|
f46d92c480 | ||
|
|
10677238d7 | ||
|
|
2553203fcf | ||
|
|
7b05bd8000 | ||
|
|
8d45728f73 | ||
|
|
c70ad82c95 | ||
|
|
e4bc34d319 | ||
|
|
257ae85da7 | ||
|
|
b42c820bb2 | ||
|
|
da5c13fb11 | ||
|
|
35180360e5 | ||
|
|
e4f6cd7a5d | ||
|
|
d7b8e6d56a | ||
|
|
6016f23fb2 | ||
|
|
e7c4ee8f6f | ||
|
|
a75702a01b | ||
|
|
81a21eb907 | ||
|
|
33d6bf0147 | ||
|
|
6eb53bb07b | ||
|
|
6ac04270b9 | ||
|
|
b0510d7c21 | ||
|
|
dc5f271882 | ||
|
|
8f718771c9 | ||
|
|
d8eea05dca | ||
|
|
b2a94274d7 | ||
|
|
77c2712ebb | ||
|
|
a9dc29f82c | ||
|
|
c934a45dca | ||
|
|
d4acdf2826 | ||
|
|
49753c4fc0 | ||
|
|
c6aed6b36d | ||
|
|
3060b4266a | ||
|
|
ebeb597f17 | ||
|
|
4783784325 | ||
|
|
bd41433bdb | ||
|
|
a9073787d2 | ||
|
|
0890bf8f09 | ||
|
|
f8c11e8802 | ||
|
|
e798d82fc1 | ||
|
|
81a01585ee | ||
|
|
a8465c1a10 | ||
|
|
a9e5db70f6 | ||
|
|
7a47be6ca6 | ||
|
|
16be3db0c6 | ||
|
|
744e51d1e1 | ||
|
|
b3af75d430 | ||
|
|
6f7320abeb | ||
|
|
a1655d35a6 | ||
|
|
9b6e801184 | ||
|
|
105777ab6f | ||
|
|
3a1a88d5cf | ||
|
|
699ca16814 | ||
|
|
26f3cf233a | ||
|
|
3d8372e9f6 | ||
|
|
b46f11804d | ||
|
|
4676361688 | ||
|
|
de3679cadf | ||
|
|
8f03a30af2 | ||
|
|
356529c58a | ||
|
|
e7eed056f7 | ||
|
|
6084cdc954 | ||
|
|
c50bcc57b1 | ||
|
|
ea76300ed7 | ||
|
|
9b413e4076 | ||
|
|
f91cb260f2 | ||
|
|
8f37a8082f | ||
|
|
5cf7614772 | ||
|
|
ae27f74c2e | ||
|
|
9457516bb9 | ||
|
|
a36fc5bf8c | ||
|
|
03ada5806d | ||
|
|
a6675390e5 | ||
|
|
af2f978876 | ||
|
|
04e7eba5c5 | ||
|
|
520165541d | ||
|
|
5b556bc161 | ||
|
|
0952a15ec5 | ||
|
|
1afb3aa3ff | ||
|
|
19b92e5f74 | ||
|
|
d4763f26b2 | ||
|
|
0e389ba16b | ||
|
|
594a3294c6 | ||
|
|
4e4a323cf1 | ||
|
|
7d9ecf697b | ||
|
|
755c420157 | ||
|
|
ff73627287 | ||
|
|
9c9ab00ace | ||
|
|
7366e21a1a | ||
|
|
a327d1aa57 | ||
|
|
f152b16ea3 | ||
|
|
85dbe80d3d | ||
|
|
edf4028fd1 | ||
|
|
8d85c45a90 | ||
|
|
d9c176d19a | ||
|
|
7a6f72a456 | ||
|
|
9a1471b88b | ||
|
|
386ea1d708 | ||
|
|
a4b23936ee | ||
|
|
b36aa9d48b | ||
|
|
13cb8e5bd2 | ||
|
|
2db4b6e075 | ||
|
|
f2b0b2bf1f | ||
|
|
7142ce295e | ||
|
|
04621b9b2d | ||
|
|
bd329a68cf | ||
|
|
f957abc9db | ||
|
|
c0fd6be1a9 | ||
|
|
c39bd34d5e | ||
|
|
27bec15a29 | ||
|
|
d98baa0656 | ||
|
|
4344f5ea5e | ||
|
|
7c6afa5b88 | ||
|
|
dbac799e1b | ||
|
|
7ee3817089 | ||
|
|
bae6f7f007 | ||
|
|
55dc087ddd | ||
|
|
c94d0db637 | ||
|
|
a1adef2261 | ||
|
|
4602dc3f88 | ||
|
|
cbbfc5ea8f | ||
|
|
dd1072e230 | ||
|
|
a495e5317a | ||
|
|
7eed647038 | ||
|
|
6973241e25 | ||
|
|
ab181f5b81 | ||
|
|
b60a0cc170 | ||
|
|
f319a497b3 | ||
|
|
bc870b3f8e | ||
|
|
15383c59eb | ||
|
|
d14c223a65 | ||
|
|
2c0a294027 | ||
|
|
5d851d73bd | ||
|
|
699913c251 | ||
|
|
a2e3f30a6d | ||
|
|
80f1174ecd | ||
|
|
a47f8d5e2c | ||
|
|
54b9e67656 | ||
|
|
3782846872 | ||
|
|
245a81897f | ||
|
|
5cbc0773b6 | ||
|
|
997fc01442 |
44
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
44
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
---
|
||||
name: Bug Report
|
||||
about: Report a bug or unexpected behavior in Databasus
|
||||
labels: bug
|
||||
---
|
||||
|
||||
## Databasus version (screenshot)
|
||||
|
||||
It is displayed in the bottom left corner of the Databasus UI. Please attach screenshot, not just version text
|
||||
|
||||
<!-- e.g. 1.4.2 -->
|
||||
|
||||
## Operating system and architecture
|
||||
|
||||
<!-- e.g. Ubuntu 22.04 x64, macOS 14 ARM, Windows 11 x64 -->
|
||||
|
||||
## Database type and version (optional, for DB-related bugs)
|
||||
|
||||
<!-- e.g. PostgreSQL 16 in Docker, MySQL 8.0 installed on server, MariaDB 11.4 in AWS Cloud -->
|
||||
|
||||
## Describe the bug (please write manually, do not ask AI to summarize)
|
||||
|
||||
**What happened:**
|
||||
|
||||
**What I expected:**
|
||||
|
||||
## Steps to reproduce
|
||||
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
|
||||
## Have you asked AI how to solve the issue?
|
||||
|
||||
<!-- Using AI to diagnose issues before filing a bug report helps narrow down root causes. -->
|
||||
|
||||
- [ ] Claude Sonnet 4.6 or newer
|
||||
- [ ] ChatGPT 5.2 or newer
|
||||
- [ ] No
|
||||
|
||||
|
||||
## Additional context / logs
|
||||
|
||||
<!-- Screenshots, error messages, relevant log output, etc. -->
|
||||
500
.github/workflows/ci-release.yml
vendored
500
.github/workflows/ci-release.yml
vendored
@@ -9,29 +9,31 @@ on:
|
||||
|
||||
jobs:
|
||||
lint-backend:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.ref != 'refs/heads/develop'
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: golang:1.26.1
|
||||
volumes:
|
||||
- /runner-cache/go-pkg:/go/pkg/mod
|
||||
- /runner-cache/go-build:/root/.cache/go-build
|
||||
- /runner-cache/golangci-lint:/root/.cache/golangci-lint
|
||||
- /runner-cache/apt-archives:/var/cache/apt/archives
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.24.4"
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
~/.cache/go-build
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('backend/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
- name: Download Go modules
|
||||
run: |
|
||||
cd backend
|
||||
go mod download
|
||||
|
||||
- name: Install golangci-lint
|
||||
run: |
|
||||
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/HEAD/install.sh | sh -s -- -b $(go env GOPATH)/bin v2.7.2
|
||||
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/HEAD/install.sh | sh -s -- -b $(go env GOPATH)/bin v2.11.3
|
||||
echo "$(go env GOPATH)/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Install swag for swagger generation
|
||||
@@ -54,6 +56,7 @@ jobs:
|
||||
git diff --exit-code go.mod go.sum || (echo "go mod tidy made changes, please run 'go mod tidy' and commit the changes" && exit 1)
|
||||
|
||||
lint-frontend:
|
||||
if: github.ref != 'refs/heads/develop'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check out code
|
||||
@@ -63,8 +66,6 @@ jobs:
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
cache: "npm"
|
||||
cache-dependency-path: frontend/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
@@ -82,7 +83,47 @@ jobs:
|
||||
cd frontend
|
||||
npm run lint
|
||||
|
||||
- name: Build frontend
|
||||
run: |
|
||||
cd frontend
|
||||
npm run build
|
||||
|
||||
lint-agent:
|
||||
if: github.ref != 'refs/heads/develop'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.26.1"
|
||||
cache-dependency-path: agent/go.sum
|
||||
|
||||
- name: Download Go modules
|
||||
run: |
|
||||
cd agent
|
||||
go mod download
|
||||
|
||||
- name: Install golangci-lint
|
||||
run: |
|
||||
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/HEAD/install.sh | sh -s -- -b $(go env GOPATH)/bin v2.11.3
|
||||
echo "$(go env GOPATH)/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Run golangci-lint
|
||||
run: |
|
||||
cd agent
|
||||
golangci-lint run
|
||||
|
||||
- name: Verify go mod tidy
|
||||
run: |
|
||||
cd agent
|
||||
go mod tidy
|
||||
git diff --exit-code go.mod go.sum || (echo "go mod tidy made changes, please run 'go mod tidy' and commit the changes" && exit 1)
|
||||
|
||||
test-frontend:
|
||||
if: github.ref != 'refs/heads/develop'
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint-frontend]
|
||||
steps:
|
||||
@@ -93,8 +134,6 @@ jobs:
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
cache: "npm"
|
||||
cache-dependency-path: frontend/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
@@ -106,45 +145,104 @@ jobs:
|
||||
cd frontend
|
||||
npm run test
|
||||
|
||||
test-backend:
|
||||
test-agent:
|
||||
if: github.ref != 'refs/heads/develop'
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint-backend]
|
||||
needs: [lint-agent]
|
||||
steps:
|
||||
- name: Free up disk space
|
||||
run: |
|
||||
echo "Disk space before cleanup:"
|
||||
df -h
|
||||
# Remove unnecessary pre-installed software
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo rm -rf /usr/local/share/boost
|
||||
sudo rm -rf /usr/share/swift
|
||||
# Clean apt cache
|
||||
sudo apt-get clean
|
||||
# Clean docker images (if any pre-installed)
|
||||
docker system prune -af --volumes || true
|
||||
echo "Disk space after cleanup:"
|
||||
df -h
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.24.4"
|
||||
go-version: "1.26.1"
|
||||
cache-dependency-path: agent/go.sum
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
~/.cache/go-build
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('backend/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
- name: Download Go modules
|
||||
run: |
|
||||
cd agent
|
||||
go mod download
|
||||
|
||||
- name: Run Go tests
|
||||
run: |
|
||||
cd agent
|
||||
go test -count=1 -failfast ./internal/...
|
||||
|
||||
e2e-agent:
|
||||
if: github.ref != 'refs/heads/develop'
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint-agent]
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Run e2e tests
|
||||
run: |
|
||||
cd agent
|
||||
make e2e
|
||||
|
||||
- name: Cleanup
|
||||
if: always()
|
||||
run: |
|
||||
cd agent/e2e
|
||||
docker compose down -v --rmi local || true
|
||||
rm -rf artifacts || true
|
||||
|
||||
e2e-agent-backup-restore:
|
||||
if: github.ref != 'refs/heads/develop'
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint-agent]
|
||||
strategy:
|
||||
matrix:
|
||||
pg_version: [15, 16, 17, 18]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Run backup-restore e2e (PG ${{ matrix.pg_version }})
|
||||
run: |
|
||||
cd agent
|
||||
make e2e-backup-restore PG_VERSION=${{ matrix.pg_version }}
|
||||
|
||||
- name: Cleanup
|
||||
if: always()
|
||||
run: |
|
||||
cd agent/e2e
|
||||
docker compose -f docker-compose.backup-restore.yml down -v --rmi local || true
|
||||
rm -rf artifacts || true
|
||||
|
||||
# Self-hosted: performant high-frequency CPU is used to start many containers and run tests fast. Tests
|
||||
# step is bottle-neck, because we need a lot of containers and cannot parallelize tests due to shared resources
|
||||
test-backend:
|
||||
if: github.ref != 'refs/heads/develop'
|
||||
runs-on: self-hosted
|
||||
needs: [lint-backend]
|
||||
container:
|
||||
image: golang:1.26.1
|
||||
options: --privileged -v /var/run/docker.sock:/var/run/docker.sock --add-host=host.docker.internal:host-gateway
|
||||
volumes:
|
||||
- /runner-cache/go-pkg:/go/pkg/mod
|
||||
- /runner-cache/go-build:/root/.cache/go-build
|
||||
- /runner-cache/apt-archives:/var/cache/apt/archives
|
||||
steps:
|
||||
- name: Install Docker CLI
|
||||
run: |
|
||||
apt-get update -qq
|
||||
apt-get install -y -qq docker.io docker-compose netcat-openbsd wget
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Download Go modules
|
||||
run: |
|
||||
cd backend
|
||||
go mod download
|
||||
|
||||
- name: Create .env file for testing
|
||||
run: |
|
||||
@@ -156,14 +254,16 @@ jobs:
|
||||
DEV_DB_PASSWORD=Q1234567
|
||||
#app
|
||||
ENV_MODE=development
|
||||
# db
|
||||
DATABASE_DSN=host=localhost user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
|
||||
DATABASE_URL=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disable
|
||||
# db - using 172.17.0.1 to access host from container
|
||||
DATABASE_DSN=host=172.17.0.1 user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
|
||||
DATABASE_URL=postgres://postgres:Q1234567@172.17.0.1:5437/databasus?sslmode=disable
|
||||
# migrations
|
||||
GOOSE_DRIVER=postgres
|
||||
GOOSE_DBSTRING=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disable
|
||||
GOOSE_DBSTRING=postgres://postgres:Q1234567@172.17.0.1:5437/databasus?sslmode=disable
|
||||
GOOSE_MIGRATION_DIR=./migrations
|
||||
# testing
|
||||
# testing
|
||||
TEST_LOCALHOST=172.17.0.1
|
||||
IS_SKIP_EXTERNAL_RESOURCES_TESTS=true
|
||||
# to get Google Drive env variables: add storage in UI and copy data from added storage here
|
||||
TEST_GOOGLE_DRIVE_CLIENT_ID=${{ secrets.TEST_GOOGLE_DRIVE_CLIENT_ID }}
|
||||
TEST_GOOGLE_DRIVE_CLIENT_SECRET=${{ secrets.TEST_GOOGLE_DRIVE_CLIENT_SECRET }}
|
||||
@@ -204,9 +304,6 @@ jobs:
|
||||
TEST_MARIADB_114_PORT=33114
|
||||
TEST_MARIADB_118_PORT=33118
|
||||
TEST_MARIADB_120_PORT=33120
|
||||
# testing Telegram
|
||||
TEST_TELEGRAM_BOT_TOKEN=${{ secrets.TEST_TELEGRAM_BOT_TOKEN }}
|
||||
TEST_TELEGRAM_CHAT_ID=${{ secrets.TEST_TELEGRAM_CHAT_ID }}
|
||||
# supabase
|
||||
TEST_SUPABASE_HOST=${{ secrets.TEST_SUPABASE_HOST }}
|
||||
TEST_SUPABASE_PORT=${{ secrets.TEST_SUPABASE_PORT }}
|
||||
@@ -221,6 +318,14 @@ jobs:
|
||||
TEST_MONGODB_60_PORT=27060
|
||||
TEST_MONGODB_70_PORT=27070
|
||||
TEST_MONGODB_82_PORT=27082
|
||||
# Valkey (cache) - using 172.17.0.1
|
||||
VALKEY_HOST=172.17.0.1
|
||||
VALKEY_PORT=6379
|
||||
VALKEY_USERNAME=
|
||||
VALKEY_PASSWORD=
|
||||
VALKEY_IS_SSL=false
|
||||
# Host for test databases (container -> host)
|
||||
TEST_DB_HOST=172.17.0.1
|
||||
EOF
|
||||
|
||||
- name: Start test containers
|
||||
@@ -233,25 +338,30 @@ jobs:
|
||||
# Wait for main dev database
|
||||
timeout 60 bash -c 'until docker exec dev-db pg_isready -h localhost -p 5437 -U postgres; do sleep 2; done'
|
||||
|
||||
# Wait for test databases
|
||||
timeout 60 bash -c 'until nc -z localhost 5000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5001; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5002; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5003; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5004; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5005; do sleep 2; done'
|
||||
# Wait for Valkey (cache)
|
||||
echo "Waiting for Valkey..."
|
||||
timeout 60 bash -c 'until docker exec dev-valkey valkey-cli ping 2>/dev/null | grep -q PONG; do sleep 2; done'
|
||||
echo "Valkey is ready!"
|
||||
|
||||
# Wait for test databases (using 172.17.0.1 from container)
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5001; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5002; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5003; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5004; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5005; do sleep 2; done'
|
||||
|
||||
# Wait for MinIO
|
||||
timeout 60 bash -c 'until nc -z localhost 9000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 9000; do sleep 2; done'
|
||||
|
||||
# Wait for Azurite
|
||||
timeout 60 bash -c 'until nc -z localhost 10000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 10000; do sleep 2; done'
|
||||
|
||||
# Wait for FTP
|
||||
timeout 60 bash -c 'until nc -z localhost 7007; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 7007; do sleep 2; done'
|
||||
|
||||
# Wait for SFTP
|
||||
timeout 60 bash -c 'until nc -z localhost 7008; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 7008; do sleep 2; done'
|
||||
|
||||
# Wait for MySQL containers
|
||||
echo "Waiting for MySQL 5.7..."
|
||||
@@ -310,67 +420,66 @@ jobs:
|
||||
mkdir -p databasus-data/backups
|
||||
mkdir -p databasus-data/temp
|
||||
|
||||
- name: Cache PostgreSQL client tools
|
||||
id: cache-postgres
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: /usr/lib/postgresql
|
||||
key: postgres-clients-12-18-v1
|
||||
|
||||
- name: Cache MySQL client tools
|
||||
id: cache-mysql
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: backend/tools/mysql
|
||||
key: mysql-clients-57-80-84-9-v1
|
||||
|
||||
- name: Cache MariaDB client tools
|
||||
id: cache-mariadb
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: backend/tools/mariadb
|
||||
key: mariadb-clients-106-121-v1
|
||||
|
||||
- name: Cache MongoDB Database Tools
|
||||
id: cache-mongodb
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: backend/tools/mongodb
|
||||
key: mongodb-database-tools-100.10.0-v1
|
||||
|
||||
- name: Install MySQL dependencies
|
||||
- name: Install database client dependencies
|
||||
run: |
|
||||
sudo apt-get update -qq
|
||||
sudo apt-get install -y -qq libncurses6
|
||||
sudo ln -sf /usr/lib/x86_64-linux-gnu/libncurses.so.6 /usr/lib/x86_64-linux-gnu/libncurses.so.5
|
||||
sudo ln -sf /usr/lib/x86_64-linux-gnu/libtinfo.so.6 /usr/lib/x86_64-linux-gnu/libtinfo.so.5
|
||||
apt-get update -qq
|
||||
apt-get install -y -qq libncurses6 libpq5
|
||||
ln -sf /usr/lib/x86_64-linux-gnu/libncurses.so.6 /usr/lib/x86_64-linux-gnu/libncurses.so.5 || true
|
||||
ln -sf /usr/lib/x86_64-linux-gnu/libtinfo.so.6 /usr/lib/x86_64-linux-gnu/libtinfo.so.5 || true
|
||||
|
||||
- name: Install PostgreSQL, MySQL, MariaDB and MongoDB client tools
|
||||
if: steps.cache-postgres.outputs.cache-hit != 'true' || steps.cache-mysql.outputs.cache-hit != 'true' || steps.cache-mariadb.outputs.cache-hit != 'true' || steps.cache-mongodb.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
chmod +x backend/tools/download_linux.sh
|
||||
cd backend/tools
|
||||
./download_linux.sh
|
||||
|
||||
- name: Setup PostgreSQL symlinks (when using cache)
|
||||
if: steps.cache-postgres.outputs.cache-hit == 'true'
|
||||
- name: Setup PostgreSQL, MySQL and MariaDB client tools from pre-built assets
|
||||
run: |
|
||||
cd backend/tools
|
||||
mkdir -p postgresql
|
||||
|
||||
# Create directory structure
|
||||
mkdir -p postgresql mysql mariadb mongodb/bin
|
||||
|
||||
# Copy PostgreSQL client tools (12-18) from pre-built assets
|
||||
for version in 12 13 14 15 16 17 18; do
|
||||
version_dir="postgresql/postgresql-$version"
|
||||
mkdir -p "$version_dir/bin"
|
||||
pg_bin_dir="/usr/lib/postgresql/$version/bin"
|
||||
if [ -d "$pg_bin_dir" ]; then
|
||||
ln -sf "$pg_bin_dir/pg_dump" "$version_dir/bin/pg_dump"
|
||||
ln -sf "$pg_bin_dir/pg_dumpall" "$version_dir/bin/pg_dumpall"
|
||||
ln -sf "$pg_bin_dir/psql" "$version_dir/bin/psql"
|
||||
ln -sf "$pg_bin_dir/pg_restore" "$version_dir/bin/pg_restore"
|
||||
ln -sf "$pg_bin_dir/createdb" "$version_dir/bin/createdb"
|
||||
ln -sf "$pg_bin_dir/dropdb" "$version_dir/bin/dropdb"
|
||||
fi
|
||||
mkdir -p postgresql/postgresql-$version
|
||||
cp -r ../../assets/tools/x64/postgresql/postgresql-$version/bin postgresql/postgresql-$version/
|
||||
done
|
||||
|
||||
# Copy MySQL client tools (5.7, 8.0, 8.4, 9) from pre-built assets
|
||||
for version in 5.7 8.0 8.4 9; do
|
||||
mkdir -p mysql/mysql-$version
|
||||
cp -r ../../assets/tools/x64/mysql/mysql-$version/bin mysql/mysql-$version/
|
||||
done
|
||||
|
||||
# Copy MariaDB client tools (10.6, 12.1) from pre-built assets
|
||||
for version in 10.6 12.1; do
|
||||
mkdir -p mariadb/mariadb-$version
|
||||
cp -r ../../assets/tools/x64/mariadb/mariadb-$version/bin mariadb/mariadb-$version/
|
||||
done
|
||||
|
||||
# Make all binaries executable
|
||||
chmod +x postgresql/*/bin/*
|
||||
chmod +x mysql/*/bin/*
|
||||
chmod +x mariadb/*/bin/*
|
||||
|
||||
echo "Pre-built client tools setup complete"
|
||||
|
||||
- name: Install MongoDB Database Tools
|
||||
run: |
|
||||
cd backend/tools
|
||||
|
||||
# MongoDB Database Tools must be downloaded (not in pre-built assets)
|
||||
# They are backward compatible - single version supports all servers (4.0-8.0)
|
||||
MONGODB_TOOLS_URL="https://fastdl.mongodb.org/tools/db/mongodb-database-tools-debian12-x86_64-100.10.0.deb"
|
||||
|
||||
echo "Downloading MongoDB Database Tools..."
|
||||
wget -q "$MONGODB_TOOLS_URL" -O /tmp/mongodb-database-tools.deb
|
||||
|
||||
echo "Installing MongoDB Database Tools..."
|
||||
dpkg -i /tmp/mongodb-database-tools.deb || apt-get install -f -y --no-install-recommends
|
||||
|
||||
# Create symlinks to tools directory
|
||||
ln -sf /usr/bin/mongodump mongodb/bin/mongodump
|
||||
ln -sf /usr/bin/mongorestore mongodb/bin/mongorestore
|
||||
|
||||
rm -f /tmp/mongodb-database-tools.deb
|
||||
echo "MongoDB Database Tools installed successfully"
|
||||
|
||||
- name: Verify MariaDB client tools exist
|
||||
run: |
|
||||
cd backend/tools
|
||||
@@ -403,7 +512,7 @@ jobs:
|
||||
- name: Run database migrations
|
||||
run: |
|
||||
cd backend
|
||||
go install github.com/pressly/goose/v3/cmd/goose@latest
|
||||
go install github.com/pressly/goose/v3/cmd/goose@v3.24.3
|
||||
goose up
|
||||
|
||||
- name: Run Go tests
|
||||
@@ -415,11 +524,65 @@ jobs:
|
||||
if: always()
|
||||
run: |
|
||||
cd backend
|
||||
# Stop and remove containers (keeping images for next run)
|
||||
docker compose -f docker-compose.yml.example down -v
|
||||
|
||||
# Clean up all data directories created by docker-compose
|
||||
echo "Cleaning up data directories..."
|
||||
rm -rf pgdata || true
|
||||
rm -rf valkey-data || true
|
||||
rm -rf mysqldata || true
|
||||
rm -rf mariadbdata || true
|
||||
rm -rf temp/nas || true
|
||||
rm -rf databasus-data || true
|
||||
|
||||
# Also clean root-level databasus-data if exists
|
||||
cd ..
|
||||
rm -rf databasus-data || true
|
||||
|
||||
echo "Cleanup complete"
|
||||
|
||||
build-and-push-dev:
|
||||
runs-on: self-hosted
|
||||
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/develop' }}
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/* || true
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/.* || true
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up QEMU (enables multi-arch emulation)
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Build and push dev image
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
platforms: linux/amd64,linux/arm64
|
||||
build-args: |
|
||||
APP_VERSION=dev-${{ github.sha }}
|
||||
tags: |
|
||||
databasus/databasus-dev:latest
|
||||
databasus/databasus-dev:${{ github.sha }}
|
||||
|
||||
determine-version:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [test-backend, test-frontend]
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: node:20
|
||||
needs: [test-backend, test-frontend, test-agent, e2e-agent, e2e-agent-backup-restore]
|
||||
if: ${{ github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, '[skip-release]') }}
|
||||
outputs:
|
||||
should_release: ${{ steps.version_bump.outputs.should_release }}
|
||||
@@ -431,10 +594,9 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Install semver
|
||||
run: npm install -g semver
|
||||
@@ -448,6 +610,7 @@ jobs:
|
||||
|
||||
- name: Analyze commits and determine version bump
|
||||
id: version_bump
|
||||
shell: bash
|
||||
run: |
|
||||
CURRENT_VERSION="${{ steps.current_version.outputs.current_version }}"
|
||||
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
|
||||
@@ -467,7 +630,7 @@ jobs:
|
||||
HAS_FIX=false
|
||||
HAS_BREAKING=false
|
||||
|
||||
# Analyze each commit
|
||||
# Analyze each commit - USE PROCESS SUBSTITUTION to avoid subshell variable scope issues
|
||||
while IFS= read -r commit; do
|
||||
if [[ "$commit" =~ ^FEATURE ]]; then
|
||||
HAS_FEATURE=true
|
||||
@@ -485,7 +648,7 @@ jobs:
|
||||
HAS_BREAKING=true
|
||||
echo "Found BREAKING CHANGE: $commit"
|
||||
fi
|
||||
done <<< "$COMMITS"
|
||||
done < <(printf '%s\n' "$COMMITS")
|
||||
|
||||
# Determine version bump
|
||||
if [ "$HAS_BREAKING" = true ]; then
|
||||
@@ -510,45 +673,18 @@ jobs:
|
||||
echo "No version bump needed"
|
||||
fi
|
||||
|
||||
build-only:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [test-backend, test-frontend]
|
||||
if: ${{ github.ref == 'refs/heads/main' && contains(github.event.head_commit.message, '[skip-release]') }}
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up QEMU (enables multi-arch emulation)
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Build and push SHA-only tags
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
platforms: linux/amd64,linux/arm64
|
||||
build-args: |
|
||||
APP_VERSION=dev-${{ github.sha }}
|
||||
tags: |
|
||||
databasus/databasus:latest
|
||||
databasus/databasus:${{ github.sha }}
|
||||
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
needs: [determine-version]
|
||||
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
|
||||
permissions:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/* || true
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/.* || true
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -578,21 +714,33 @@ jobs:
|
||||
databasus/databasus:${{ github.sha }}
|
||||
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: node:20
|
||||
needs: [determine-version, build-and-push]
|
||||
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
rm -rf "$GITHUB_WORKSPACE"/* || true
|
||||
rm -rf "$GITHUB_WORKSPACE"/.* || true
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Generate changelog
|
||||
id: changelog
|
||||
shell: bash
|
||||
run: |
|
||||
NEW_VERSION="${{ needs.determine-version.outputs.new_version }}"
|
||||
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
|
||||
@@ -612,6 +760,7 @@ jobs:
|
||||
FIXES=""
|
||||
REFACTORS=""
|
||||
|
||||
# USE PROCESS SUBSTITUTION to avoid subshell variable scope issues
|
||||
while IFS= read -r line; do
|
||||
if [ -n "$line" ]; then
|
||||
COMMIT_MSG=$(echo "$line" | cut -d'|' -f1)
|
||||
@@ -645,7 +794,7 @@ jobs:
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
done <<< "$COMMITS"
|
||||
done < <(printf '%s\n' "$COMMITS")
|
||||
|
||||
# Build changelog sections
|
||||
if [ -n "$FEATURES" ]; then
|
||||
@@ -684,16 +833,33 @@ jobs:
|
||||
prerelease: false
|
||||
|
||||
publish-helm-chart:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: alpine:3.19
|
||||
volumes:
|
||||
- /runner-cache/apk-cache:/etc/apk/cache
|
||||
needs: [determine-version, build-and-push]
|
||||
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
rm -rf "$GITHUB_WORKSPACE"/* || true
|
||||
rm -rf "$GITHUB_WORKSPACE"/.* || true
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apk add --no-cache git bash curl
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4
|
||||
with:
|
||||
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -1,12 +1,16 @@
|
||||
ansible/
|
||||
postgresus_data/
|
||||
postgresus-data/
|
||||
databasus-data/
|
||||
.env
|
||||
pgdata/
|
||||
docker-compose.yml
|
||||
!agent/e2e/docker-compose.yml
|
||||
node_modules/
|
||||
.idea
|
||||
/articles
|
||||
|
||||
.DS_Store
|
||||
/scripts
|
||||
/scripts
|
||||
.vscode/settings.json
|
||||
.claude
|
||||
@@ -18,6 +18,13 @@ repos:
|
||||
files: ^frontend/.*\.(ts|tsx|js|jsx)$
|
||||
pass_filenames: false
|
||||
|
||||
- id: frontend-build
|
||||
name: Frontend Build
|
||||
entry: bash -c "cd frontend && npm run build"
|
||||
language: system
|
||||
files: ^frontend/.*\.(ts|tsx|js|jsx|json|css)$
|
||||
pass_filenames: false
|
||||
|
||||
# Backend checks
|
||||
- repo: local
|
||||
hooks:
|
||||
@@ -27,3 +34,27 @@ repos:
|
||||
language: system
|
||||
files: ^backend/.*\.go$
|
||||
pass_filenames: false
|
||||
|
||||
- id: backend-go-mod-tidy
|
||||
name: Backend Go Mod Tidy
|
||||
entry: bash -c "cd backend && go mod tidy"
|
||||
language: system
|
||||
files: ^backend/.*\.go$
|
||||
pass_filenames: false
|
||||
|
||||
# Agent checks
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: agent-format-and-lint
|
||||
name: Agent Format & Lint (golangci-lint)
|
||||
entry: bash -c "cd agent && golangci-lint fmt ./internal/... ./cmd/... && golangci-lint run ./internal/... ./cmd/..."
|
||||
language: system
|
||||
files: ^agent/.*\.go$
|
||||
pass_filenames: false
|
||||
|
||||
- id: agent-go-mod-tidy
|
||||
name: Agent Go Mod Tidy
|
||||
entry: bash -c "cd agent && go mod tidy"
|
||||
language: system
|
||||
files: ^agent/.*\.go$
|
||||
pass_filenames: false
|
||||
|
||||
195
Dockerfile
195
Dockerfile
@@ -22,7 +22,7 @@ RUN npm run build
|
||||
|
||||
# ========= BUILD BACKEND =========
|
||||
# Backend build stage
|
||||
FROM --platform=$BUILDPLATFORM golang:1.24.4 AS backend-build
|
||||
FROM --platform=$BUILDPLATFORM golang:1.26.1 AS backend-build
|
||||
|
||||
# Make TARGET args available early so tools built here match the final image arch
|
||||
ARG TARGETOS
|
||||
@@ -66,13 +66,52 @@ RUN CGO_ENABLED=0 \
|
||||
go build -o /app/main ./cmd/main.go
|
||||
|
||||
|
||||
# ========= BUILD AGENT =========
|
||||
# Builds the databasus-agent CLI binary for BOTH x86_64 and ARM64.
|
||||
# Both architectures are always built because:
|
||||
# - Databasus server runs on one arch (e.g. amd64)
|
||||
# - The agent runs on remote PostgreSQL servers that may be on a
|
||||
# different arch (e.g. arm64)
|
||||
# - The backend serves the correct binary based on the agent's
|
||||
# ?arch= query parameter
|
||||
#
|
||||
# We cross-compile from the build platform (no QEMU needed) because the
|
||||
# agent is pure Go with zero C dependencies.
|
||||
# CGO_ENABLED=0 produces fully static binaries — no glibc/musl dependency,
|
||||
# so the agent runs on any Linux distro (Alpine, Debian, Ubuntu, RHEL, etc.).
|
||||
# APP_VERSION is baked into the binary via -ldflags so the agent can
|
||||
# compare its version against the server and auto-update when needed.
|
||||
FROM --platform=$BUILDPLATFORM golang:1.26.1 AS agent-build
|
||||
|
||||
ARG APP_VERSION=dev
|
||||
|
||||
WORKDIR /agent
|
||||
|
||||
COPY agent/go.mod ./
|
||||
RUN go mod download
|
||||
|
||||
COPY agent/ ./
|
||||
|
||||
# Build for x86_64 (amd64) — static binary, no glibc dependency
|
||||
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 \
|
||||
go build -ldflags "-X main.Version=${APP_VERSION}" \
|
||||
-o /agent-binaries/databasus-agent-linux-amd64 ./cmd/main.go
|
||||
|
||||
# Build for ARM64 (arm64) — static binary, no glibc dependency
|
||||
RUN CGO_ENABLED=0 GOOS=linux GOARCH=arm64 \
|
||||
go build -ldflags "-X main.Version=${APP_VERSION}" \
|
||||
-o /agent-binaries/databasus-agent-linux-arm64 ./cmd/main.go
|
||||
|
||||
|
||||
# ========= RUNTIME =========
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
# Add version metadata to runtime image
|
||||
ARG APP_VERSION=dev
|
||||
ARG TARGETARCH
|
||||
LABEL org.opencontainers.image.version=$APP_VERSION
|
||||
ENV APP_VERSION=$APP_VERSION
|
||||
ENV CONTAINER_ARCH=$TARGETARCH
|
||||
|
||||
# Set production mode for Docker containers
|
||||
ENV ENV_MODE=production
|
||||
@@ -123,6 +162,15 @@ RUN wget -qO- https://www.postgresql.org/media/keys/ACCC4CF8.asc | apt-key add -
|
||||
apt-get install -y --no-install-recommends postgresql-17 && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Valkey server from debian repository
|
||||
# Valkey is only accessible internally (localhost) - not exposed outside container
|
||||
RUN wget -O /usr/share/keyrings/greensec.github.io-valkey-debian.key https://greensec.github.io/valkey-debian/public.key && \
|
||||
echo "deb [signed-by=/usr/share/keyrings/greensec.github.io-valkey-debian.key] https://greensec.github.io/valkey-debian/repo $(lsb_release -cs) main" \
|
||||
> /etc/apt/sources.list.d/valkey-debian.list && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends valkey && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# ========= Install rclone =========
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends rclone && \
|
||||
@@ -191,7 +239,8 @@ RUN apt-get update && \
|
||||
fi
|
||||
|
||||
# Create postgres user and set up directories
|
||||
RUN useradd -m -s /bin/bash postgres || true && \
|
||||
RUN groupadd -g 999 postgres || true && \
|
||||
useradd -m -s /bin/bash -u 999 -g 999 postgres || true && \
|
||||
mkdir -p /databasus-data/pgdata && \
|
||||
chown -R postgres:postgres /databasus-data/pgdata
|
||||
|
||||
@@ -209,6 +258,13 @@ COPY backend/migrations ./migrations
|
||||
# Copy UI files
|
||||
COPY --from=backend-build /app/ui/build ./ui/build
|
||||
|
||||
# Copy cloud static HTML template (injected into index.html at startup when IS_CLOUD=true)
|
||||
COPY frontend/cloud-root-content.html /app/cloud-root-content.html
|
||||
|
||||
# Copy agent binaries (both architectures) — served by the backend
|
||||
# at GET /api/v1/system/agent?arch=amd64|arm64
|
||||
COPY --from=agent-build /agent-binaries ./agent-binaries
|
||||
|
||||
# Copy .env file (with fallback to .env.production.example)
|
||||
COPY backend/.env* /app/
|
||||
RUN if [ ! -f /app/.env ]; then \
|
||||
@@ -239,9 +295,88 @@ if [ -d "/postgresus-data" ] && [ "\$(ls -A /postgresus-data 2>/dev/null)" ]; th
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# ========= Adjust postgres user UID/GID =========
|
||||
PUID=\${PUID:-999}
|
||||
PGID=\${PGID:-999}
|
||||
|
||||
CURRENT_UID=\$(id -u postgres)
|
||||
CURRENT_GID=\$(id -g postgres)
|
||||
|
||||
if [ "\$CURRENT_GID" != "\$PGID" ]; then
|
||||
echo "Adjusting postgres group GID from \$CURRENT_GID to \$PGID..."
|
||||
groupmod -o -g "\$PGID" postgres
|
||||
fi
|
||||
|
||||
if [ "\$CURRENT_UID" != "\$PUID" ]; then
|
||||
echo "Adjusting postgres user UID from \$CURRENT_UID to \$PUID..."
|
||||
usermod -o -u "\$PUID" postgres
|
||||
fi
|
||||
|
||||
chown -R postgres:postgres /var/run/postgresql
|
||||
|
||||
# PostgreSQL 17 binary paths
|
||||
PG_BIN="/usr/lib/postgresql/17/bin"
|
||||
|
||||
# Generate runtime configuration for frontend
|
||||
echo "Generating runtime configuration..."
|
||||
|
||||
# Detect if email is configured (both SMTP_HOST and DATABASUS_URL must be set)
|
||||
if [ -n "\${SMTP_HOST:-}" ] && [ -n "\${DATABASUS_URL:-}" ]; then
|
||||
IS_EMAIL_CONFIGURED="true"
|
||||
else
|
||||
IS_EMAIL_CONFIGURED="false"
|
||||
fi
|
||||
|
||||
cat > /app/ui/build/runtime-config.js <<JSEOF
|
||||
// Runtime configuration injected at container startup
|
||||
// This file is generated dynamically and should not be edited manually
|
||||
window.__RUNTIME_CONFIG__ = {
|
||||
IS_CLOUD: '\${IS_CLOUD:-false}',
|
||||
GITHUB_CLIENT_ID: '\${GITHUB_CLIENT_ID:-}',
|
||||
GOOGLE_CLIENT_ID: '\${GOOGLE_CLIENT_ID:-}',
|
||||
IS_EMAIL_CONFIGURED: '\$IS_EMAIL_CONFIGURED',
|
||||
CLOUDFLARE_TURNSTILE_SITE_KEY: '\${CLOUDFLARE_TURNSTILE_SITE_KEY:-}',
|
||||
CONTAINER_ARCH: '\${CONTAINER_ARCH:-unknown}',
|
||||
CLOUD_PRICE_PER_GB: '\${CLOUD_PRICE_PER_GB:-}',
|
||||
CLOUD_PADDLE_CLIENT_TOKEN: '\${CLOUD_PADDLE_CLIENT_TOKEN:-}'
|
||||
};
|
||||
JSEOF
|
||||
|
||||
# Inject analytics script if provided (only if not already injected)
|
||||
if [ -n "\${ANALYTICS_SCRIPT:-}" ]; then
|
||||
if ! grep -q "rybbit.databasus.com" /app/ui/build/index.html 2>/dev/null; then
|
||||
echo "Injecting analytics script..."
|
||||
sed -i "s#</head># \${ANALYTICS_SCRIPT}\\
|
||||
</head>#" /app/ui/build/index.html
|
||||
fi
|
||||
fi
|
||||
|
||||
# Inject Paddle script if client token is provided (only if not already injected)
|
||||
if [ -n "\${CLOUD_PADDLE_CLIENT_TOKEN:-}" ]; then
|
||||
if ! grep -q "cdn.paddle.com" /app/ui/build/index.html 2>/dev/null; then
|
||||
echo "Injecting Paddle script..."
|
||||
sed -i "s#</head># <script src=\"https://cdn.paddle.com/paddle/v2/paddle.js\"></script>\\
|
||||
</head>#" /app/ui/build/index.html
|
||||
fi
|
||||
fi
|
||||
|
||||
# Inject static HTML into root div for cloud mode (payment system requires visible legal links)
|
||||
if [ "\${IS_CLOUD:-false}" = "true" ]; then
|
||||
if ! grep -q "cloud-static-content" /app/ui/build/index.html 2>/dev/null; then
|
||||
echo "Injecting cloud static HTML content..."
|
||||
perl -i -pe '
|
||||
BEGIN {
|
||||
open my \$fh, "<", "/app/cloud-root-content.html" or die;
|
||||
local \$/;
|
||||
\$c = <\$fh>;
|
||||
close \$fh;
|
||||
\$c =~ s/\\n/ /g;
|
||||
}
|
||||
s/<div id="root"><\\/div>/<div id="root"><!-- cloud-static-content --><noscript>\$c<\\/noscript><\\/div>/
|
||||
' /app/ui/build/index.html
|
||||
fi
|
||||
fi
|
||||
|
||||
# Ensure proper ownership of data directory
|
||||
echo "Setting up data directory permissions..."
|
||||
mkdir -p /databasus-data/pgdata
|
||||
@@ -250,6 +385,30 @@ mkdir -p /databasus-data/backups
|
||||
chown -R postgres:postgres /databasus-data
|
||||
chmod 700 /databasus-data/temp
|
||||
|
||||
# ========= Start Valkey (internal cache) =========
|
||||
echo "Configuring Valkey cache..."
|
||||
cat > /tmp/valkey.conf << 'VALKEY_CONFIG'
|
||||
port 6379
|
||||
bind 127.0.0.1
|
||||
protected-mode yes
|
||||
save ""
|
||||
maxmemory 256mb
|
||||
maxmemory-policy allkeys-lru
|
||||
VALKEY_CONFIG
|
||||
|
||||
echo "Starting Valkey..."
|
||||
valkey-server /tmp/valkey.conf &
|
||||
VALKEY_PID=\$!
|
||||
|
||||
echo "Waiting for Valkey to be ready..."
|
||||
for i in {1..30}; do
|
||||
if valkey-cli ping >/dev/null 2>&1; then
|
||||
echo "Valkey is ready!"
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
# Initialize PostgreSQL if not already initialized
|
||||
if [ ! -s "/databasus-data/pgdata/PG_VERSION" ]; then
|
||||
echo "Initializing PostgreSQL database..."
|
||||
@@ -330,6 +489,8 @@ fi
|
||||
# Create database and set password for postgres user
|
||||
echo "Setting up database and user..."
|
||||
gosu postgres \$PG_BIN/psql -p 5437 -h localhost -d postgres << 'SQL'
|
||||
|
||||
-- We use stub password, because internal DB is not exposed outside container
|
||||
ALTER USER postgres WITH PASSWORD 'Q1234567';
|
||||
SELECT 'CREATE DATABASE databasus OWNER postgres'
|
||||
WHERE NOT EXISTS (SELECT FROM pg_database WHERE datname = 'databasus')
|
||||
@@ -339,9 +500,37 @@ SQL
|
||||
|
||||
# Start the main application
|
||||
echo "Starting Databasus application..."
|
||||
|
||||
# Check and warn about external database/Valkey usage
|
||||
if [ -n "\${DANGEROUS_EXTERNAL_DATABASE_DSN:-}" ]; then
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "WARNING: Using external database"
|
||||
echo "=========================================="
|
||||
echo "DANGEROUS_EXTERNAL_DATABASE_DSN is set."
|
||||
echo "Application will connect to external PostgreSQL instead of internal instance."
|
||||
echo "Internal PostgreSQL is still running in the background."
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
fi
|
||||
|
||||
if [ -n "\${DANGEROUS_VALKEY_HOST:-}" ]; then
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "WARNING: Using external Valkey"
|
||||
echo "=========================================="
|
||||
echo "DANGEROUS_VALKEY_HOST is set."
|
||||
echo "Application will connect to external Valkey instead of internal instance."
|
||||
echo "Internal Valkey is still running in the background."
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
fi
|
||||
|
||||
exec ./main
|
||||
EOF
|
||||
|
||||
LABEL org.opencontainers.image.source="https://github.com/databasus/databasus"
|
||||
|
||||
RUN chmod +x /app/start.sh
|
||||
|
||||
EXPOSE 4005
|
||||
@@ -350,4 +539,4 @@ EXPOSE 4005
|
||||
VOLUME ["/databasus-data"]
|
||||
|
||||
ENTRYPOINT ["/app/start.sh"]
|
||||
CMD []
|
||||
CMD []
|
||||
|
||||
22
NOTICE.md
Normal file
22
NOTICE.md
Normal file
@@ -0,0 +1,22 @@
|
||||
Copyright © 2025–2026 Rostislav Dugin and contributors.
|
||||
|
||||
“Databasus” is a trademark of Rostislav Dugin.
|
||||
|
||||
The source code in this repository is licensed under the Apache License, Version 2.0.
|
||||
That license applies to the code only and does not grant any right to use the
|
||||
Databasus name, logo, or branding, except for reasonable and customary referential
|
||||
use in describing the origin of the software and reproducing the content of this NOTICE.
|
||||
|
||||
Permitted referential use includes truthful use of the name “Databasus” to identify
|
||||
the original Databasus project in software catalogs, deployment templates, hosting
|
||||
panels, package indexes, compatibility pages, integrations, tutorials, reviews, and
|
||||
similar informational materials, including phrases such as “Databasus”,
|
||||
“Deploy Databasus”, “Databasus on Coolify”, and “Compatible with Databasus”.
|
||||
|
||||
You may not use “Databasus” as the name or primary branding of a competing product,
|
||||
service, fork, distribution, or hosted offering, or in any manner likely to cause
|
||||
confusion as to source, affiliation, sponsorship, or endorsement.
|
||||
|
||||
Nothing in this repository transfers, waives, limits, or estops any rights in the
|
||||
Databasus mark. All trademark rights are reserved except for the limited referential
|
||||
use stated above.
|
||||
97
README.md
97
README.md
@@ -1,8 +1,8 @@
|
||||
<div align="center">
|
||||
<img src="assets/logo.svg" alt="Databasus Logo" width="250"/>
|
||||
|
||||
<h3>Backup tool for PostgreSQL, MySQL and MongoDB</h3>
|
||||
<p>Databasus is a free, open source and self-hosted tool to backup databases (with focus on PostgreSQL). Make backups with different storages (S3, Google Drive, FTP, etc.) and notifications about progress (Slack, Discord, Telegram, etc.). Previously known as Postgresus (see migration guide).</p>
|
||||
<h3>PostgreSQL backup tool (with MySQL\MariaDB and MongoDB support)</h3>
|
||||
<p>Databasus is a free, open source and self-hosted tool to backup databases (with primary focus on PostgreSQL). Make backups with different storages (S3, Google Drive, FTP, etc.) and notifications about progress (Slack, Discord, Telegram, etc.)</p>
|
||||
|
||||
<!-- Badges -->
|
||||
[](https://www.postgresql.org/)
|
||||
@@ -11,7 +11,7 @@
|
||||
[](https://www.mongodb.com/)
|
||||
<br />
|
||||
[](LICENSE)
|
||||
[](https://hub.docker.com/r/rostislavdugin/postgresus)
|
||||
[](https://hub.docker.com/r/databasus/databasus)
|
||||
[](https://github.com/databasus/databasus)
|
||||
[](https://github.com/databasus/databasus)
|
||||
[](https://github.com/databasus/databasus)
|
||||
@@ -31,8 +31,6 @@
|
||||
<img src="assets/dashboard-dark.svg" alt="Databasus Dark Dashboard" width="800" style="margin-bottom: 10px;"/>
|
||||
|
||||
<img src="assets/dashboard.svg" alt="Databasus Dashboard" width="800"/>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
@@ -43,7 +41,7 @@
|
||||
|
||||
- **PostgreSQL**: 12, 13, 14, 15, 16, 17 and 18
|
||||
- **MySQL**: 5.7, 8 and 9
|
||||
- **MariaDB**: 10 and 11
|
||||
- **MariaDB**: 10, 11 and 12
|
||||
- **MongoDB**: 4, 5, 6, 7 and 8
|
||||
|
||||
### 🔄 **Scheduled backups**
|
||||
@@ -52,6 +50,13 @@
|
||||
- **Precise timing**: run backups at specific times (e.g., 4 AM during low traffic)
|
||||
- **Smart compression**: 4-8x space savings with balanced compression (~20% overhead)
|
||||
|
||||
### 🗑️ **Retention policies**
|
||||
|
||||
- **Time period**: Keep backups for a fixed duration (e.g., 7 days, 3 months, 1 year)
|
||||
- **Count**: Keep a fixed number of the most recent backups (e.g., last 30)
|
||||
- **GFS (Grandfather-Father-Son)**: Layered retention — keep hourly, daily, weekly, monthly and yearly backups independently for fine-grained long-term history (enterprises requirement)
|
||||
- **Size limits**: Set per-backup and total storage size caps to control storage usage
|
||||
|
||||
### 🗄️ **Multiple storage destinations** <a href="https://databasus.com/storages">(view supported)</a>
|
||||
|
||||
- **Local storage**: Keep backups on your VPS/server
|
||||
@@ -71,6 +76,8 @@
|
||||
- **Encryption for secrets**: Any sensitive data is encrypted and never exposed, even in logs or error messages
|
||||
- **Read-only user**: Databasus uses a read-only user by default for backups and never stores anything that can modify your data
|
||||
|
||||
It is also important for Databasus that you are able to decrypt and restore backups from storages (local, S3, etc.) without Databasus itself. To do so, read our guide on [how to recover directly from storage](https://databasus.com/how-to-recover-without-databasus). We avoid "vendor lock-in" even to open source tool!
|
||||
|
||||
### 👥 **Suitable for teams** <a href="https://databasus.com/access-management">(docs)</a>
|
||||
|
||||
- **Workspaces**: Group databases, notifiers and storages for different projects or teams
|
||||
@@ -84,14 +91,16 @@
|
||||
- **Dark & light themes**: Choose the look that suits your workflow
|
||||
- **Mobile adaptive**: Check your backups from anywhere on any device
|
||||
|
||||
### ☁️ **Works with self-hosted & cloud databases**
|
||||
### 🔌 **Connection types**
|
||||
|
||||
Databasus works seamlessly with both self-hosted PostgreSQL and cloud-managed databases:
|
||||
- **Remote** — Databasus connects directly to the database over the network (recommended in read-only mode). No agent or additional software required. Works with cloud-managed and self-hosted databases
|
||||
- **Agent** — A lightweight Databasus agent (written in Go) runs alongside the database. The agent streams backups directly to Databasus, so the database never needs to be exposed publicly. Supports host-installed databases and Docker containers
|
||||
|
||||
- **Cloud support**: AWS RDS, Google Cloud SQL, Azure Database for PostgreSQL
|
||||
- **Self-hosted**: Any PostgreSQL instance you manage yourself
|
||||
- **Why no PITR support?**: Cloud providers already offer native PITR, and external PITR backups cannot be restored to managed cloud databases — making them impractical for cloud-hosted PostgreSQL
|
||||
- **Practical granularity**: Hourly and daily backups are sufficient for 99% of projects without the operational complexity of WAL archiving
|
||||
### 📦 **Backup types**
|
||||
|
||||
- **Logical** — Native dump of the database in its engine-specific binary format. Compressed and streamed directly to storage with no intermediate files
|
||||
- **Physical** — File-level copy of the entire database cluster. Faster backup and restore for large datasets compared to logical dumps
|
||||
- **Incremental** — Physical base backup combined with continuous WAL segment archiving. Enables Point-in-time recovery (PITR) — restore to any second between backups. Designed for disaster recovery and near-zero data loss requirements
|
||||
|
||||
### 🐳 **Self-hosted & secure**
|
||||
|
||||
@@ -220,8 +229,9 @@ For more options (NodePort, TLS, HTTPRoute for Gateway API), see the [Helm chart
|
||||
3. **Configure schedule**: Choose from hourly, daily, weekly, monthly or cron intervals
|
||||
4. **Set database connection**: Enter your database credentials and connection details
|
||||
5. **Choose storage**: Select where to store your backups (local, S3, Google Drive, etc.)
|
||||
6. **Add notifications** (optional): Configure email, Telegram, Slack, or webhook notifications
|
||||
7. **Save and start**: Databasus will validate settings and begin the backup schedule
|
||||
6. **Configure retention policy**: Choose time period, count or GFS to control how long backups are kept
|
||||
7. **Add notifications** (optional): Configure email, Telegram, Slack, or webhook notifications
|
||||
8. **Save and start**: Databasus will validate settings and begin the backup schedule
|
||||
|
||||
### 🔑 Resetting password <a href="https://databasus.com/password">(docs)</a>
|
||||
|
||||
@@ -233,66 +243,39 @@ docker exec -it databasus ./main --new-password="YourNewSecurePassword123" --ema
|
||||
|
||||
Replace `admin` with the actual email address of the user whose password you want to reset.
|
||||
|
||||
### 💾 Backuping Databasus itself
|
||||
|
||||
After installation, it is also recommended to <a href="https://databasus.com/faq#backup-databasus">backup your Databasus itself</a> or, at least, to copy secret key used for encryption (30 seconds is needed). So you are able to restore from your encrypted backups if you lose access to the server with Databasus or it is corrupted.
|
||||
|
||||
---
|
||||
|
||||
## 📝 License
|
||||
|
||||
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details
|
||||
|
||||
---
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
Contributions are welcome! Read the <a href="https://databasus.com/contribute">contributing guide</a> for more details, priorities and rules. If you want to contribute but don't know where to start, message me on Telegram [@rostislav_dugin](https://t.me/rostislav_dugin)
|
||||
|
||||
Also you can join our large community of developers, DBAs and DevOps engineers on Telegram [@databasus_community](https://t.me/databasus_community).
|
||||
|
||||
--
|
||||
## FAQ
|
||||
|
||||
## 📖 Migration guide
|
||||
|
||||
Databasus is the new name for Postgresus. You can stay with latest version of Postgresus if you wish. If you want to migrate - follow installation steps for Databasus itself.
|
||||
|
||||
Just renaming an image is not enough as Postgresus and Databasus use different data folders and internal database naming.
|
||||
|
||||
You can put a new Databasus image with updated volume near the old Postgresus and run it (stop Postgresus before):
|
||||
|
||||
```
|
||||
services:
|
||||
databasus:
|
||||
container_name: databasus
|
||||
image: databasus/databasus:latest
|
||||
ports:
|
||||
- "4005:4005"
|
||||
volumes:
|
||||
- ./databasus-data:/databasus-data
|
||||
restart: unless-stopped
|
||||
```
|
||||
|
||||
Then manually move databases from Postgresus to Databasus.
|
||||
|
||||
### Why was Postgresus renamed to Databasus?
|
||||
|
||||
Databasus has been developed since 2023. It was internal tool to backup production and home projects databases. In start of 2025 it was released as open source project on GitHub. By the end of 2025 it became popular and the time for renaming has come in December 2025.
|
||||
|
||||
It was an important step for the project to grow. Actually, there are a couple of reasons:
|
||||
|
||||
1. Postgresus is no longer a little tool that just adds UI for pg_dump for little projects. It became a tool both for individual users, DevOps, DBAs, teams, companies and even large enterprises. Tens of thousands of users use Postgresus every day. Postgresus grew into a reliable backup management tool. Initial positioning is no longer suitable: the project is not just a UI wrapper, it's a solid backup management system now (despite it's still easy to use).
|
||||
|
||||
2. New databases are supported: although the primary focus is PostgreSQL (with 100% support in the most efficient way) and always will be, Databasus added support for MySQL, MariaDB and MongoDB. Later more databases will be supported.
|
||||
|
||||
3. Trademark issue: "postgres" is a trademark of PostgreSQL Inc. and cannot be used in the project name. So for safety and legal reasons, we had to rename the project.
|
||||
|
||||
## AI disclaimer
|
||||
### AI disclaimer
|
||||
|
||||
There have been questions about AI usage in project development in issues and discussions. As the project focuses on security, reliability and production usage, it's important to explain how AI is used in the development process.
|
||||
|
||||
First of all, we are proud to say that Databasus has been accepted into both [Claude for Open Source](https://claude.com/contact-sales/claude-for-oss) by Anthropic and [Codex for Open Source](https://developers.openai.com/codex/community/codex-for-oss/) by OpenAI in March 2026. For us it is one more signal that the project was recognized as important open-source software and was as critical infrastructure worth supporting independently by two of the world's leading AI companies. Read more at [databasus.com/faq](https://databasus.com/faq#oss-programs).
|
||||
|
||||
Despite of this, we have the following rules how AI is used in the development process:
|
||||
|
||||
AI is used as a helper for:
|
||||
|
||||
- verification of code quality and searching for vulnerabilities
|
||||
- cleaning up and improving documentation, comments and code
|
||||
- assistance during development
|
||||
- double-checking PRs and commits after human review
|
||||
- additional security analysis of PRs via Codex Security
|
||||
|
||||
AI is not used for:
|
||||
|
||||
@@ -314,3 +297,13 @@ Moreover, it's important to note that we do not differentiate between bad human
|
||||
Even if code is written manually by a human, it's not guaranteed to be merged. Vibe code is not allowed at all and all such PRs are rejected by default (see [contributing guide](https://databasus.com/contribute)).
|
||||
|
||||
We also draw attention to fast issue resolution and security [vulnerability reporting](https://github.com/databasus/databasus?tab=security-ov-file#readme).
|
||||
|
||||
### You have a cloud version — are you truly open source?
|
||||
|
||||
Yes. Every feature available in Databasus Cloud is equally available in the self-hosted version with no restrictions, no feature gates and no usage limits. The entire codebase is Apache 2.0 licensed and always will be.
|
||||
|
||||
Databasus is not "open core." We do not withhold features behind a paid tier and then call the limited remainder "open source," as projects like GitLab or Sentry do. We believe open source means the complete product is open, not just a marketing label on a stripped-down edition.
|
||||
|
||||
Databasus Cloud runs the exact same code as the self-hosted version. The only difference is that we take care of infrastructure, availability, backups, reservations, monitoring and updates for you — so you don't have to. If you are using cloud, you can always move your databases from cloud to self-hosted if you wish.
|
||||
|
||||
Revenue from Cloud funds full-time development of the project. Most large open-source projects rely on corporate backing or sponsorship to survive. We chose a different path: Databasus sustains itself so it can grow and improve independently, without being tied to any enterprise or sponsor.
|
||||
|
||||
3
agent/.env.example
Normal file
3
agent/.env.example
Normal file
@@ -0,0 +1,3 @@
|
||||
ENV_MODE=development
|
||||
AGENT_DB_ID=your-database-id
|
||||
AGENT_TOKEN=your-agent-token
|
||||
27
agent/.gitignore
vendored
Normal file
27
agent/.gitignore
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
main
|
||||
.env
|
||||
docker-compose.yml
|
||||
!e2e/docker-compose.yml
|
||||
pgdata
|
||||
pgdata_test/
|
||||
mysqldata/
|
||||
mariadbdata/
|
||||
main.exe
|
||||
swagger/
|
||||
swagger/*
|
||||
swagger/docs.go
|
||||
swagger/swagger.json
|
||||
swagger/swagger.yaml
|
||||
postgresus-backend.exe
|
||||
databasus-backend.exe
|
||||
ui/build/*
|
||||
pgdata-for-restore/
|
||||
temp/
|
||||
cmd.exe
|
||||
temp/
|
||||
valkey-data/
|
||||
victoria-logs-data/
|
||||
databasus.json
|
||||
.test-tmp/
|
||||
databasus.log
|
||||
wal-queue/
|
||||
41
agent/.golangci.yml
Normal file
41
agent/.golangci.yml
Normal file
@@ -0,0 +1,41 @@
|
||||
version: "2"
|
||||
|
||||
run:
|
||||
timeout: 5m
|
||||
tests: false
|
||||
concurrency: 4
|
||||
|
||||
linters:
|
||||
default: standard
|
||||
enable:
|
||||
- funcorder
|
||||
- bodyclose
|
||||
- errorlint
|
||||
- gocritic
|
||||
- unconvert
|
||||
- misspell
|
||||
- errname
|
||||
- noctx
|
||||
- modernize
|
||||
|
||||
settings:
|
||||
errcheck:
|
||||
check-type-assertions: true
|
||||
|
||||
formatters:
|
||||
enable:
|
||||
- gofumpt
|
||||
- golines
|
||||
- gci
|
||||
|
||||
settings:
|
||||
golines:
|
||||
max-len: 120
|
||||
gofumpt:
|
||||
module-path: databasus-agent
|
||||
extra-rules: true
|
||||
gci:
|
||||
sections:
|
||||
- standard
|
||||
- default
|
||||
- localmodule
|
||||
41
agent/Makefile
Normal file
41
agent/Makefile
Normal file
@@ -0,0 +1,41 @@
|
||||
.PHONY: run build test lint e2e e2e-clean e2e-backup-restore e2e-backup-restore-clean
|
||||
|
||||
-include .env
|
||||
export
|
||||
|
||||
run:
|
||||
go run cmd/main.go start \
|
||||
--databasus-host http://localhost:4005 \
|
||||
--db-id $(AGENT_DB_ID) \
|
||||
--token $(AGENT_TOKEN) \
|
||||
--pg-host 127.0.0.1 \
|
||||
--pg-port 7433 \
|
||||
--pg-user devuser \
|
||||
--pg-password devpassword \
|
||||
--pg-type docker \
|
||||
--pg-docker-container-name dev-postgres \
|
||||
--pg-wal-dir ./wal-queue \
|
||||
--skip-update
|
||||
|
||||
build:
|
||||
CGO_ENABLED=0 go build -ldflags "-X main.Version=$(VERSION)" -o databasus-agent ./cmd/main.go
|
||||
|
||||
test:
|
||||
go test -count=1 -failfast ./internal/...
|
||||
|
||||
lint:
|
||||
golangci-lint fmt ./cmd/... ./internal/... ./e2e/... && golangci-lint run ./cmd/... ./internal/... ./e2e/...
|
||||
|
||||
e2e:
|
||||
cd e2e && docker compose build --no-cache e2e-mock-server
|
||||
cd e2e && docker compose build
|
||||
cd e2e && docker compose run --rm e2e-agent-builder
|
||||
cd e2e && docker compose up -d e2e-postgres e2e-mock-server
|
||||
cd e2e && docker compose run --rm e2e-agent-runner
|
||||
cd e2e && docker compose run --rm e2e-agent-docker
|
||||
cd e2e && docker compose down -v
|
||||
|
||||
e2e-clean:
|
||||
cd e2e && docker compose down -v --rmi local
|
||||
cd e2e && docker compose -f docker-compose.backup-restore.yml down -v --rmi local 2>/dev/null || true
|
||||
rm -rf e2e/artifacts
|
||||
245
agent/cmd/main.go
Normal file
245
agent/cmd/main.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"databasus-agent/internal/config"
|
||||
"databasus-agent/internal/features/api"
|
||||
"databasus-agent/internal/features/restore"
|
||||
"databasus-agent/internal/features/start"
|
||||
"databasus-agent/internal/features/upgrade"
|
||||
"databasus-agent/internal/logger"
|
||||
)
|
||||
|
||||
var Version = "dev"
|
||||
|
||||
func main() {
|
||||
if len(os.Args) < 2 {
|
||||
printUsage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
switch os.Args[1] {
|
||||
case "start":
|
||||
runStart(os.Args[2:])
|
||||
case "_run":
|
||||
runDaemon(os.Args[2:])
|
||||
case "stop":
|
||||
runStop()
|
||||
case "status":
|
||||
runStatus()
|
||||
case "restore":
|
||||
runRestore(os.Args[2:])
|
||||
case "version":
|
||||
fmt.Println(Version)
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "unknown command: %s\n", os.Args[1])
|
||||
printUsage()
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func runStart(args []string) {
|
||||
fs := flag.NewFlagSet("start", flag.ExitOnError)
|
||||
|
||||
isSkipUpdate := fs.Bool("skip-update", false, "Skip auto-update check")
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.LoadFromJSONAndArgs(fs, args)
|
||||
|
||||
if err := cfg.SaveToJSON(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to save config: %v\n", err)
|
||||
}
|
||||
|
||||
log := logger.GetLogger()
|
||||
|
||||
isDev := checkIsDevelopment()
|
||||
runUpdateCheck(cfg.DatabasusHost, *isSkipUpdate, isDev, log)
|
||||
|
||||
if err := start.Start(cfg, Version, isDev, log); err != nil {
|
||||
if errors.Is(err, upgrade.ErrUpgradeRestart) {
|
||||
reexecAfterUpgrade(log)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func runDaemon(args []string) {
|
||||
fs := flag.NewFlagSet("_run", flag.ExitOnError)
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
log := logger.GetLogger()
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.LoadFromJSON()
|
||||
|
||||
if err := start.RunDaemon(cfg, Version, checkIsDevelopment(), log); err != nil {
|
||||
if errors.Is(err, upgrade.ErrUpgradeRestart) {
|
||||
reexecAfterUpgrade(log)
|
||||
}
|
||||
|
||||
log.Error("Agent exited with error", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func runStop() {
|
||||
log := logger.GetLogger()
|
||||
|
||||
if err := start.Stop(log); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func runStatus() {
|
||||
log := logger.GetLogger()
|
||||
|
||||
if err := start.Status(log); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func runRestore(args []string) {
|
||||
fs := flag.NewFlagSet("restore", flag.ExitOnError)
|
||||
|
||||
pgDataDir := fs.String("target-dir", "", "Target pgdata directory (required)")
|
||||
backupID := fs.String("backup-id", "", "Full backup UUID (optional)")
|
||||
targetTime := fs.String("target-time", "", "PITR target time in RFC3339 (optional)")
|
||||
isSkipUpdate := fs.Bool("skip-update", false, "Skip auto-update check")
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.LoadFromJSONAndArgs(fs, args)
|
||||
|
||||
if err := cfg.SaveToJSON(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to save config: %v\n", err)
|
||||
}
|
||||
|
||||
log := logger.GetLogger()
|
||||
|
||||
isDev := checkIsDevelopment()
|
||||
runUpdateCheck(cfg.DatabasusHost, *isSkipUpdate, isDev, log)
|
||||
|
||||
if *pgDataDir == "" {
|
||||
fmt.Fprintln(os.Stderr, "Error: --target-dir is required")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if cfg.DatabasusHost == "" || cfg.Token == "" {
|
||||
fmt.Fprintln(os.Stderr, "Error: databasus-host and token must be configured")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if cfg.PgType != "host" && cfg.PgType != "docker" {
|
||||
fmt.Fprintf(os.Stderr, "Error: --pg-type must be 'host' or 'docker', got '%s'\n", cfg.PgType)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
apiClient := api.NewClient(cfg.DatabasusHost, cfg.Token, log)
|
||||
restorer := restore.NewRestorer(apiClient, log, *pgDataDir, *backupID, *targetTime, cfg.PgType)
|
||||
|
||||
ctx := context.Background()
|
||||
if err := restorer.Run(ctx); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func printUsage() {
|
||||
fmt.Fprintln(os.Stderr, "Usage: databasus-agent <command> [flags]")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, "Commands:")
|
||||
fmt.Fprintln(os.Stderr, " start Start the agent (WAL archiving + basebackups)")
|
||||
fmt.Fprintln(os.Stderr, " stop Stop a running agent")
|
||||
fmt.Fprintln(os.Stderr, " status Show agent status")
|
||||
fmt.Fprintln(os.Stderr, " restore Restore a database from backup")
|
||||
fmt.Fprintln(os.Stderr, " version Print agent version")
|
||||
}
|
||||
|
||||
func runUpdateCheck(host string, isSkipUpdate, isDev bool, log *slog.Logger) {
|
||||
if isSkipUpdate {
|
||||
return
|
||||
}
|
||||
|
||||
if host == "" {
|
||||
return
|
||||
}
|
||||
|
||||
apiClient := api.NewClient(host, "", log)
|
||||
|
||||
isUpgraded, err := upgrade.CheckAndUpdate(apiClient, Version, isDev, log)
|
||||
if err != nil {
|
||||
log.Error("Auto-update failed", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if isUpgraded {
|
||||
reexecAfterUpgrade(log)
|
||||
}
|
||||
}
|
||||
|
||||
func checkIsDevelopment() bool {
|
||||
dir, err := os.Getwd()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for range 3 {
|
||||
if data, err := os.ReadFile(filepath.Join(dir, ".env")); err == nil {
|
||||
return parseEnvMode(data)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
dir = filepath.Dir(dir)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func parseEnvMode(data []byte) bool {
|
||||
for line := range strings.SplitSeq(string(data), "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.SplitN(line, "=", 2)
|
||||
if len(parts) == 2 && strings.TrimSpace(parts[0]) == "ENV_MODE" {
|
||||
return strings.TrimSpace(parts[1]) == "development"
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func reexecAfterUpgrade(log *slog.Logger) {
|
||||
selfPath, err := os.Executable()
|
||||
if err != nil {
|
||||
log.Error("Failed to resolve executable for re-exec", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
log.Info("Re-executing after upgrade...")
|
||||
|
||||
if err := syscall.Exec(selfPath, os.Args, os.Environ()); err != nil {
|
||||
log.Error("Failed to re-exec after upgrade", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
58
agent/docker-compose.yml.example
Normal file
58
agent/docker-compose.yml.example
Normal file
@@ -0,0 +1,58 @@
|
||||
services:
|
||||
dev-postgres:
|
||||
image: postgres:17
|
||||
container_name: dev-postgres
|
||||
environment:
|
||||
POSTGRES_DB: devdb
|
||||
POSTGRES_USER: devuser
|
||||
POSTGRES_PASSWORD: devpassword
|
||||
ports:
|
||||
- "7433:5432"
|
||||
command:
|
||||
- bash
|
||||
- -c
|
||||
- |
|
||||
mkdir -p /wal-queue && chown postgres:postgres /wal-queue
|
||||
exec docker-entrypoint.sh postgres \
|
||||
-c wal_level=replica \
|
||||
-c max_wal_senders=3 \
|
||||
-c archive_mode=on \
|
||||
-c "archive_command=cp %p /wal-queue/%f"
|
||||
volumes:
|
||||
- ./wal-queue:/wal-queue
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U devuser -d devdb"]
|
||||
interval: 2s
|
||||
timeout: 5s
|
||||
retries: 30
|
||||
|
||||
db-writer:
|
||||
image: postgres:17
|
||||
container_name: dev-db-writer
|
||||
depends_on:
|
||||
dev-postgres:
|
||||
condition: service_healthy
|
||||
environment:
|
||||
PGHOST: dev-postgres
|
||||
PGPORT: "5432"
|
||||
PGUSER: devuser
|
||||
PGPASSWORD: devpassword
|
||||
PGDATABASE: devdb
|
||||
command:
|
||||
- bash
|
||||
- -c
|
||||
- |
|
||||
echo "Waiting for postgres..."
|
||||
until pg_isready -h dev-postgres -U devuser -d devdb; do sleep 1; done
|
||||
|
||||
psql -c "DROP TABLE IF EXISTS wal_generator;"
|
||||
psql -c "CREATE TABLE wal_generator (id SERIAL PRIMARY KEY, data TEXT NOT NULL);"
|
||||
echo "Starting WAL generation loop..."
|
||||
while true; do
|
||||
echo "Inserting ~50MB of data..."
|
||||
psql -c "INSERT INTO wal_generator (data) SELECT repeat(md5(random()::text), 640) FROM generate_series(1, 2500);"
|
||||
echo "Deleting data..."
|
||||
psql -c "DELETE FROM wal_generator;"
|
||||
echo "Cycle complete, sleeping 5s..."
|
||||
sleep 5
|
||||
done
|
||||
2
agent/e2e/.gitignore
vendored
Normal file
2
agent/e2e/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
artifacts/
|
||||
pgdata/
|
||||
13
agent/e2e/Dockerfile.agent-builder
Normal file
13
agent/e2e/Dockerfile.agent-builder
Normal file
@@ -0,0 +1,13 @@
|
||||
# Builds agent binaries with different versions so
|
||||
# we can test upgrade behavior (v1 -> v2)
|
||||
FROM golang:1.26.1-alpine AS build
|
||||
WORKDIR /src
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
RUN CGO_ENABLED=0 go build -ldflags "-X main.Version=v1.0.0" -o /out/agent-v1 ./cmd/main.go
|
||||
RUN CGO_ENABLED=0 go build -ldflags "-X main.Version=v2.0.0" -o /out/agent-v2 ./cmd/main.go
|
||||
|
||||
FROM alpine:3.21
|
||||
COPY --from=build /out/ /out/
|
||||
CMD ["cp", "-v", "/out/agent-v1", "/out/agent-v2", "/artifacts/"]
|
||||
22
agent/e2e/Dockerfile.agent-docker
Normal file
22
agent/e2e/Dockerfile.agent-docker
Normal file
@@ -0,0 +1,22 @@
|
||||
# Runs backup-restore via docker exec test (test 6). Needs both Docker
|
||||
# CLI (for pg_basebackup via docker exec) and PostgreSQL server (for
|
||||
# restore verification).
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
ca-certificates curl gnupg2 locales postgresql-common && \
|
||||
sed -i '/en_US.UTF-8/s/^# //g' /etc/locale.gen && \
|
||||
locale-gen && \
|
||||
/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
postgresql-17 && \
|
||||
install -m 0755 -d /etc/apt/keyrings && \
|
||||
curl -fsSL https://download.docker.com/linux/debian/gpg -o /etc/apt/keyrings/docker.asc && \
|
||||
echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/debian bookworm stable" > /etc/apt/sources.list.d/docker.list && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends docker-ce-cli && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /tmp
|
||||
ENTRYPOINT []
|
||||
14
agent/e2e/Dockerfile.agent-runner
Normal file
14
agent/e2e/Dockerfile.agent-runner
Normal file
@@ -0,0 +1,14 @@
|
||||
# Runs upgrade and host-mode backup-restore tests (tests 1-5). Needs
|
||||
# full PostgreSQL server for backup-restore lifecycle tests.
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
ca-certificates curl gnupg2 postgresql-common && \
|
||||
/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
postgresql-17 && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /tmp
|
||||
ENTRYPOINT []
|
||||
16
agent/e2e/Dockerfile.backup-restore-runner
Normal file
16
agent/e2e/Dockerfile.backup-restore-runner
Normal file
@@ -0,0 +1,16 @@
|
||||
# Runs backup-restore lifecycle tests with a specific PostgreSQL version.
|
||||
# Used for PG version matrix testing (15, 16, 17, 18).
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
ARG PG_VERSION=17
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
ca-certificates curl gnupg2 postgresql-common && \
|
||||
/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
postgresql-${PG_VERSION} && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /tmp
|
||||
ENTRYPOINT []
|
||||
10
agent/e2e/Dockerfile.mock-server
Normal file
10
agent/e2e/Dockerfile.mock-server
Normal file
@@ -0,0 +1,10 @@
|
||||
# Mock databasus API server for version checks and binary downloads. Just
|
||||
# serves static responses and files from the `artifacts` directory.
|
||||
FROM golang:1.26.1-alpine AS build
|
||||
WORKDIR /app
|
||||
COPY mock-server/main.go .
|
||||
RUN CGO_ENABLED=0 go build -o mock-server main.go
|
||||
|
||||
FROM alpine:3.21
|
||||
COPY --from=build /app/mock-server /usr/local/bin/mock-server
|
||||
ENTRYPOINT ["mock-server"]
|
||||
33
agent/e2e/docker-compose.backup-restore.yml
Normal file
33
agent/e2e/docker-compose.backup-restore.yml
Normal file
@@ -0,0 +1,33 @@
|
||||
services:
|
||||
e2e-br-mock-server:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.mock-server
|
||||
volumes:
|
||||
- backup-storage:/backup-storage
|
||||
container_name: e2e-br-mock-server
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "-q", "--spider", "http://localhost:4050/health"]
|
||||
interval: 2s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
|
||||
e2e-br-runner:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.backup-restore-runner
|
||||
args:
|
||||
PG_VERSION: ${PG_VERSION:-17}
|
||||
volumes:
|
||||
- ./artifacts:/opt/agent/artifacts:ro
|
||||
- ./scripts:/opt/agent/scripts:ro
|
||||
depends_on:
|
||||
e2e-br-mock-server:
|
||||
condition: service_healthy
|
||||
container_name: e2e-br-runner
|
||||
command: ["bash", "/opt/agent/scripts/test-pg-host-path.sh"]
|
||||
environment:
|
||||
MOCK_SERVER_OVERRIDE: "http://e2e-br-mock-server:4050"
|
||||
|
||||
volumes:
|
||||
backup-storage:
|
||||
84
agent/e2e/docker-compose.yml
Normal file
84
agent/e2e/docker-compose.yml
Normal file
@@ -0,0 +1,84 @@
|
||||
services:
|
||||
e2e-agent-builder:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: e2e/Dockerfile.agent-builder
|
||||
volumes:
|
||||
- ./artifacts:/artifacts
|
||||
container_name: e2e-agent-builder
|
||||
|
||||
e2e-postgres:
|
||||
image: postgres:17
|
||||
environment:
|
||||
POSTGRES_DB: testdb
|
||||
POSTGRES_USER: testuser
|
||||
POSTGRES_PASSWORD: testpassword
|
||||
container_name: e2e-agent-postgres
|
||||
command:
|
||||
- bash
|
||||
- -c
|
||||
- |
|
||||
mkdir -p /wal-queue && chown postgres:postgres /wal-queue
|
||||
exec docker-entrypoint.sh postgres \
|
||||
-c wal_level=replica \
|
||||
-c max_wal_senders=3 \
|
||||
-c archive_mode=on \
|
||||
-c "archive_command=cp %p /wal-queue/%f"
|
||||
volumes:
|
||||
- ./pgdata:/var/lib/postgresql/data
|
||||
- wal-queue:/wal-queue
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U testuser -d testdb"]
|
||||
interval: 2s
|
||||
timeout: 5s
|
||||
retries: 30
|
||||
|
||||
e2e-mock-server:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.mock-server
|
||||
volumes:
|
||||
- ./artifacts:/artifacts:ro
|
||||
- backup-storage:/backup-storage
|
||||
container_name: e2e-mock-server
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "-q", "--spider", "http://localhost:4050/health"]
|
||||
interval: 2s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
|
||||
e2e-agent-runner:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.agent-runner
|
||||
volumes:
|
||||
- ./artifacts:/opt/agent/artifacts:ro
|
||||
- ./scripts:/opt/agent/scripts:ro
|
||||
depends_on:
|
||||
e2e-postgres:
|
||||
condition: service_healthy
|
||||
e2e-mock-server:
|
||||
condition: service_healthy
|
||||
container_name: e2e-agent-runner
|
||||
command: ["bash", "/opt/agent/scripts/run-all.sh", "host"]
|
||||
|
||||
e2e-agent-docker:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.agent-docker
|
||||
volumes:
|
||||
- ./artifacts:/opt/agent/artifacts:ro
|
||||
- ./scripts:/opt/agent/scripts:ro
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
- wal-queue:/wal-queue
|
||||
depends_on:
|
||||
e2e-postgres:
|
||||
condition: service_healthy
|
||||
e2e-mock-server:
|
||||
condition: service_healthy
|
||||
container_name: e2e-agent-docker
|
||||
command: ["bash", "/opt/agent/scripts/run-all.sh", "docker"]
|
||||
|
||||
volumes:
|
||||
wal-queue:
|
||||
backup-storage:
|
||||
477
agent/e2e/mock-server/main.go
Normal file
477
agent/e2e/mock-server/main.go
Normal file
@@ -0,0 +1,477 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const backupStorageDir = "/backup-storage"
|
||||
|
||||
type walSegment struct {
|
||||
BackupID string
|
||||
SegmentName string
|
||||
FilePath string
|
||||
SizeBytes int64
|
||||
}
|
||||
|
||||
type server struct {
|
||||
mu sync.RWMutex
|
||||
version string
|
||||
binaryPath string
|
||||
|
||||
backupID string
|
||||
backupFilePath string
|
||||
startSegment string
|
||||
stopSegment string
|
||||
isFinalized bool
|
||||
walSegments []walSegment
|
||||
backupCreatedAt time.Time
|
||||
}
|
||||
|
||||
func main() {
|
||||
version := "v2.0.0"
|
||||
binaryPath := "/artifacts/agent-v2"
|
||||
port := "4050"
|
||||
|
||||
_ = os.MkdirAll(backupStorageDir, 0o755)
|
||||
|
||||
s := &server{version: version, binaryPath: binaryPath}
|
||||
|
||||
// System endpoints
|
||||
http.HandleFunc("/api/v1/system/version", s.handleVersion)
|
||||
http.HandleFunc("/api/v1/system/agent", s.handleAgentDownload)
|
||||
|
||||
// Backup endpoints
|
||||
http.HandleFunc("/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup", s.handleChainValidity)
|
||||
http.HandleFunc("/api/v1/backups/postgres/wal/next-full-backup-time", s.handleNextBackupTime)
|
||||
http.HandleFunc("/api/v1/backups/postgres/wal/upload/full-start", s.handleFullStart)
|
||||
http.HandleFunc("/api/v1/backups/postgres/wal/upload/full-complete", s.handleFullComplete)
|
||||
http.HandleFunc("/api/v1/backups/postgres/wal/upload/wal", s.handleWalUpload)
|
||||
http.HandleFunc("/api/v1/backups/postgres/wal/error", s.handleError)
|
||||
|
||||
// Restore endpoints
|
||||
http.HandleFunc("/api/v1/backups/postgres/wal/restore/plan", s.handleRestorePlan)
|
||||
http.HandleFunc("/api/v1/backups/postgres/wal/restore/download", s.handleRestoreDownload)
|
||||
|
||||
// Mock control endpoints
|
||||
http.HandleFunc("/mock/set-version", s.handleSetVersion)
|
||||
http.HandleFunc("/mock/set-binary-path", s.handleSetBinaryPath)
|
||||
http.HandleFunc("/mock/backup-status", s.handleBackupStatus)
|
||||
http.HandleFunc("/mock/reset", s.handleReset)
|
||||
http.HandleFunc("/health", s.handleHealth)
|
||||
|
||||
addr := ":" + port
|
||||
log.Printf("Mock server starting on %s (version=%s, binary=%s)", addr, version, binaryPath)
|
||||
|
||||
if err := http.ListenAndServe(addr, nil); err != nil {
|
||||
log.Fatalf("Server failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- System handlers ---
|
||||
|
||||
func (s *server) handleVersion(w http.ResponseWriter, _ *http.Request) {
|
||||
s.mu.RLock()
|
||||
v := s.version
|
||||
s.mu.RUnlock()
|
||||
|
||||
log.Printf("GET /api/v1/system/version -> %s", v)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"version": v})
|
||||
}
|
||||
|
||||
func (s *server) handleAgentDownload(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.RLock()
|
||||
path := s.binaryPath
|
||||
s.mu.RUnlock()
|
||||
|
||||
log.Printf("GET /api/v1/system/agent (arch=%s) -> serving %s", r.URL.Query().Get("arch"), path)
|
||||
|
||||
http.ServeFile(w, r, path)
|
||||
}
|
||||
|
||||
// --- Backup handlers ---
|
||||
|
||||
func (s *server) handleChainValidity(w http.ResponseWriter, _ *http.Request) {
|
||||
s.mu.RLock()
|
||||
isFinalized := s.isFinalized
|
||||
s.mu.RUnlock()
|
||||
|
||||
log.Printf("GET chain-validity -> isFinalized=%v", isFinalized)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if isFinalized {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"isValid": true,
|
||||
})
|
||||
} else {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"isValid": false,
|
||||
"error": "no full backup found",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *server) handleNextBackupTime(w http.ResponseWriter, _ *http.Request) {
|
||||
log.Printf("GET next-full-backup-time")
|
||||
|
||||
nextTime := time.Now().UTC().Add(1 * time.Hour)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"nextFullBackupTime": nextTime.Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *server) handleFullStart(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "POST only", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
backupID := generateID()
|
||||
filePath := filepath.Join(backupStorageDir, backupID+".zst")
|
||||
|
||||
file, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
log.Printf("ERROR creating backup file: %v", err)
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
bytesWritten, err := io.Copy(file, r.Body)
|
||||
_ = file.Close()
|
||||
|
||||
if err != nil {
|
||||
log.Printf("ERROR writing backup data: %v", err)
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.backupID = backupID
|
||||
s.backupFilePath = filePath
|
||||
s.backupCreatedAt = time.Now().UTC()
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Printf("POST full-start -> backupID=%s, size=%d bytes", backupID, bytesWritten)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"backupId": backupID})
|
||||
}
|
||||
|
||||
func (s *server) handleFullComplete(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "POST only", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var body struct {
|
||||
BackupID string `json:"backupId"`
|
||||
StartSegment string `json:"startSegment"`
|
||||
StopSegment string `json:"stopSegment"`
|
||||
Error *string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if body.Error != nil {
|
||||
log.Printf("POST full-complete -> backupID=%s ERROR: %s", body.BackupID, *body.Error)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.startSegment = body.StartSegment
|
||||
s.stopSegment = body.StopSegment
|
||||
s.isFinalized = true
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Printf(
|
||||
"POST full-complete -> backupID=%s, start=%s, stop=%s",
|
||||
body.BackupID,
|
||||
body.StartSegment,
|
||||
body.StopSegment,
|
||||
)
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (s *server) handleWalUpload(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "POST only", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
segmentName := r.Header.Get("X-Wal-Segment-Name")
|
||||
if segmentName == "" {
|
||||
http.Error(w, "missing X-Wal-Segment-Name header", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
walBackupID := generateID()
|
||||
filePath := filepath.Join(backupStorageDir, walBackupID+".zst")
|
||||
|
||||
file, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
log.Printf("ERROR creating WAL file: %v", err)
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
bytesWritten, err := io.Copy(file, r.Body)
|
||||
_ = file.Close()
|
||||
|
||||
if err != nil {
|
||||
log.Printf("ERROR writing WAL data: %v", err)
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.walSegments = append(s.walSegments, walSegment{
|
||||
BackupID: walBackupID,
|
||||
SegmentName: segmentName,
|
||||
FilePath: filePath,
|
||||
SizeBytes: bytesWritten,
|
||||
})
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Printf("POST wal-upload -> segment=%s, walBackupID=%s, size=%d", segmentName, walBackupID, bytesWritten)
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (s *server) handleError(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "POST only", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var body struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
log.Printf("POST error -> failed to decode: %v", err)
|
||||
} else {
|
||||
log.Printf("POST error -> %s", body.Error)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// --- Restore handlers ---
|
||||
|
||||
func (s *server) handleRestorePlan(w http.ResponseWriter, _ *http.Request) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if !s.isFinalized {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "no_backups",
|
||||
"message": "No full backups available",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
backupFileInfo, err := os.Stat(s.backupFilePath)
|
||||
if err != nil {
|
||||
log.Printf("ERROR stat backup file: %v", err)
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
backupSizeBytes := backupFileInfo.Size()
|
||||
totalSizeBytes := backupSizeBytes
|
||||
|
||||
walSegmentsJSON := make([]map[string]any, 0, len(s.walSegments))
|
||||
|
||||
latestSegment := ""
|
||||
|
||||
for _, segment := range s.walSegments {
|
||||
totalSizeBytes += segment.SizeBytes
|
||||
latestSegment = segment.SegmentName
|
||||
|
||||
walSegmentsJSON = append(walSegmentsJSON, map[string]any{
|
||||
"backupId": segment.BackupID,
|
||||
"segmentName": segment.SegmentName,
|
||||
"sizeBytes": segment.SizeBytes,
|
||||
})
|
||||
}
|
||||
|
||||
response := map[string]any{
|
||||
"fullBackup": map[string]any{
|
||||
"id": s.backupID,
|
||||
"fullBackupWalStartSegment": s.startSegment,
|
||||
"fullBackupWalStopSegment": s.stopSegment,
|
||||
"pgVersion": "17",
|
||||
"createdAt": s.backupCreatedAt.Format(time.RFC3339),
|
||||
"sizeBytes": backupSizeBytes,
|
||||
},
|
||||
"walSegments": walSegmentsJSON,
|
||||
"totalSizeBytes": totalSizeBytes,
|
||||
"latestAvailableSegment": latestSegment,
|
||||
}
|
||||
|
||||
log.Printf("GET restore-plan -> backupID=%s, walSegments=%d", s.backupID, len(s.walSegments))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
func (s *server) handleRestoreDownload(w http.ResponseWriter, r *http.Request) {
|
||||
requestedBackupID := r.URL.Query().Get("backupId")
|
||||
if requestedBackupID == "" {
|
||||
http.Error(w, "missing backupId query param", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
filePath := s.findBackupFile(requestedBackupID)
|
||||
if filePath == "" {
|
||||
log.Printf("GET restore-download -> backupId=%s NOT FOUND", requestedBackupID)
|
||||
http.Error(w, "backup not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("GET restore-download -> backupId=%s, file=%s", requestedBackupID, filePath)
|
||||
|
||||
http.ServeFile(w, r, filePath)
|
||||
}
|
||||
|
||||
// --- Mock control handlers ---
|
||||
|
||||
func (s *server) handleSetVersion(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "POST only", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var body struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.version = body.Version
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Printf("POST /mock/set-version -> %s", body.Version)
|
||||
|
||||
_, _ = fmt.Fprintf(w, "version set to %s", body.Version)
|
||||
}
|
||||
|
||||
func (s *server) handleSetBinaryPath(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "POST only", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var body struct {
|
||||
BinaryPath string `json:"binaryPath"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.binaryPath = body.BinaryPath
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Printf("POST /mock/set-binary-path -> %s", body.BinaryPath)
|
||||
|
||||
_, _ = fmt.Fprintf(w, "binary path set to %s", body.BinaryPath)
|
||||
}
|
||||
|
||||
func (s *server) handleBackupStatus(w http.ResponseWriter, _ *http.Request) {
|
||||
s.mu.RLock()
|
||||
isFinalized := s.isFinalized
|
||||
walSegmentCount := len(s.walSegments)
|
||||
s.mu.RUnlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"isFinalized": isFinalized,
|
||||
"walSegmentCount": walSegmentCount,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *server) handleReset(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "POST only", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.backupID = ""
|
||||
s.backupFilePath = ""
|
||||
s.startSegment = ""
|
||||
s.stopSegment = ""
|
||||
s.isFinalized = false
|
||||
s.walSegments = nil
|
||||
s.backupCreatedAt = time.Time{}
|
||||
s.mu.Unlock()
|
||||
|
||||
// Clean stored files
|
||||
entries, _ := os.ReadDir(backupStorageDir)
|
||||
for _, entry := range entries {
|
||||
_ = os.Remove(filepath.Join(backupStorageDir, entry.Name()))
|
||||
}
|
||||
|
||||
log.Printf("POST /mock/reset -> state cleared")
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}
|
||||
|
||||
func (s *server) handleHealth(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}
|
||||
|
||||
// --- Private helpers ---
|
||||
|
||||
func generateID() string {
|
||||
b := make([]byte, 16)
|
||||
_, _ = rand.Read(b)
|
||||
|
||||
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:16])
|
||||
}
|
||||
|
||||
func (s *server) findBackupFile(backupID string) string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.backupID == backupID {
|
||||
return s.backupFilePath
|
||||
}
|
||||
|
||||
for _, segment := range s.walSegments {
|
||||
if segment.BackupID == backupID {
|
||||
return segment.FilePath
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
357
agent/e2e/scripts/backup-restore-helpers.sh
Normal file
357
agent/e2e/scripts/backup-restore-helpers.sh
Normal file
@@ -0,0 +1,357 @@
|
||||
#!/bin/bash
|
||||
# Shared helper functions for backup-restore E2E tests.
|
||||
# Source this file from test scripts: source "$(dirname "$0")/backup-restore-helpers.sh"
|
||||
|
||||
AGENT="/tmp/test-agent"
|
||||
AGENT_PID=""
|
||||
|
||||
cleanup_agent() {
|
||||
if [ -n "$AGENT_PID" ]; then
|
||||
kill "$AGENT_PID" 2>/dev/null || true
|
||||
wait "$AGENT_PID" 2>/dev/null || true
|
||||
AGENT_PID=""
|
||||
fi
|
||||
|
||||
pkill -f "test-agent" 2>/dev/null || true
|
||||
for i in $(seq 1 20); do
|
||||
pgrep -f "test-agent" > /dev/null 2>&1 || break
|
||||
sleep 0.5
|
||||
done
|
||||
pkill -9 -f "test-agent" 2>/dev/null || true
|
||||
sleep 0.5
|
||||
|
||||
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
|
||||
}
|
||||
|
||||
setup_agent() {
|
||||
local artifacts="${1:-/opt/agent/artifacts}"
|
||||
|
||||
cleanup_agent
|
||||
cp "$artifacts/agent-v1" "$AGENT"
|
||||
chmod +x "$AGENT"
|
||||
}
|
||||
|
||||
init_pg_local() {
|
||||
local pgdata="$1"
|
||||
local port="$2"
|
||||
local wal_queue="$3"
|
||||
local pg_bin_dir="$4"
|
||||
|
||||
# Stop any leftover PG from previous test runs
|
||||
su postgres -c "$pg_bin_dir/pg_ctl -D $pgdata stop -m immediate" 2>/dev/null || true
|
||||
su postgres -c "$pg_bin_dir/pg_ctl -D /tmp/restore-pgdata stop -m immediate" 2>/dev/null || true
|
||||
|
||||
mkdir -p "$wal_queue"
|
||||
chown postgres:postgres "$wal_queue"
|
||||
rm -rf "$pgdata"
|
||||
|
||||
su postgres -c "$pg_bin_dir/initdb -D $pgdata" > /dev/null
|
||||
|
||||
cat >> "$pgdata/postgresql.conf" <<PGCONF
|
||||
wal_level = replica
|
||||
archive_mode = on
|
||||
archive_command = 'cp %p $wal_queue/%f'
|
||||
max_wal_senders = 3
|
||||
listen_addresses = 'localhost'
|
||||
port = $port
|
||||
checkpoint_timeout = 30s
|
||||
PGCONF
|
||||
|
||||
echo "local all all trust" > "$pgdata/pg_hba.conf"
|
||||
echo "host all all 127.0.0.1/32 trust" >> "$pgdata/pg_hba.conf"
|
||||
echo "host all all ::1/128 trust" >> "$pgdata/pg_hba.conf"
|
||||
echo "local replication all trust" >> "$pgdata/pg_hba.conf"
|
||||
echo "host replication all 127.0.0.1/32 trust" >> "$pgdata/pg_hba.conf"
|
||||
echo "host replication all ::1/128 trust" >> "$pgdata/pg_hba.conf"
|
||||
|
||||
su postgres -c "$pg_bin_dir/pg_ctl -D $pgdata -l /tmp/pg.log start -w"
|
||||
|
||||
su postgres -c "$pg_bin_dir/psql -p $port -c \"CREATE USER testuser WITH SUPERUSER REPLICATION;\"" > /dev/null 2>&1 || true
|
||||
su postgres -c "$pg_bin_dir/psql -p $port -c \"CREATE DATABASE testdb OWNER testuser;\"" > /dev/null 2>&1 || true
|
||||
|
||||
echo "PostgreSQL initialized and started on port $port"
|
||||
}
|
||||
|
||||
insert_test_data() {
|
||||
local port="$1"
|
||||
local pg_bin_dir="$2"
|
||||
|
||||
su postgres -c "$pg_bin_dir/psql -p $port -U testuser -d testdb" <<SQL
|
||||
CREATE TABLE e2e_test_data (
|
||||
id SERIAL PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
value INT NOT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
);
|
||||
|
||||
INSERT INTO e2e_test_data (name, value) VALUES
|
||||
('row1', 100),
|
||||
('row2', 200),
|
||||
('row3', 300);
|
||||
SQL
|
||||
|
||||
echo "Test data inserted (3 rows)"
|
||||
}
|
||||
|
||||
force_checkpoint() {
|
||||
local port="$1"
|
||||
local pg_bin_dir="$2"
|
||||
|
||||
su postgres -c "$pg_bin_dir/psql -p $port -c 'CHECKPOINT;'" > /dev/null
|
||||
echo "Checkpoint forced"
|
||||
}
|
||||
|
||||
run_agent_backup() {
|
||||
local mock_server="$1"
|
||||
local pg_host="$2"
|
||||
local pg_port="$3"
|
||||
local wal_queue="$4"
|
||||
local pg_type="$5"
|
||||
local pg_host_bin_dir="${6:-}"
|
||||
local pg_docker_container="${7:-}"
|
||||
|
||||
# Reset mock server state and set version to match agent (prevents background upgrade loop)
|
||||
curl -sf -X POST "$mock_server/mock/reset" > /dev/null
|
||||
curl -sf -X POST "$mock_server/mock/set-version" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"version":"v1.0.0"}' > /dev/null
|
||||
|
||||
# Build JSON config
|
||||
cd /tmp
|
||||
|
||||
local extra_fields=""
|
||||
if [ -n "$pg_host_bin_dir" ]; then
|
||||
extra_fields="$extra_fields\"pgHostBinDir\": \"$pg_host_bin_dir\","
|
||||
fi
|
||||
if [ -n "$pg_docker_container" ]; then
|
||||
extra_fields="$extra_fields\"pgDockerContainerName\": \"$pg_docker_container\","
|
||||
fi
|
||||
|
||||
cat > databasus.json <<AGENTCONF
|
||||
{
|
||||
"databasusHost": "$mock_server",
|
||||
"dbId": "test-db-id",
|
||||
"token": "test-token",
|
||||
"pgHost": "$pg_host",
|
||||
"pgPort": $pg_port,
|
||||
"pgUser": "testuser",
|
||||
"pgPassword": "",
|
||||
${extra_fields}
|
||||
"pgType": "$pg_type",
|
||||
"pgWalDir": "$wal_queue",
|
||||
"deleteWalAfterUpload": true
|
||||
}
|
||||
AGENTCONF
|
||||
|
||||
# Run agent daemon in background
|
||||
"$AGENT" _run > /tmp/agent-output.log 2>&1 &
|
||||
AGENT_PID=$!
|
||||
|
||||
echo "Agent started with PID $AGENT_PID"
|
||||
}
|
||||
|
||||
generate_wal_background() {
|
||||
local port="$1"
|
||||
local pg_bin_dir="$2"
|
||||
|
||||
while true; do
|
||||
su postgres -c "$pg_bin_dir/psql -p $port -U testuser -d testdb -c \"
|
||||
INSERT INTO e2e_test_data (name, value)
|
||||
SELECT 'bulk_' || g, g FROM generate_series(1, 1000) g;
|
||||
SELECT pg_switch_wal();
|
||||
\"" > /dev/null 2>&1 || break
|
||||
sleep 2
|
||||
done
|
||||
}
|
||||
|
||||
generate_wal_docker_background() {
|
||||
local container="$1"
|
||||
|
||||
while true; do
|
||||
docker exec "$container" psql -U testuser -d testdb -c "
|
||||
INSERT INTO e2e_test_data (name, value)
|
||||
SELECT 'bulk_' || g, g FROM generate_series(1, 1000) g;
|
||||
SELECT pg_switch_wal();
|
||||
" > /dev/null 2>&1 || break
|
||||
sleep 2
|
||||
done
|
||||
}
|
||||
|
||||
wait_for_backup_complete() {
|
||||
local mock_server="$1"
|
||||
local timeout="${2:-120}"
|
||||
|
||||
echo "Waiting for backup to complete (timeout: ${timeout}s)..."
|
||||
|
||||
for i in $(seq 1 "$timeout"); do
|
||||
STATUS=$(curl -sf "$mock_server/mock/backup-status" 2>/dev/null || echo '{}')
|
||||
IS_FINALIZED=$(echo "$STATUS" | grep -o '"isFinalized":true' || true)
|
||||
WAL_COUNT=$(echo "$STATUS" | grep -o '"walSegmentCount":[0-9]*' | grep -o '[0-9]*$' || echo "0")
|
||||
|
||||
if [ -n "$IS_FINALIZED" ] && [ "$WAL_COUNT" -gt 0 ]; then
|
||||
echo "Backup complete: finalized with $WAL_COUNT WAL segments"
|
||||
return 0
|
||||
fi
|
||||
|
||||
sleep 1
|
||||
done
|
||||
|
||||
echo "FAIL: Backup did not complete within ${timeout} seconds"
|
||||
echo "Last status: $STATUS"
|
||||
echo "Agent output:"
|
||||
cat /tmp/agent-output.log 2>/dev/null || true
|
||||
return 1
|
||||
}
|
||||
|
||||
stop_agent() {
|
||||
if [ -n "$AGENT_PID" ]; then
|
||||
kill "$AGENT_PID" 2>/dev/null || true
|
||||
wait "$AGENT_PID" 2>/dev/null || true
|
||||
AGENT_PID=""
|
||||
fi
|
||||
|
||||
echo "Agent stopped"
|
||||
}
|
||||
|
||||
stop_pg() {
|
||||
local pgdata="$1"
|
||||
local pg_bin_dir="$2"
|
||||
|
||||
su postgres -c "$pg_bin_dir/pg_ctl -D $pgdata stop -m fast" 2>/dev/null || true
|
||||
|
||||
echo "PostgreSQL stopped"
|
||||
}
|
||||
|
||||
run_agent_restore() {
|
||||
local mock_server="$1"
|
||||
local restore_dir="$2"
|
||||
|
||||
rm -rf "$restore_dir"
|
||||
mkdir -p "$restore_dir"
|
||||
chown postgres:postgres "$restore_dir"
|
||||
|
||||
cd /tmp
|
||||
|
||||
"$AGENT" restore \
|
||||
--skip-update \
|
||||
--databasus-host "$mock_server" \
|
||||
--token test-token \
|
||||
--target-dir "$restore_dir"
|
||||
|
||||
echo "Agent restore completed"
|
||||
}
|
||||
|
||||
start_restored_pg() {
|
||||
local restore_dir="$1"
|
||||
local port="$2"
|
||||
local pg_bin_dir="$3"
|
||||
|
||||
# Ensure port is set in restored config
|
||||
if ! grep -q "^port" "$restore_dir/postgresql.conf" 2>/dev/null; then
|
||||
echo "port = $port" >> "$restore_dir/postgresql.conf"
|
||||
fi
|
||||
|
||||
# Ensure listen_addresses is set
|
||||
if ! grep -q "^listen_addresses" "$restore_dir/postgresql.conf" 2>/dev/null; then
|
||||
echo "listen_addresses = 'localhost'" >> "$restore_dir/postgresql.conf"
|
||||
fi
|
||||
|
||||
chown -R postgres:postgres "$restore_dir"
|
||||
chmod 700 "$restore_dir"
|
||||
|
||||
if ! su postgres -c "$pg_bin_dir/pg_ctl -D $restore_dir -l /tmp/pg-restore.log start -w"; then
|
||||
echo "FAIL: PostgreSQL failed to start on restored data"
|
||||
echo "--- pg-restore.log ---"
|
||||
cat /tmp/pg-restore.log 2>/dev/null || echo "(no log file)"
|
||||
echo "--- postgresql.auto.conf ---"
|
||||
cat "$restore_dir/postgresql.auto.conf" 2>/dev/null || echo "(no file)"
|
||||
echo "--- pg_wal/ listing ---"
|
||||
ls -la "$restore_dir/pg_wal/" 2>/dev/null || echo "(no pg_wal dir)"
|
||||
echo "--- databasus-wal-restore/ listing ---"
|
||||
ls -la "$restore_dir/databasus-wal-restore/" 2>/dev/null || echo "(no dir)"
|
||||
echo "--- end diagnostics ---"
|
||||
return 1
|
||||
fi
|
||||
|
||||
echo "PostgreSQL started on restored data"
|
||||
}
|
||||
|
||||
wait_for_recovery_complete() {
|
||||
local port="$1"
|
||||
local pg_bin_dir="$2"
|
||||
local timeout="${3:-60}"
|
||||
|
||||
echo "Waiting for recovery to complete (timeout: ${timeout}s)..."
|
||||
|
||||
for i in $(seq 1 "$timeout"); do
|
||||
IS_READY=$(su postgres -c "$pg_bin_dir/pg_isready -p $port" 2>&1 || true)
|
||||
|
||||
if echo "$IS_READY" | grep -q "accepting connections"; then
|
||||
IN_RECOVERY=$(su postgres -c "$pg_bin_dir/psql -p $port -U testuser -d testdb -t -c 'SELECT pg_is_in_recovery();'" 2>/dev/null | tr -d ' \n' || echo "t")
|
||||
|
||||
if [ "$IN_RECOVERY" = "f" ]; then
|
||||
echo "PostgreSQL recovered and promoted to primary"
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
|
||||
sleep 1
|
||||
done
|
||||
|
||||
echo "FAIL: PostgreSQL did not recover within ${timeout} seconds"
|
||||
echo "Recovery log:"
|
||||
cat /tmp/pg-restore.log 2>/dev/null || true
|
||||
return 1
|
||||
}
|
||||
|
||||
verify_restored_data() {
|
||||
local port="$1"
|
||||
local pg_bin_dir="$2"
|
||||
|
||||
ROW_COUNT=$(su postgres -c "$pg_bin_dir/psql -p $port -U testuser -d testdb -t -c 'SELECT COUNT(*) FROM e2e_test_data;'" | tr -d ' \n')
|
||||
|
||||
if [ "$ROW_COUNT" -lt 3 ]; then
|
||||
echo "FAIL: Expected at least 3 rows, got $ROW_COUNT"
|
||||
su postgres -c "$pg_bin_dir/psql -p $port -U testuser -d testdb -c 'SELECT * FROM e2e_test_data;'"
|
||||
return 1
|
||||
fi
|
||||
|
||||
RESULT=$(su postgres -c "$pg_bin_dir/psql -p $port -U testuser -d testdb -t -c \"SELECT value FROM e2e_test_data WHERE name='row1';\"" | tr -d ' \n')
|
||||
|
||||
if [ "$RESULT" != "100" ]; then
|
||||
echo "FAIL: Expected row1 value=100, got $RESULT"
|
||||
return 1
|
||||
fi
|
||||
|
||||
RESULT2=$(su postgres -c "$pg_bin_dir/psql -p $port -U testuser -d testdb -t -c \"SELECT value FROM e2e_test_data WHERE name='row3';\"" | tr -d ' \n')
|
||||
|
||||
if [ "$RESULT2" != "300" ]; then
|
||||
echo "FAIL: Expected row3 value=300, got $RESULT2"
|
||||
return 1
|
||||
fi
|
||||
|
||||
echo "PASS: Found $ROW_COUNT rows, data integrity verified"
|
||||
return 0
|
||||
}
|
||||
|
||||
find_pg_bin_dir() {
|
||||
# Find the PG bin dir from the installed version
|
||||
local pg_config_path
|
||||
pg_config_path=$(which pg_config 2>/dev/null || true)
|
||||
|
||||
if [ -n "$pg_config_path" ]; then
|
||||
pg_config --bindir
|
||||
return
|
||||
fi
|
||||
|
||||
# Fallback: search common locations
|
||||
for version in 18 17 16 15; do
|
||||
if [ -d "/usr/lib/postgresql/$version/bin" ]; then
|
||||
echo "/usr/lib/postgresql/$version/bin"
|
||||
return
|
||||
fi
|
||||
done
|
||||
|
||||
echo "ERROR: Cannot find PostgreSQL bin directory" >&2
|
||||
return 1
|
||||
}
|
||||
56
agent/e2e/scripts/run-all.sh
Normal file
56
agent/e2e/scripts/run-all.sh
Normal file
@@ -0,0 +1,56 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
MODE="${1:-host}"
|
||||
SCRIPT_DIR="$(dirname "$0")"
|
||||
PASSED=0
|
||||
FAILED=0
|
||||
FAILED_NAMES=""
|
||||
|
||||
run_test() {
|
||||
local name="$1"
|
||||
local script="$2"
|
||||
|
||||
echo ""
|
||||
echo "========================================"
|
||||
echo " $name"
|
||||
echo "========================================"
|
||||
|
||||
if bash "$script"; then
|
||||
echo " PASSED: $name"
|
||||
PASSED=$((PASSED + 1))
|
||||
else
|
||||
echo " FAILED: $name"
|
||||
FAILED=$((FAILED + 1))
|
||||
FAILED_NAMES="${FAILED_NAMES}\n - ${name}"
|
||||
fi
|
||||
}
|
||||
|
||||
if [ "$MODE" = "host" ]; then
|
||||
run_test "Test 1: Upgrade success (v1 -> v2)" "$SCRIPT_DIR/test-upgrade-success.sh"
|
||||
run_test "Test 2: Upgrade skip (version matches)" "$SCRIPT_DIR/test-upgrade-skip.sh"
|
||||
run_test "Test 3: Background upgrade (v1 -> v2 while running)" "$SCRIPT_DIR/test-upgrade-background.sh"
|
||||
run_test "Test 4: Backup-restore via host PATH" "$SCRIPT_DIR/test-pg-host-path.sh"
|
||||
run_test "Test 5: Backup-restore via host bindir" "$SCRIPT_DIR/test-pg-host-bindir.sh"
|
||||
|
||||
elif [ "$MODE" = "docker" ]; then
|
||||
run_test "Test 6: Backup-restore via docker exec" "$SCRIPT_DIR/test-pg-docker-exec.sh"
|
||||
|
||||
else
|
||||
echo "Unknown mode: $MODE (expected 'host' or 'docker')"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "========================================"
|
||||
echo " Results: $PASSED passed, $FAILED failed"
|
||||
if [ "$FAILED" -gt 0 ]; then
|
||||
echo ""
|
||||
echo " Failed:"
|
||||
echo -e "$FAILED_NAMES"
|
||||
fi
|
||||
echo "========================================"
|
||||
|
||||
if [ "$FAILED" -gt 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
95
agent/e2e/scripts/test-pg-docker-exec.sh
Normal file
95
agent/e2e/scripts/test-pg-docker-exec.sh
Normal file
@@ -0,0 +1,95 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(dirname "$0")"
|
||||
source "$SCRIPT_DIR/backup-restore-helpers.sh"
|
||||
|
||||
MOCK_SERVER="${MOCK_SERVER_OVERRIDE:-http://e2e-mock-server:4050}"
|
||||
PG_CONTAINER="e2e-agent-postgres"
|
||||
RESTORE_PGDATA="/tmp/restore-pgdata"
|
||||
WAL_QUEUE="/wal-queue"
|
||||
PG_PORT=5432
|
||||
|
||||
# For restore verification we need a local PG bin dir
|
||||
PG_BIN_DIR=$(find_pg_bin_dir)
|
||||
echo "Using local PG bin dir for restore verification: $PG_BIN_DIR"
|
||||
|
||||
# Verify docker CLI works and PG container is accessible
|
||||
if ! docker exec "$PG_CONTAINER" pg_basebackup --version > /dev/null 2>&1; then
|
||||
echo "FAIL: Cannot reach pg_basebackup inside container $PG_CONTAINER (test setup issue)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "=== Phase 1: Setup agent ==="
|
||||
setup_agent
|
||||
|
||||
echo "=== Phase 2: Insert test data into containerized PostgreSQL ==="
|
||||
docker exec "$PG_CONTAINER" psql -U testuser -d testdb -c "
|
||||
CREATE TABLE IF NOT EXISTS e2e_test_data (
|
||||
id SERIAL PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
value INT NOT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
);
|
||||
DELETE FROM e2e_test_data;
|
||||
INSERT INTO e2e_test_data (name, value) VALUES
|
||||
('row1', 100),
|
||||
('row2', 200),
|
||||
('row3', 300);
|
||||
"
|
||||
echo "Test data inserted (3 rows)"
|
||||
|
||||
echo "=== Phase 3: Start agent backup (docker exec mode) ==="
|
||||
curl -sf -X POST "$MOCK_SERVER/mock/reset" > /dev/null
|
||||
|
||||
cd /tmp
|
||||
cat > databasus.json <<AGENTCONF
|
||||
{
|
||||
"databasusHost": "$MOCK_SERVER",
|
||||
"dbId": "test-db-id",
|
||||
"token": "test-token",
|
||||
"pgHost": "$PG_CONTAINER",
|
||||
"pgPort": $PG_PORT,
|
||||
"pgUser": "testuser",
|
||||
"pgPassword": "testpassword",
|
||||
"pgType": "docker",
|
||||
"pgDockerContainerName": "$PG_CONTAINER",
|
||||
"pgWalDir": "$WAL_QUEUE",
|
||||
"deleteWalAfterUpload": true
|
||||
}
|
||||
AGENTCONF
|
||||
|
||||
"$AGENT" _run > /tmp/agent-output.log 2>&1 &
|
||||
AGENT_PID=$!
|
||||
echo "Agent started with PID $AGENT_PID"
|
||||
|
||||
echo "=== Phase 4: Generate WAL in background ==="
|
||||
generate_wal_docker_background "$PG_CONTAINER" &
|
||||
WAL_GEN_PID=$!
|
||||
|
||||
echo "=== Phase 5: Wait for backup to complete ==="
|
||||
wait_for_backup_complete "$MOCK_SERVER" 120
|
||||
|
||||
echo "=== Phase 6: Stop WAL generator and agent ==="
|
||||
kill $WAL_GEN_PID 2>/dev/null || true
|
||||
wait $WAL_GEN_PID 2>/dev/null || true
|
||||
stop_agent
|
||||
|
||||
echo "=== Phase 7: Restore to local directory ==="
|
||||
run_agent_restore "$MOCK_SERVER" "$RESTORE_PGDATA"
|
||||
|
||||
echo "=== Phase 8: Start local PostgreSQL on restored data ==="
|
||||
# Use a different port to avoid conflict with the containerized PG
|
||||
RESTORE_PORT=5433
|
||||
start_restored_pg "$RESTORE_PGDATA" "$RESTORE_PORT" "$PG_BIN_DIR"
|
||||
|
||||
echo "=== Phase 9: Wait for recovery ==="
|
||||
wait_for_recovery_complete "$RESTORE_PORT" "$PG_BIN_DIR" 60
|
||||
|
||||
echo "=== Phase 10: Verify data ==="
|
||||
verify_restored_data "$RESTORE_PORT" "$PG_BIN_DIR"
|
||||
|
||||
echo "=== Phase 11: Cleanup ==="
|
||||
stop_pg "$RESTORE_PGDATA" "$PG_BIN_DIR"
|
||||
|
||||
echo "pg_basebackup via docker exec: full backup-restore lifecycle passed"
|
||||
62
agent/e2e/scripts/test-pg-host-bindir.sh
Normal file
62
agent/e2e/scripts/test-pg-host-bindir.sh
Normal file
@@ -0,0 +1,62 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(dirname "$0")"
|
||||
source "$SCRIPT_DIR/backup-restore-helpers.sh"
|
||||
|
||||
MOCK_SERVER="${MOCK_SERVER_OVERRIDE:-http://e2e-mock-server:4050}"
|
||||
PGDATA="/tmp/pgdata"
|
||||
RESTORE_PGDATA="/tmp/restore-pgdata"
|
||||
WAL_QUEUE="/tmp/wal-queue"
|
||||
PG_PORT=5433
|
||||
CUSTOM_BIN_DIR="/opt/pg/bin"
|
||||
|
||||
PG_BIN_DIR=$(find_pg_bin_dir)
|
||||
echo "Using PG bin dir: $PG_BIN_DIR"
|
||||
|
||||
# Copy pg_basebackup to a custom directory (simulates non-PATH installation)
|
||||
mkdir -p "$CUSTOM_BIN_DIR"
|
||||
cp "$PG_BIN_DIR/pg_basebackup" "$CUSTOM_BIN_DIR/pg_basebackup"
|
||||
|
||||
echo "=== Phase 1: Setup agent ==="
|
||||
setup_agent
|
||||
|
||||
echo "=== Phase 2: Initialize PostgreSQL ==="
|
||||
init_pg_local "$PGDATA" "$PG_PORT" "$WAL_QUEUE" "$PG_BIN_DIR"
|
||||
|
||||
echo "=== Phase 3: Insert test data ==="
|
||||
insert_test_data "$PG_PORT" "$PG_BIN_DIR"
|
||||
|
||||
echo "=== Phase 4: Force checkpoint and start agent backup (using --pg-host-bin-dir) ==="
|
||||
force_checkpoint "$PG_PORT" "$PG_BIN_DIR"
|
||||
run_agent_backup "$MOCK_SERVER" "127.0.0.1" "$PG_PORT" "$WAL_QUEUE" "host" "$CUSTOM_BIN_DIR"
|
||||
|
||||
echo "=== Phase 5: Generate WAL in background ==="
|
||||
generate_wal_background "$PG_PORT" "$PG_BIN_DIR" &
|
||||
WAL_GEN_PID=$!
|
||||
|
||||
echo "=== Phase 6: Wait for backup to complete ==="
|
||||
wait_for_backup_complete "$MOCK_SERVER" 120
|
||||
|
||||
echo "=== Phase 7: Stop WAL generator, agent, and PostgreSQL ==="
|
||||
kill $WAL_GEN_PID 2>/dev/null || true
|
||||
wait $WAL_GEN_PID 2>/dev/null || true
|
||||
stop_agent
|
||||
stop_pg "$PGDATA" "$PG_BIN_DIR"
|
||||
|
||||
echo "=== Phase 8: Restore ==="
|
||||
run_agent_restore "$MOCK_SERVER" "$RESTORE_PGDATA"
|
||||
|
||||
echo "=== Phase 9: Start PostgreSQL on restored data ==="
|
||||
start_restored_pg "$RESTORE_PGDATA" "$PG_PORT" "$PG_BIN_DIR"
|
||||
|
||||
echo "=== Phase 10: Wait for recovery ==="
|
||||
wait_for_recovery_complete "$PG_PORT" "$PG_BIN_DIR" 60
|
||||
|
||||
echo "=== Phase 11: Verify data ==="
|
||||
verify_restored_data "$PG_PORT" "$PG_BIN_DIR"
|
||||
|
||||
echo "=== Phase 12: Cleanup ==="
|
||||
stop_pg "$RESTORE_PGDATA" "$PG_BIN_DIR"
|
||||
|
||||
echo "pg_basebackup via custom bindir: full backup-restore lifecycle passed"
|
||||
63
agent/e2e/scripts/test-pg-host-path.sh
Normal file
63
agent/e2e/scripts/test-pg-host-path.sh
Normal file
@@ -0,0 +1,63 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(dirname "$0")"
|
||||
source "$SCRIPT_DIR/backup-restore-helpers.sh"
|
||||
|
||||
MOCK_SERVER="${MOCK_SERVER_OVERRIDE:-http://e2e-mock-server:4050}"
|
||||
PGDATA="/tmp/pgdata"
|
||||
RESTORE_PGDATA="/tmp/restore-pgdata"
|
||||
WAL_QUEUE="/tmp/wal-queue"
|
||||
PG_PORT=5433
|
||||
|
||||
PG_BIN_DIR=$(find_pg_bin_dir)
|
||||
echo "Using PG bin dir: $PG_BIN_DIR"
|
||||
|
||||
# Verify pg_basebackup is in PATH
|
||||
if ! which pg_basebackup > /dev/null 2>&1; then
|
||||
echo "FAIL: pg_basebackup not found in PATH (test setup issue)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "=== Phase 1: Setup agent ==="
|
||||
setup_agent
|
||||
|
||||
echo "=== Phase 2: Initialize PostgreSQL ==="
|
||||
init_pg_local "$PGDATA" "$PG_PORT" "$WAL_QUEUE" "$PG_BIN_DIR"
|
||||
|
||||
echo "=== Phase 3: Insert test data ==="
|
||||
insert_test_data "$PG_PORT" "$PG_BIN_DIR"
|
||||
|
||||
echo "=== Phase 4: Force checkpoint and start agent backup ==="
|
||||
force_checkpoint "$PG_PORT" "$PG_BIN_DIR"
|
||||
run_agent_backup "$MOCK_SERVER" "127.0.0.1" "$PG_PORT" "$WAL_QUEUE" "host"
|
||||
|
||||
echo "=== Phase 5: Generate WAL in background ==="
|
||||
generate_wal_background "$PG_PORT" "$PG_BIN_DIR" &
|
||||
WAL_GEN_PID=$!
|
||||
|
||||
echo "=== Phase 6: Wait for backup to complete ==="
|
||||
wait_for_backup_complete "$MOCK_SERVER" 120
|
||||
|
||||
echo "=== Phase 7: Stop WAL generator, agent, and PostgreSQL ==="
|
||||
kill $WAL_GEN_PID 2>/dev/null || true
|
||||
wait $WAL_GEN_PID 2>/dev/null || true
|
||||
stop_agent
|
||||
stop_pg "$PGDATA" "$PG_BIN_DIR"
|
||||
|
||||
echo "=== Phase 8: Restore ==="
|
||||
run_agent_restore "$MOCK_SERVER" "$RESTORE_PGDATA"
|
||||
|
||||
echo "=== Phase 9: Start PostgreSQL on restored data ==="
|
||||
start_restored_pg "$RESTORE_PGDATA" "$PG_PORT" "$PG_BIN_DIR"
|
||||
|
||||
echo "=== Phase 10: Wait for recovery ==="
|
||||
wait_for_recovery_complete "$PG_PORT" "$PG_BIN_DIR" 60
|
||||
|
||||
echo "=== Phase 11: Verify data ==="
|
||||
verify_restored_data "$PG_PORT" "$PG_BIN_DIR"
|
||||
|
||||
echo "=== Phase 12: Cleanup ==="
|
||||
stop_pg "$RESTORE_PGDATA" "$PG_BIN_DIR"
|
||||
|
||||
echo "pg_basebackup in PATH: full backup-restore lifecycle passed"
|
||||
90
agent/e2e/scripts/test-upgrade-background.sh
Normal file
90
agent/e2e/scripts/test-upgrade-background.sh
Normal file
@@ -0,0 +1,90 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
ARTIFACTS="/opt/agent/artifacts"
|
||||
AGENT="/tmp/test-agent"
|
||||
|
||||
# Cleanup from previous runs
|
||||
pkill -f "test-agent" 2>/dev/null || true
|
||||
for i in $(seq 1 20); do
|
||||
pgrep -f "test-agent" > /dev/null 2>&1 || break
|
||||
sleep 0.5
|
||||
done
|
||||
pkill -9 -f "test-agent" 2>/dev/null || true
|
||||
sleep 0.5
|
||||
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
|
||||
|
||||
# Set mock server to v1.0.0 (same as agent — no sync upgrade on start)
|
||||
curl -sf -X POST http://e2e-mock-server:4050/mock/set-version \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"version":"v1.0.0"}'
|
||||
|
||||
curl -sf -X POST http://e2e-mock-server:4050/mock/set-binary-path \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"binaryPath":"/artifacts/agent-v1"}'
|
||||
|
||||
# Copy v1 binary to writable location
|
||||
cp "$ARTIFACTS/agent-v1" "$AGENT"
|
||||
chmod +x "$AGENT"
|
||||
|
||||
# Verify initial version
|
||||
VERSION=$("$AGENT" version)
|
||||
if [ "$VERSION" != "v1.0.0" ]; then
|
||||
echo "FAIL: Expected initial version v1.0.0, got $VERSION"
|
||||
exit 1
|
||||
fi
|
||||
echo "Initial version: $VERSION"
|
||||
|
||||
# Start agent as daemon (versions match → no sync upgrade)
|
||||
mkdir -p /tmp/wal
|
||||
"$AGENT" start \
|
||||
--databasus-host http://e2e-mock-server:4050 \
|
||||
--db-id test-db-id \
|
||||
--token test-token \
|
||||
--pg-host e2e-postgres \
|
||||
--pg-port 5432 \
|
||||
--pg-user testuser \
|
||||
--pg-password testpassword \
|
||||
--pg-wal-dir /tmp/wal \
|
||||
--pg-type host
|
||||
|
||||
echo "Agent started as daemon, waiting for stabilization..."
|
||||
sleep 2
|
||||
|
||||
# Change mock server to v2.0.0 and point to v2 binary
|
||||
curl -sf -X POST http://e2e-mock-server:4050/mock/set-version \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"version":"v2.0.0"}'
|
||||
|
||||
curl -sf -X POST http://e2e-mock-server:4050/mock/set-binary-path \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"binaryPath":"/artifacts/agent-v2"}'
|
||||
|
||||
echo "Mock server updated to v2.0.0, waiting for background upgrade..."
|
||||
|
||||
# Poll for upgrade (timeout 60s, poll every 3s)
|
||||
DEADLINE=$((SECONDS + 60))
|
||||
while [ $SECONDS -lt $DEADLINE ]; do
|
||||
VERSION=$("$AGENT" version)
|
||||
if [ "$VERSION" = "v2.0.0" ]; then
|
||||
echo "Binary upgraded to $VERSION"
|
||||
break
|
||||
fi
|
||||
sleep 3
|
||||
done
|
||||
|
||||
VERSION=$("$AGENT" version)
|
||||
if [ "$VERSION" != "v2.0.0" ]; then
|
||||
echo "FAIL: Expected v2.0.0 after background upgrade, got $VERSION"
|
||||
cat databasus.log 2>/dev/null || true
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Verify agent is still running after restart
|
||||
sleep 2
|
||||
"$AGENT" status || true
|
||||
|
||||
# Cleanup
|
||||
"$AGENT" stop || true
|
||||
|
||||
echo "Background upgrade test passed"
|
||||
64
agent/e2e/scripts/test-upgrade-skip.sh
Normal file
64
agent/e2e/scripts/test-upgrade-skip.sh
Normal file
@@ -0,0 +1,64 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
ARTIFACTS="/opt/agent/artifacts"
|
||||
AGENT="/tmp/test-agent"
|
||||
|
||||
# Cleanup from previous runs
|
||||
pkill -f "test-agent" 2>/dev/null || true
|
||||
for i in $(seq 1 20); do
|
||||
pgrep -f "test-agent" > /dev/null 2>&1 || break
|
||||
sleep 0.5
|
||||
done
|
||||
pkill -9 -f "test-agent" 2>/dev/null || true
|
||||
sleep 0.5
|
||||
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
|
||||
|
||||
# Set mock server to return v1.0.0 (same as agent)
|
||||
curl -sf -X POST http://e2e-mock-server:4050/mock/set-version \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"version":"v1.0.0"}'
|
||||
|
||||
# Copy v1 binary to writable location
|
||||
cp "$ARTIFACTS/agent-v1" "$AGENT"
|
||||
chmod +x "$AGENT"
|
||||
|
||||
# Verify initial version
|
||||
VERSION=$("$AGENT" version)
|
||||
if [ "$VERSION" != "v1.0.0" ]; then
|
||||
echo "FAIL: Expected initial version v1.0.0, got $VERSION"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Run start — agent should see version matches and skip upgrade
|
||||
echo "Running agent start (expecting upgrade skip)..."
|
||||
OUTPUT=$("$AGENT" start \
|
||||
--databasus-host http://e2e-mock-server:4050 \
|
||||
--db-id test-db-id \
|
||||
--token test-token \
|
||||
--pg-host e2e-postgres \
|
||||
--pg-port 5432 \
|
||||
--pg-user testuser \
|
||||
--pg-password testpassword \
|
||||
--pg-wal-dir /tmp/wal \
|
||||
--pg-type host 2>&1) || true
|
||||
|
||||
echo "$OUTPUT"
|
||||
|
||||
# Verify output contains "up to date"
|
||||
if ! echo "$OUTPUT" | grep -qi "up to date"; then
|
||||
echo "FAIL: Expected output to contain 'up to date'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Verify binary is still v1
|
||||
VERSION=$("$AGENT" version)
|
||||
if [ "$VERSION" != "v1.0.0" ]; then
|
||||
echo "FAIL: Expected version v1.0.0 (unchanged), got $VERSION"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Upgrade correctly skipped, version still $VERSION"
|
||||
|
||||
# Cleanup daemon
|
||||
"$AGENT" stop || true
|
||||
69
agent/e2e/scripts/test-upgrade-success.sh
Normal file
69
agent/e2e/scripts/test-upgrade-success.sh
Normal file
@@ -0,0 +1,69 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
ARTIFACTS="/opt/agent/artifacts"
|
||||
AGENT="/tmp/test-agent"
|
||||
|
||||
# Cleanup from previous runs
|
||||
pkill -f "test-agent" 2>/dev/null || true
|
||||
for i in $(seq 1 20); do
|
||||
pgrep -f "test-agent" > /dev/null 2>&1 || break
|
||||
sleep 0.5
|
||||
done
|
||||
pkill -9 -f "test-agent" 2>/dev/null || true
|
||||
sleep 0.5
|
||||
rm -f "$AGENT" "$AGENT.update" databasus.lock databasus.log databasus.log.old databasus.json 2>/dev/null || true
|
||||
|
||||
# Ensure mock server returns v2.0.0 and serves v2 binary
|
||||
curl -sf -X POST http://e2e-mock-server:4050/mock/set-version \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"version":"v2.0.0"}'
|
||||
|
||||
curl -sf -X POST http://e2e-mock-server:4050/mock/set-binary-path \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"binaryPath":"/artifacts/agent-v2"}'
|
||||
|
||||
# Copy v1 binary to writable location
|
||||
cp "$ARTIFACTS/agent-v1" "$AGENT"
|
||||
chmod +x "$AGENT"
|
||||
|
||||
# Verify initial version
|
||||
VERSION=$("$AGENT" version)
|
||||
if [ "$VERSION" != "v1.0.0" ]; then
|
||||
echo "FAIL: Expected initial version v1.0.0, got $VERSION"
|
||||
exit 1
|
||||
fi
|
||||
echo "Initial version: $VERSION"
|
||||
|
||||
# Run start — agent will:
|
||||
# 1. Fetch version from mock (v2.0.0 != v1.0.0)
|
||||
# 2. Download v2 binary from mock
|
||||
# 3. Replace itself on disk
|
||||
# 4. Re-exec with same args
|
||||
# 5. Re-exec'd v2 fetches version (v2.0.0 == v2.0.0) → skips update
|
||||
# 6. Proceeds to start → verifies pg_basebackup + DB → exits 0 (stub)
|
||||
echo "Running agent start (expecting upgrade v1 -> v2)..."
|
||||
OUTPUT=$("$AGENT" start \
|
||||
--databasus-host http://e2e-mock-server:4050 \
|
||||
--db-id test-db-id \
|
||||
--token test-token \
|
||||
--pg-host e2e-postgres \
|
||||
--pg-port 5432 \
|
||||
--pg-user testuser \
|
||||
--pg-password testpassword \
|
||||
--pg-wal-dir /tmp/wal \
|
||||
--pg-type host 2>&1) || true
|
||||
|
||||
echo "$OUTPUT"
|
||||
|
||||
# Verify binary on disk is now v2
|
||||
VERSION=$("$AGENT" version)
|
||||
if [ "$VERSION" != "v2.0.0" ]; then
|
||||
echo "FAIL: Expected upgraded version v2.0.0, got $VERSION"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Binary upgraded successfully to $VERSION"
|
||||
|
||||
# Cleanup daemon
|
||||
"$AGENT" stop || true
|
||||
22
agent/go.mod
Normal file
22
agent/go.mod
Normal file
@@ -0,0 +1,22 @@
|
||||
module databasus-agent
|
||||
|
||||
go 1.26.1
|
||||
|
||||
require (
|
||||
github.com/go-resty/resty/v2 v2.17.2
|
||||
github.com/jackc/pgx/v5 v5.8.0
|
||||
github.com/klauspost/compress v1.18.4
|
||||
github.com/stretchr/testify v1.11.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.14.1 // indirect
|
||||
golang.org/x/net v0.43.0 // indirect
|
||||
golang.org/x/text v0.29.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
43
agent/go.sum
Normal file
43
agent/go.sum
Normal file
@@ -0,0 +1,43 @@
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-resty/resty/v2 v2.17.2 h1:FQW5oHYcIlkCNrMD2lloGScxcHJ0gkjshV3qcQAyHQk=
|
||||
github.com/go-resty/resty/v2 v2.17.2/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo=
|
||||
github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c=
|
||||
github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
||||
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
|
||||
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
|
||||
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
|
||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
271
agent/internal/config/config.go
Normal file
271
agent/internal/config/config.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"databasus-agent/internal/logger"
|
||||
)
|
||||
|
||||
var log = logger.GetLogger()
|
||||
|
||||
const configFileName = "databasus.json"
|
||||
|
||||
type Config struct {
|
||||
DatabasusHost string `json:"databasusHost"`
|
||||
DbID string `json:"dbId"`
|
||||
Token string `json:"token"`
|
||||
PgHost string `json:"pgHost"`
|
||||
PgPort int `json:"pgPort"`
|
||||
PgUser string `json:"pgUser"`
|
||||
PgPassword string `json:"pgPassword"`
|
||||
PgType string `json:"pgType"`
|
||||
PgHostBinDir string `json:"pgHostBinDir"`
|
||||
PgDockerContainerName string `json:"pgDockerContainerName"`
|
||||
PgWalDir string `json:"pgWalDir"`
|
||||
IsDeleteWalAfterUpload *bool `json:"deleteWalAfterUpload"`
|
||||
|
||||
flags parsedFlags
|
||||
}
|
||||
|
||||
// LoadFromJSONAndArgs reads databasus.json into the struct
|
||||
// and overrides JSON values with any explicitly provided CLI flags.
|
||||
func (c *Config) LoadFromJSONAndArgs(fs *flag.FlagSet, args []string) {
|
||||
c.loadFromJSON()
|
||||
c.applyDefaults()
|
||||
c.initSources()
|
||||
|
||||
c.flags.databasusHost = fs.String(
|
||||
"databasus-host",
|
||||
"",
|
||||
"Databasus server URL (e.g. http://your-server:4005)",
|
||||
)
|
||||
c.flags.dbID = fs.String("db-id", "", "Database ID")
|
||||
c.flags.token = fs.String("token", "", "Agent token")
|
||||
c.flags.pgHost = fs.String("pg-host", "", "PostgreSQL host")
|
||||
c.flags.pgPort = fs.Int("pg-port", 0, "PostgreSQL port")
|
||||
c.flags.pgUser = fs.String("pg-user", "", "PostgreSQL user")
|
||||
c.flags.pgPassword = fs.String("pg-password", "", "PostgreSQL password")
|
||||
c.flags.pgType = fs.String("pg-type", "", "PostgreSQL type: host or docker")
|
||||
c.flags.pgHostBinDir = fs.String("pg-host-bin-dir", "", "Path to PG bin directory (host mode)")
|
||||
c.flags.pgDockerContainerName = fs.String("pg-docker-container-name", "", "Docker container name (docker mode)")
|
||||
c.flags.pgWalDir = fs.String("pg-wal-dir", "", "Path to WAL queue directory")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
c.applyFlags()
|
||||
log.Info("========= Loading config ============")
|
||||
c.logConfigSources()
|
||||
log.Info("========= Config has been loaded ====")
|
||||
}
|
||||
|
||||
// SaveToJSON writes the current struct to databasus.json.
|
||||
func (c *Config) SaveToJSON() error {
|
||||
data, err := json.MarshalIndent(c, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(configFileName, data, 0o644)
|
||||
}
|
||||
|
||||
func (c *Config) LoadFromJSON() {
|
||||
c.loadFromJSON()
|
||||
c.applyDefaults()
|
||||
}
|
||||
|
||||
func (c *Config) loadFromJSON() {
|
||||
data, err := os.ReadFile(configFileName)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
log.Info("No databasus.json found, will create on save")
|
||||
return
|
||||
}
|
||||
|
||||
log.Warn("Failed to read databasus.json", "error", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, c); err != nil {
|
||||
log.Warn("Failed to parse databasus.json", "error", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("Configuration loaded from " + configFileName)
|
||||
}
|
||||
|
||||
func (c *Config) applyDefaults() {
|
||||
if c.PgPort == 0 {
|
||||
c.PgPort = 5432
|
||||
}
|
||||
|
||||
if c.PgType == "" {
|
||||
c.PgType = "host"
|
||||
}
|
||||
|
||||
if c.IsDeleteWalAfterUpload == nil {
|
||||
c.IsDeleteWalAfterUpload = new(true)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) initSources() {
|
||||
c.flags.sources = map[string]string{
|
||||
"databasus-host": "not configured",
|
||||
"db-id": "not configured",
|
||||
"token": "not configured",
|
||||
"pg-host": "not configured",
|
||||
"pg-port": "not configured",
|
||||
"pg-user": "not configured",
|
||||
"pg-password": "not configured",
|
||||
"pg-type": "not configured",
|
||||
"pg-host-bin-dir": "not configured",
|
||||
"pg-docker-container-name": "not configured",
|
||||
"pg-wal-dir": "not configured",
|
||||
"delete-wal-after-upload": "not configured",
|
||||
}
|
||||
|
||||
if c.DatabasusHost != "" {
|
||||
c.flags.sources["databasus-host"] = configFileName
|
||||
}
|
||||
|
||||
if c.DbID != "" {
|
||||
c.flags.sources["db-id"] = configFileName
|
||||
}
|
||||
|
||||
if c.Token != "" {
|
||||
c.flags.sources["token"] = configFileName
|
||||
}
|
||||
|
||||
if c.PgHost != "" {
|
||||
c.flags.sources["pg-host"] = configFileName
|
||||
}
|
||||
|
||||
// PgPort always has a value after applyDefaults
|
||||
c.flags.sources["pg-port"] = configFileName
|
||||
|
||||
if c.PgUser != "" {
|
||||
c.flags.sources["pg-user"] = configFileName
|
||||
}
|
||||
|
||||
if c.PgPassword != "" {
|
||||
c.flags.sources["pg-password"] = configFileName
|
||||
}
|
||||
|
||||
// PgType always has a value after applyDefaults
|
||||
c.flags.sources["pg-type"] = configFileName
|
||||
|
||||
if c.PgHostBinDir != "" {
|
||||
c.flags.sources["pg-host-bin-dir"] = configFileName
|
||||
}
|
||||
|
||||
if c.PgDockerContainerName != "" {
|
||||
c.flags.sources["pg-docker-container-name"] = configFileName
|
||||
}
|
||||
|
||||
if c.PgWalDir != "" {
|
||||
c.flags.sources["pg-wal-dir"] = configFileName
|
||||
}
|
||||
|
||||
// IsDeleteWalAfterUpload always has a value after applyDefaults
|
||||
c.flags.sources["delete-wal-after-upload"] = configFileName
|
||||
}
|
||||
|
||||
func (c *Config) applyFlags() {
|
||||
if c.flags.databasusHost != nil && *c.flags.databasusHost != "" {
|
||||
c.DatabasusHost = *c.flags.databasusHost
|
||||
c.flags.sources["databasus-host"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.dbID != nil && *c.flags.dbID != "" {
|
||||
c.DbID = *c.flags.dbID
|
||||
c.flags.sources["db-id"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.token != nil && *c.flags.token != "" {
|
||||
c.Token = *c.flags.token
|
||||
c.flags.sources["token"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgHost != nil && *c.flags.pgHost != "" {
|
||||
c.PgHost = *c.flags.pgHost
|
||||
c.flags.sources["pg-host"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgPort != nil && *c.flags.pgPort != 0 {
|
||||
c.PgPort = *c.flags.pgPort
|
||||
c.flags.sources["pg-port"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgUser != nil && *c.flags.pgUser != "" {
|
||||
c.PgUser = *c.flags.pgUser
|
||||
c.flags.sources["pg-user"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgPassword != nil && *c.flags.pgPassword != "" {
|
||||
c.PgPassword = *c.flags.pgPassword
|
||||
c.flags.sources["pg-password"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgType != nil && *c.flags.pgType != "" {
|
||||
c.PgType = *c.flags.pgType
|
||||
c.flags.sources["pg-type"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgHostBinDir != nil && *c.flags.pgHostBinDir != "" {
|
||||
c.PgHostBinDir = *c.flags.pgHostBinDir
|
||||
c.flags.sources["pg-host-bin-dir"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgDockerContainerName != nil && *c.flags.pgDockerContainerName != "" {
|
||||
c.PgDockerContainerName = *c.flags.pgDockerContainerName
|
||||
c.flags.sources["pg-docker-container-name"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgWalDir != nil && *c.flags.pgWalDir != "" {
|
||||
c.PgWalDir = *c.flags.pgWalDir
|
||||
c.flags.sources["pg-wal-dir"] = "command line args"
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) logConfigSources() {
|
||||
log.Info("databasus-host", "value", c.DatabasusHost, "source", c.flags.sources["databasus-host"])
|
||||
log.Info("db-id", "value", c.DbID, "source", c.flags.sources["db-id"])
|
||||
log.Info("token", "value", maskSensitive(c.Token), "source", c.flags.sources["token"])
|
||||
log.Info("pg-host", "value", c.PgHost, "source", c.flags.sources["pg-host"])
|
||||
log.Info("pg-port", "value", c.PgPort, "source", c.flags.sources["pg-port"])
|
||||
log.Info("pg-user", "value", c.PgUser, "source", c.flags.sources["pg-user"])
|
||||
log.Info("pg-password", "value", maskSensitive(c.PgPassword), "source", c.flags.sources["pg-password"])
|
||||
log.Info("pg-type", "value", c.PgType, "source", c.flags.sources["pg-type"])
|
||||
log.Info("pg-host-bin-dir", "value", c.PgHostBinDir, "source", c.flags.sources["pg-host-bin-dir"])
|
||||
log.Info(
|
||||
"pg-docker-container-name",
|
||||
"value",
|
||||
c.PgDockerContainerName,
|
||||
"source",
|
||||
c.flags.sources["pg-docker-container-name"],
|
||||
)
|
||||
log.Info("pg-wal-dir", "value", c.PgWalDir, "source", c.flags.sources["pg-wal-dir"])
|
||||
log.Info(
|
||||
"delete-wal-after-upload",
|
||||
"value",
|
||||
fmt.Sprintf("%v", *c.IsDeleteWalAfterUpload),
|
||||
"source",
|
||||
c.flags.sources["delete-wal-after-upload"],
|
||||
)
|
||||
}
|
||||
|
||||
func maskSensitive(value string) string {
|
||||
if value == "" {
|
||||
return "(not set)"
|
||||
}
|
||||
|
||||
visibleLen := max(len(value)/4, 1)
|
||||
|
||||
return value[:visibleLen] + "***"
|
||||
}
|
||||
301
agent/internal/config/config_test.go
Normal file
301
agent/internal/config/config_test.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_LoadFromJSONAndArgs_ValuesLoadedFromJSON(t *testing.T) {
|
||||
dir := setupTempDir(t)
|
||||
writeConfigJSON(t, dir, Config{
|
||||
DatabasusHost: "http://json-host:4005",
|
||||
DbID: "json-db-id",
|
||||
Token: "json-token",
|
||||
})
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{})
|
||||
|
||||
assert.Equal(t, "http://json-host:4005", cfg.DatabasusHost)
|
||||
assert.Equal(t, "json-db-id", cfg.DbID)
|
||||
assert.Equal(t, "json-token", cfg.Token)
|
||||
}
|
||||
|
||||
func Test_LoadFromJSONAndArgs_ValuesLoadedFromArgs_WhenNoJSON(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{
|
||||
"--databasus-host", "http://arg-host:4005",
|
||||
"--db-id", "arg-db-id",
|
||||
"--token", "arg-token",
|
||||
})
|
||||
|
||||
assert.Equal(t, "http://arg-host:4005", cfg.DatabasusHost)
|
||||
assert.Equal(t, "arg-db-id", cfg.DbID)
|
||||
assert.Equal(t, "arg-token", cfg.Token)
|
||||
}
|
||||
|
||||
func Test_LoadFromJSONAndArgs_ArgsOverrideJSON(t *testing.T) {
|
||||
dir := setupTempDir(t)
|
||||
writeConfigJSON(t, dir, Config{
|
||||
DatabasusHost: "http://json-host:4005",
|
||||
DbID: "json-db-id",
|
||||
Token: "json-token",
|
||||
})
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{
|
||||
"--databasus-host", "http://arg-host:9999",
|
||||
"--db-id", "arg-db-id-override",
|
||||
"--token", "arg-token-override",
|
||||
})
|
||||
|
||||
assert.Equal(t, "http://arg-host:9999", cfg.DatabasusHost)
|
||||
assert.Equal(t, "arg-db-id-override", cfg.DbID)
|
||||
assert.Equal(t, "arg-token-override", cfg.Token)
|
||||
}
|
||||
|
||||
func Test_LoadFromJSONAndArgs_PartialArgsOverrideJSON(t *testing.T) {
|
||||
dir := setupTempDir(t)
|
||||
writeConfigJSON(t, dir, Config{
|
||||
DatabasusHost: "http://json-host:4005",
|
||||
DbID: "json-db-id",
|
||||
Token: "json-token",
|
||||
})
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{
|
||||
"--databasus-host", "http://arg-host-only:4005",
|
||||
})
|
||||
|
||||
assert.Equal(t, "http://arg-host-only:4005", cfg.DatabasusHost)
|
||||
assert.Equal(t, "json-db-id", cfg.DbID)
|
||||
assert.Equal(t, "json-token", cfg.Token)
|
||||
}
|
||||
|
||||
func Test_SaveToJSON_ConfigSavedCorrectly(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
|
||||
deleteWal := true
|
||||
cfg := &Config{
|
||||
DatabasusHost: "http://save-host:4005",
|
||||
DbID: "save-db-id",
|
||||
Token: "save-token",
|
||||
IsDeleteWalAfterUpload: &deleteWal,
|
||||
}
|
||||
|
||||
err := cfg.SaveToJSON()
|
||||
require.NoError(t, err)
|
||||
|
||||
saved := readConfigJSON(t)
|
||||
|
||||
assert.Equal(t, "http://save-host:4005", saved.DatabasusHost)
|
||||
assert.Equal(t, "save-db-id", saved.DbID)
|
||||
assert.Equal(t, "save-token", saved.Token)
|
||||
}
|
||||
|
||||
func Test_SaveToJSON_AfterArgsOverrideJSON_SavedFileContainsMergedValues(t *testing.T) {
|
||||
dir := setupTempDir(t)
|
||||
writeConfigJSON(t, dir, Config{
|
||||
DatabasusHost: "http://json-host:4005",
|
||||
DbID: "json-db-id",
|
||||
Token: "json-token",
|
||||
})
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{
|
||||
"--databasus-host", "http://override-host:9999",
|
||||
})
|
||||
|
||||
err := cfg.SaveToJSON()
|
||||
require.NoError(t, err)
|
||||
|
||||
saved := readConfigJSON(t)
|
||||
|
||||
assert.Equal(t, "http://override-host:9999", saved.DatabasusHost)
|
||||
assert.Equal(t, "json-db-id", saved.DbID)
|
||||
assert.Equal(t, "json-token", saved.Token)
|
||||
}
|
||||
|
||||
func Test_LoadFromJSONAndArgs_PgFieldsLoadedFromJSON(t *testing.T) {
|
||||
dir := setupTempDir(t)
|
||||
deleteWal := false
|
||||
writeConfigJSON(t, dir, Config{
|
||||
DatabasusHost: "http://json-host:4005",
|
||||
DbID: "json-db-id",
|
||||
Token: "json-token",
|
||||
PgHost: "pg-json-host",
|
||||
PgPort: 5433,
|
||||
PgUser: "pg-json-user",
|
||||
PgPassword: "pg-json-pass",
|
||||
PgType: "docker",
|
||||
PgHostBinDir: "/usr/bin",
|
||||
PgDockerContainerName: "pg-container",
|
||||
PgWalDir: "/opt/wal",
|
||||
IsDeleteWalAfterUpload: &deleteWal,
|
||||
})
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{})
|
||||
|
||||
assert.Equal(t, "pg-json-host", cfg.PgHost)
|
||||
assert.Equal(t, 5433, cfg.PgPort)
|
||||
assert.Equal(t, "pg-json-user", cfg.PgUser)
|
||||
assert.Equal(t, "pg-json-pass", cfg.PgPassword)
|
||||
assert.Equal(t, "docker", cfg.PgType)
|
||||
assert.Equal(t, "/usr/bin", cfg.PgHostBinDir)
|
||||
assert.Equal(t, "pg-container", cfg.PgDockerContainerName)
|
||||
assert.Equal(t, "/opt/wal", cfg.PgWalDir)
|
||||
assert.Equal(t, false, *cfg.IsDeleteWalAfterUpload)
|
||||
}
|
||||
|
||||
func Test_LoadFromJSONAndArgs_PgFieldsLoadedFromArgs(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{
|
||||
"--pg-host", "arg-pg-host",
|
||||
"--pg-port", "5433",
|
||||
"--pg-user", "arg-pg-user",
|
||||
"--pg-password", "arg-pg-pass",
|
||||
"--pg-type", "docker",
|
||||
"--pg-host-bin-dir", "/custom/bin",
|
||||
"--pg-docker-container-name", "my-pg",
|
||||
"--pg-wal-dir", "/var/wal",
|
||||
})
|
||||
|
||||
assert.Equal(t, "arg-pg-host", cfg.PgHost)
|
||||
assert.Equal(t, 5433, cfg.PgPort)
|
||||
assert.Equal(t, "arg-pg-user", cfg.PgUser)
|
||||
assert.Equal(t, "arg-pg-pass", cfg.PgPassword)
|
||||
assert.Equal(t, "docker", cfg.PgType)
|
||||
assert.Equal(t, "/custom/bin", cfg.PgHostBinDir)
|
||||
assert.Equal(t, "my-pg", cfg.PgDockerContainerName)
|
||||
assert.Equal(t, "/var/wal", cfg.PgWalDir)
|
||||
}
|
||||
|
||||
func Test_LoadFromJSONAndArgs_PgArgsOverrideJSON(t *testing.T) {
|
||||
dir := setupTempDir(t)
|
||||
writeConfigJSON(t, dir, Config{
|
||||
PgHost: "json-host",
|
||||
PgPort: 5432,
|
||||
PgUser: "json-user",
|
||||
PgType: "host",
|
||||
PgWalDir: "/json/wal",
|
||||
})
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{
|
||||
"--pg-host", "arg-host",
|
||||
"--pg-port", "5433",
|
||||
"--pg-user", "arg-user",
|
||||
"--pg-type", "docker",
|
||||
"--pg-docker-container-name", "my-container",
|
||||
"--pg-wal-dir", "/arg/wal",
|
||||
})
|
||||
|
||||
assert.Equal(t, "arg-host", cfg.PgHost)
|
||||
assert.Equal(t, 5433, cfg.PgPort)
|
||||
assert.Equal(t, "arg-user", cfg.PgUser)
|
||||
assert.Equal(t, "docker", cfg.PgType)
|
||||
assert.Equal(t, "my-container", cfg.PgDockerContainerName)
|
||||
assert.Equal(t, "/arg/wal", cfg.PgWalDir)
|
||||
}
|
||||
|
||||
func Test_LoadFromJSONAndArgs_DefaultsApplied_WhenNoJSONAndNoArgs(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{})
|
||||
|
||||
assert.Equal(t, 5432, cfg.PgPort)
|
||||
assert.Equal(t, "host", cfg.PgType)
|
||||
require.NotNil(t, cfg.IsDeleteWalAfterUpload)
|
||||
assert.Equal(t, true, *cfg.IsDeleteWalAfterUpload)
|
||||
}
|
||||
|
||||
func Test_SaveToJSON_PgFieldsSavedCorrectly(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
|
||||
deleteWal := false
|
||||
cfg := &Config{
|
||||
DatabasusHost: "http://host:4005",
|
||||
DbID: "db-id",
|
||||
Token: "token",
|
||||
PgHost: "pg-host",
|
||||
PgPort: 5433,
|
||||
PgUser: "pg-user",
|
||||
PgPassword: "pg-pass",
|
||||
PgType: "docker",
|
||||
PgHostBinDir: "/usr/bin",
|
||||
PgDockerContainerName: "pg-container",
|
||||
PgWalDir: "/opt/wal",
|
||||
IsDeleteWalAfterUpload: &deleteWal,
|
||||
}
|
||||
|
||||
err := cfg.SaveToJSON()
|
||||
require.NoError(t, err)
|
||||
|
||||
saved := readConfigJSON(t)
|
||||
|
||||
assert.Equal(t, "pg-host", saved.PgHost)
|
||||
assert.Equal(t, 5433, saved.PgPort)
|
||||
assert.Equal(t, "pg-user", saved.PgUser)
|
||||
assert.Equal(t, "pg-pass", saved.PgPassword)
|
||||
assert.Equal(t, "docker", saved.PgType)
|
||||
assert.Equal(t, "/usr/bin", saved.PgHostBinDir)
|
||||
assert.Equal(t, "pg-container", saved.PgDockerContainerName)
|
||||
assert.Equal(t, "/opt/wal", saved.PgWalDir)
|
||||
require.NotNil(t, saved.IsDeleteWalAfterUpload)
|
||||
assert.Equal(t, false, *saved.IsDeleteWalAfterUpload)
|
||||
}
|
||||
|
||||
func setupTempDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
origDir, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
|
||||
dir := t.TempDir()
|
||||
require.NoError(t, os.Chdir(dir))
|
||||
|
||||
t.Cleanup(func() { os.Chdir(origDir) })
|
||||
|
||||
return dir
|
||||
}
|
||||
|
||||
func writeConfigJSON(t *testing.T, dir string, cfg Config) {
|
||||
t.Helper()
|
||||
|
||||
data, err := json.MarshalIndent(cfg, "", " ")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, os.WriteFile(dir+"/"+configFileName, data, 0o644))
|
||||
}
|
||||
|
||||
func readConfigJSON(t *testing.T) Config {
|
||||
t.Helper()
|
||||
|
||||
data, err := os.ReadFile(configFileName)
|
||||
require.NoError(t, err)
|
||||
|
||||
var cfg Config
|
||||
require.NoError(t, json.Unmarshal(data, &cfg))
|
||||
|
||||
return cfg
|
||||
}
|
||||
17
agent/internal/config/dto.go
Normal file
17
agent/internal/config/dto.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package config
|
||||
|
||||
type parsedFlags struct {
|
||||
databasusHost *string
|
||||
dbID *string
|
||||
token *string
|
||||
pgHost *string
|
||||
pgPort *int
|
||||
pgUser *string
|
||||
pgPassword *string
|
||||
pgType *string
|
||||
pgHostBinDir *string
|
||||
pgDockerContainerName *string
|
||||
pgWalDir *string
|
||||
|
||||
sources map[string]string
|
||||
}
|
||||
376
agent/internal/features/api/api.go
Normal file
376
agent/internal/features/api/api.go
Normal file
@@ -0,0 +1,376 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
chainValidPath = "/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup"
|
||||
nextBackupTimePath = "/api/v1/backups/postgres/wal/next-full-backup-time"
|
||||
walUploadPath = "/api/v1/backups/postgres/wal/upload/wal"
|
||||
fullStartPath = "/api/v1/backups/postgres/wal/upload/full-start"
|
||||
fullCompletePath = "/api/v1/backups/postgres/wal/upload/full-complete"
|
||||
reportErrorPath = "/api/v1/backups/postgres/wal/error"
|
||||
restorePlanPath = "/api/v1/backups/postgres/wal/restore/plan"
|
||||
restoreDownloadPath = "/api/v1/backups/postgres/wal/restore/download"
|
||||
versionPath = "/api/v1/system/version"
|
||||
agentBinaryPath = "/api/v1/system/agent"
|
||||
|
||||
apiCallTimeout = 30 * time.Second
|
||||
maxRetryAttempts = 3
|
||||
retryBaseDelay = 1 * time.Second
|
||||
)
|
||||
|
||||
// For stream uploads (basebackup and WAL segments) the standard resty client is not used,
|
||||
// because it buffers the entire body in memory before sending.
|
||||
type Client struct {
|
||||
json *resty.Client
|
||||
streamHTTP *http.Client
|
||||
host string
|
||||
token string
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
func NewClient(host, token string, log *slog.Logger) *Client {
|
||||
setAuth := func(_ *resty.Client, req *resty.Request) error {
|
||||
if token != "" {
|
||||
req.SetHeader("Authorization", token)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
jsonClient := resty.New().
|
||||
SetTimeout(apiCallTimeout).
|
||||
SetRetryCount(maxRetryAttempts - 1).
|
||||
SetRetryWaitTime(retryBaseDelay).
|
||||
SetRetryMaxWaitTime(4 * retryBaseDelay).
|
||||
AddRetryCondition(func(resp *resty.Response, err error) bool {
|
||||
return err != nil || resp.StatusCode() >= 500
|
||||
}).
|
||||
OnBeforeRequest(setAuth)
|
||||
|
||||
return &Client{
|
||||
json: jsonClient,
|
||||
streamHTTP: &http.Client{},
|
||||
host: host,
|
||||
token: token,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) CheckWalChainValidity(ctx context.Context) (*WalChainValidityResponse, error) {
|
||||
var resp WalChainValidityResponse
|
||||
|
||||
httpResp, err := c.json.R().
|
||||
SetContext(ctx).
|
||||
SetResult(&resp).
|
||||
Get(c.buildURL(chainValidPath))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := c.checkResponse(httpResp, "check WAL chain validity"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (c *Client) GetNextFullBackupTime(ctx context.Context) (*NextFullBackupTimeResponse, error) {
|
||||
var resp NextFullBackupTimeResponse
|
||||
|
||||
httpResp, err := c.json.R().
|
||||
SetContext(ctx).
|
||||
SetResult(&resp).
|
||||
Get(c.buildURL(nextBackupTimePath))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := c.checkResponse(httpResp, "get next full backup time"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (c *Client) ReportBackupError(ctx context.Context, errMsg string) error {
|
||||
httpResp, err := c.json.R().
|
||||
SetContext(ctx).
|
||||
SetBody(reportErrorRequest{Error: errMsg}).
|
||||
Post(c.buildURL(reportErrorPath))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.checkResponse(httpResp, "report backup error")
|
||||
}
|
||||
|
||||
func (c *Client) UploadBasebackup(
|
||||
ctx context.Context,
|
||||
body io.Reader,
|
||||
) (*UploadBasebackupResponse, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.buildURL(fullStartPath), body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create upload request: %w", err)
|
||||
}
|
||||
|
||||
c.setStreamHeaders(req)
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
|
||||
resp, err := c.streamHTTP.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upload request: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
return nil, fmt.Errorf("upload failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result UploadBasebackupResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("decode upload response: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (c *Client) FinalizeBasebackup(
|
||||
ctx context.Context,
|
||||
backupID string,
|
||||
startSegment string,
|
||||
stopSegment string,
|
||||
) error {
|
||||
resp, err := c.json.R().
|
||||
SetContext(ctx).
|
||||
SetBody(finalizeBasebackupRequest{
|
||||
BackupID: backupID,
|
||||
StartSegment: startSegment,
|
||||
StopSegment: stopSegment,
|
||||
}).
|
||||
Post(c.buildURL(fullCompletePath))
|
||||
if err != nil {
|
||||
return fmt.Errorf("finalize request: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode() != http.StatusOK {
|
||||
return fmt.Errorf("finalize failed with status %d: %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) FinalizeBasebackupWithError(
|
||||
ctx context.Context,
|
||||
backupID string,
|
||||
errMsg string,
|
||||
) error {
|
||||
resp, err := c.json.R().
|
||||
SetContext(ctx).
|
||||
SetBody(finalizeBasebackupRequest{
|
||||
BackupID: backupID,
|
||||
Error: &errMsg,
|
||||
}).
|
||||
Post(c.buildURL(fullCompletePath))
|
||||
if err != nil {
|
||||
return fmt.Errorf("finalize-with-error request: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode() != http.StatusOK {
|
||||
return fmt.Errorf("finalize-with-error failed with status %d: %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) UploadWalSegment(
|
||||
ctx context.Context,
|
||||
segmentName string,
|
||||
body io.Reader,
|
||||
) (*UploadWalSegmentResult, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.buildURL(walUploadPath), body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create WAL upload request: %w", err)
|
||||
}
|
||||
|
||||
c.setStreamHeaders(req)
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
req.Header.Set("X-Wal-Segment-Name", segmentName)
|
||||
|
||||
resp, err := c.streamHTTP.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upload request: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusNoContent:
|
||||
return &UploadWalSegmentResult{IsGapDetected: false}, nil
|
||||
|
||||
case http.StatusConflict:
|
||||
var errResp uploadErrorResponse
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&errResp); err != nil {
|
||||
return &UploadWalSegmentResult{IsGapDetected: true}, nil
|
||||
}
|
||||
|
||||
return &UploadWalSegmentResult{
|
||||
IsGapDetected: true,
|
||||
ExpectedSegmentName: errResp.ExpectedSegmentName,
|
||||
ReceivedSegmentName: errResp.ReceivedSegmentName,
|
||||
}, nil
|
||||
|
||||
default:
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
return nil, fmt.Errorf("upload failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) GetRestorePlan(
|
||||
ctx context.Context,
|
||||
backupID string,
|
||||
) (*GetRestorePlanResponse, *GetRestorePlanErrorResponse, error) {
|
||||
request := c.json.R().SetContext(ctx)
|
||||
|
||||
if backupID != "" {
|
||||
request.SetQueryParam("backupId", backupID)
|
||||
}
|
||||
|
||||
httpResp, err := request.Get(c.buildURL(restorePlanPath))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("get restore plan: %w", err)
|
||||
}
|
||||
|
||||
switch httpResp.StatusCode() {
|
||||
case http.StatusOK:
|
||||
var response GetRestorePlanResponse
|
||||
if err := json.Unmarshal(httpResp.Body(), &response); err != nil {
|
||||
return nil, nil, fmt.Errorf("decode restore plan response: %w", err)
|
||||
}
|
||||
|
||||
return &response, nil, nil
|
||||
|
||||
case http.StatusBadRequest:
|
||||
var errorResponse GetRestorePlanErrorResponse
|
||||
if err := json.Unmarshal(httpResp.Body(), &errorResponse); err != nil {
|
||||
return nil, nil, fmt.Errorf("decode restore plan error: %w", err)
|
||||
}
|
||||
|
||||
return nil, &errorResponse, nil
|
||||
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("get restore plan: server returned status %d: %s",
|
||||
httpResp.StatusCode(), httpResp.String())
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) DownloadBackupFile(
|
||||
ctx context.Context,
|
||||
backupID string,
|
||||
) (io.ReadCloser, error) {
|
||||
requestURL := c.buildURL(restoreDownloadPath) + "?" + url.Values{"backupId": {backupID}}.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create download request: %w", err)
|
||||
}
|
||||
|
||||
c.setStreamHeaders(req)
|
||||
|
||||
resp, err := c.streamHTTP.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("download backup file: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
|
||||
return nil, fmt.Errorf("download backup file: server returned status %d: %s",
|
||||
resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return resp.Body, nil
|
||||
}
|
||||
|
||||
func (c *Client) FetchServerVersion(ctx context.Context) (string, error) {
|
||||
var ver versionResponse
|
||||
|
||||
httpResp, err := c.json.R().
|
||||
SetContext(ctx).
|
||||
SetResult(&ver).
|
||||
Get(c.buildURL(versionPath))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := c.checkResponse(httpResp, "fetch server version"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return ver.Version, nil
|
||||
}
|
||||
|
||||
func (c *Client) DownloadAgentBinary(ctx context.Context, arch, destPath string) error {
|
||||
requestURL := c.buildURL(agentBinaryPath) + "?" + url.Values{"arch": {arch}}.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create agent download request: %w", err)
|
||||
}
|
||||
|
||||
c.setStreamHeaders(req)
|
||||
|
||||
resp, err := c.streamHTTP.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("server returned %d for agent download", resp.StatusCode)
|
||||
}
|
||||
|
||||
file, err := os.Create(destPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
|
||||
_, err = io.Copy(file, resp.Body)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) buildURL(path string) string {
|
||||
return c.host + path
|
||||
}
|
||||
|
||||
func (c *Client) checkResponse(resp *resty.Response, method string) error {
|
||||
if resp.StatusCode() >= 400 {
|
||||
return fmt.Errorf("%s: server returned status %d: %s", method, resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) setStreamHeaders(req *http.Request) {
|
||||
if c.token != "" {
|
||||
req.Header.Set("Authorization", c.token)
|
||||
}
|
||||
}
|
||||
72
agent/internal/features/api/dto.go
Normal file
72
agent/internal/features/api/dto.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package api
|
||||
|
||||
import "time"
|
||||
|
||||
type WalChainValidityResponse struct {
|
||||
IsValid bool `json:"isValid"`
|
||||
Error string `json:"error,omitempty"`
|
||||
LastContiguousSegment string `json:"lastContiguousSegment,omitempty"`
|
||||
}
|
||||
|
||||
type NextFullBackupTimeResponse struct {
|
||||
NextFullBackupTime *time.Time `json:"nextFullBackupTime"`
|
||||
}
|
||||
|
||||
type UploadWalSegmentResult struct {
|
||||
IsGapDetected bool
|
||||
ExpectedSegmentName string
|
||||
ReceivedSegmentName string
|
||||
}
|
||||
|
||||
type reportErrorRequest struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
type versionResponse struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
type UploadBasebackupResponse struct {
|
||||
BackupID string `json:"backupId"`
|
||||
}
|
||||
|
||||
type finalizeBasebackupRequest struct {
|
||||
BackupID string `json:"backupId"`
|
||||
StartSegment string `json:"startSegment"`
|
||||
StopSegment string `json:"stopSegment"`
|
||||
Error *string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type uploadErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
ExpectedSegmentName string `json:"expectedSegmentName"`
|
||||
ReceivedSegmentName string `json:"receivedSegmentName"`
|
||||
}
|
||||
|
||||
type RestorePlanFullBackup struct {
|
||||
BackupID string `json:"id"`
|
||||
FullBackupWalStartSegment string `json:"fullBackupWalStartSegment"`
|
||||
FullBackupWalStopSegment string `json:"fullBackupWalStopSegment"`
|
||||
PgVersion string `json:"pgVersion"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
SizeBytes int64 `json:"sizeBytes"`
|
||||
}
|
||||
|
||||
type RestorePlanWalSegment struct {
|
||||
BackupID string `json:"backupId"`
|
||||
SegmentName string `json:"segmentName"`
|
||||
SizeBytes int64 `json:"sizeBytes"`
|
||||
}
|
||||
|
||||
type GetRestorePlanResponse struct {
|
||||
FullBackup RestorePlanFullBackup `json:"fullBackup"`
|
||||
WalSegments []RestorePlanWalSegment `json:"walSegments"`
|
||||
TotalSizeBytes int64 `json:"totalSizeBytes"`
|
||||
LatestAvailableSegment string `json:"latestAvailableSegment"`
|
||||
}
|
||||
|
||||
type GetRestorePlanErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
Message string `json:"message"`
|
||||
LastContiguousSegment string `json:"lastContiguousSegment,omitempty"`
|
||||
}
|
||||
60
agent/internal/features/api/idle_timeout_reader.go
Normal file
60
agent/internal/features/api/idle_timeout_reader.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
// IdleTimeoutReader wraps an io.Reader and cancels the associated context
|
||||
// if no bytes are successfully read within the specified timeout duration.
|
||||
// This detects stalled uploads where the network or source stops transmitting data.
|
||||
//
|
||||
// When the idle timeout fires, the reader is also closed (if it implements io.Closer)
|
||||
// to unblock any goroutine blocked on the underlying Read.
|
||||
type IdleTimeoutReader struct {
|
||||
reader io.Reader
|
||||
timeout time.Duration
|
||||
cancel context.CancelCauseFunc
|
||||
timer *time.Timer
|
||||
}
|
||||
|
||||
// NewIdleTimeoutReader creates a reader that cancels the context via cancel
|
||||
// if Read does not return any bytes for the given timeout duration.
|
||||
func NewIdleTimeoutReader(reader io.Reader, timeout time.Duration, cancel context.CancelCauseFunc) *IdleTimeoutReader {
|
||||
r := &IdleTimeoutReader{
|
||||
reader: reader,
|
||||
timeout: timeout,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
r.timer = time.AfterFunc(timeout, func() {
|
||||
cancel(fmt.Errorf("upload idle timeout: no bytes transmitted for %v", timeout))
|
||||
|
||||
if closer, ok := reader.(io.Closer); ok {
|
||||
_ = closer.Close()
|
||||
}
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *IdleTimeoutReader) Read(p []byte) (int, error) {
|
||||
n, err := r.reader.Read(p)
|
||||
|
||||
if n > 0 {
|
||||
r.timer.Reset(r.timeout)
|
||||
}
|
||||
|
||||
if err != nil && err != io.EOF {
|
||||
r.Stop()
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Stop cancels the idle timer. Must be called when the reader is no longer needed.
|
||||
func (r *IdleTimeoutReader) Stop() {
|
||||
r.timer.Stop()
|
||||
}
|
||||
112
agent/internal/features/api/idle_timeout_reader_test.go
Normal file
112
agent/internal/features/api/idle_timeout_reader_test.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_ReadThroughIdleTimeoutReader_WhenBytesFlowContinuously_DoesNotCancelContext(t *testing.T) {
|
||||
ctx, cancel := context.WithCancelCause(t.Context())
|
||||
defer cancel(nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
idleReader := NewIdleTimeoutReader(pr, 200*time.Millisecond, cancel)
|
||||
defer idleReader.Stop()
|
||||
|
||||
go func() {
|
||||
for range 5 {
|
||||
_, _ = pw.Write([]byte("data"))
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
_ = pw.Close()
|
||||
}()
|
||||
|
||||
data, err := io.ReadAll(idleReader)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "datadatadatadatadata", string(data))
|
||||
assert.NoError(t, ctx.Err(), "context should not be cancelled when bytes flow continuously")
|
||||
}
|
||||
|
||||
func Test_ReadThroughIdleTimeoutReader_WhenNoBytesTransmitted_CancelsContext(t *testing.T) {
|
||||
ctx, cancel := context.WithCancelCause(t.Context())
|
||||
defer cancel(nil)
|
||||
|
||||
pr, _ := io.Pipe()
|
||||
|
||||
idleReader := NewIdleTimeoutReader(pr, 100*time.Millisecond, cancel)
|
||||
defer idleReader.Stop()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
assert.Error(t, ctx.Err(), "context should be cancelled when no bytes are transmitted")
|
||||
assert.Contains(t, context.Cause(ctx).Error(), "upload idle timeout")
|
||||
}
|
||||
|
||||
func Test_ReadThroughIdleTimeoutReader_WhenBytesStopMidStream_CancelsContext(t *testing.T) {
|
||||
ctx, cancel := context.WithCancelCause(t.Context())
|
||||
defer cancel(nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
idleReader := NewIdleTimeoutReader(pr, 100*time.Millisecond, cancel)
|
||||
defer idleReader.Stop()
|
||||
|
||||
go func() {
|
||||
_, _ = pw.Write([]byte("initial"))
|
||||
// Stop writing — simulate stalled source
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
n, _ := idleReader.Read(buf)
|
||||
assert.Equal(t, "initial", string(buf[:n]))
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
assert.Error(t, ctx.Err(), "context should be cancelled when bytes stop mid-stream")
|
||||
assert.Contains(t, context.Cause(ctx).Error(), "upload idle timeout")
|
||||
}
|
||||
|
||||
func Test_StopIdleTimeoutReader_WhenCalledBeforeTimeout_DoesNotCancelContext(t *testing.T) {
|
||||
ctx, cancel := context.WithCancelCause(t.Context())
|
||||
defer cancel(nil)
|
||||
|
||||
pr, _ := io.Pipe()
|
||||
|
||||
idleReader := NewIdleTimeoutReader(pr, 100*time.Millisecond, cancel)
|
||||
idleReader.Stop()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
assert.NoError(t, ctx.Err(), "context should not be cancelled when reader is stopped before timeout")
|
||||
}
|
||||
|
||||
func Test_ReadThroughIdleTimeoutReader_WhenReaderReturnsError_PropagatesError(t *testing.T) {
|
||||
ctx, cancel := context.WithCancelCause(t.Context())
|
||||
defer cancel(nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
idleReader := NewIdleTimeoutReader(pr, 5*time.Second, cancel)
|
||||
defer idleReader.Stop()
|
||||
|
||||
expectedErr := fmt.Errorf("test read error")
|
||||
_ = pw.CloseWithError(expectedErr)
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
_, err := idleReader.Read(buf)
|
||||
|
||||
assert.ErrorIs(t, err, expectedErr)
|
||||
|
||||
// Timer should be stopped after error — context should not be cancelled
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
assert.NoError(t, ctx.Err(), "context should not be cancelled after reader error stops the timer")
|
||||
}
|
||||
316
agent/internal/features/full_backup/backuper.go
Normal file
316
agent/internal/features/full_backup/backuper.go
Normal file
@@ -0,0 +1,316 @@
|
||||
package full_backup
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"databasus-agent/internal/config"
|
||||
"databasus-agent/internal/features/api"
|
||||
)
|
||||
|
||||
const (
|
||||
checkInterval = 30 * time.Second
|
||||
retryDelay = 1 * time.Minute
|
||||
uploadTimeout = 23 * time.Hour
|
||||
)
|
||||
|
||||
var uploadIdleTimeout = 5 * time.Minute
|
||||
|
||||
var retryDelayOverride *time.Duration
|
||||
|
||||
type CmdBuilder func(ctx context.Context) *exec.Cmd
|
||||
|
||||
// FullBackuper runs pg_basebackup when the WAL chain is broken or a scheduled backup is due.
|
||||
//
|
||||
// Every 30 seconds it checks two conditions via the Databasus API:
|
||||
// 1. WAL chain validity — if broken or no full backup exists, triggers an immediate basebackup.
|
||||
// 2. Scheduled backup time — if the next full backup time has passed, triggers a basebackup.
|
||||
//
|
||||
// Only one basebackup runs at a time (guarded by atomic bool).
|
||||
// On failure the error is reported to the server and the backup retries after 1 minute, indefinitely.
|
||||
// WAL segment uploads (handled by wal.Streamer) continue independently and are not paused.
|
||||
//
|
||||
// pg_basebackup runs as "pg_basebackup -Ft -D - -X fetch --verbose --checkpoint=fast".
|
||||
// Stdout (tar) is zstd-compressed and uploaded to the server.
|
||||
// Stderr is parsed for WAL start/stop segment names (LSN → segment arithmetic).
|
||||
type FullBackuper struct {
|
||||
cfg *config.Config
|
||||
apiClient *api.Client
|
||||
log *slog.Logger
|
||||
isRunning atomic.Bool
|
||||
cmdBuilder CmdBuilder
|
||||
}
|
||||
|
||||
func NewFullBackuper(cfg *config.Config, apiClient *api.Client, log *slog.Logger) *FullBackuper {
|
||||
backuper := &FullBackuper{
|
||||
cfg: cfg,
|
||||
apiClient: apiClient,
|
||||
log: log,
|
||||
}
|
||||
|
||||
backuper.cmdBuilder = backuper.defaultCmdBuilder
|
||||
|
||||
return backuper
|
||||
}
|
||||
|
||||
func (backuper *FullBackuper) Run(ctx context.Context) {
|
||||
backuper.log.Info("Full backuper started")
|
||||
|
||||
backuper.checkAndRunIfNeeded(ctx)
|
||||
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
backuper.log.Info("Full backuper stopping")
|
||||
return
|
||||
case <-ticker.C:
|
||||
backuper.checkAndRunIfNeeded(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (backuper *FullBackuper) checkAndRunIfNeeded(ctx context.Context) {
|
||||
if backuper.isRunning.Load() {
|
||||
backuper.log.Debug("Skipping check: basebackup already in progress")
|
||||
return
|
||||
}
|
||||
|
||||
chainResp, err := backuper.apiClient.CheckWalChainValidity(ctx)
|
||||
if err != nil {
|
||||
backuper.log.Error("Failed to check WAL chain validity", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !chainResp.IsValid {
|
||||
backuper.log.Info("WAL chain is invalid, triggering basebackup",
|
||||
"error", chainResp.Error,
|
||||
"lastContiguousSegment", chainResp.LastContiguousSegment,
|
||||
)
|
||||
|
||||
backuper.runBasebackupWithRetry(ctx)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
nextTimeResp, err := backuper.apiClient.GetNextFullBackupTime(ctx)
|
||||
if err != nil {
|
||||
backuper.log.Error("Failed to check next full backup time", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if nextTimeResp.NextFullBackupTime == nil || !nextTimeResp.NextFullBackupTime.After(time.Now().UTC()) {
|
||||
backuper.log.Info("Scheduled full backup is due, triggering basebackup")
|
||||
backuper.runBasebackupWithRetry(ctx)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
backuper.log.Debug("No basebackup needed",
|
||||
"nextFullBackupTime", nextTimeResp.NextFullBackupTime,
|
||||
)
|
||||
}
|
||||
|
||||
func (backuper *FullBackuper) runBasebackupWithRetry(ctx context.Context) {
|
||||
if !backuper.isRunning.CompareAndSwap(false, true) {
|
||||
backuper.log.Debug("Skipping basebackup: already running")
|
||||
return
|
||||
}
|
||||
defer backuper.isRunning.Store(false)
|
||||
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
backuper.log.Info("Starting pg_basebackup")
|
||||
|
||||
err := backuper.executeAndUploadBasebackup(ctx)
|
||||
if err == nil {
|
||||
backuper.log.Info("Basebackup completed successfully")
|
||||
return
|
||||
}
|
||||
|
||||
backuper.log.Error("Basebackup failed", "error", err)
|
||||
backuper.reportError(ctx, err.Error())
|
||||
|
||||
delay := retryDelay
|
||||
if retryDelayOverride != nil {
|
||||
delay = *retryDelayOverride
|
||||
}
|
||||
|
||||
backuper.log.Info("Retrying basebackup after delay", "delay", delay)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(delay):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (backuper *FullBackuper) executeAndUploadBasebackup(ctx context.Context) error {
|
||||
cmd := backuper.cmdBuilder(ctx)
|
||||
|
||||
var stderrBuf bytes.Buffer
|
||||
cmd.Stderr = &stderrBuf
|
||||
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fmt.Errorf("start pg_basebackup: %w", err)
|
||||
}
|
||||
|
||||
// Phase 1: Stream compressed data via io.Pipe directly to the API.
|
||||
pipeReader, pipeWriter := io.Pipe()
|
||||
defer func() { _ = pipeReader.Close() }()
|
||||
|
||||
go backuper.compressAndStream(pipeWriter, stdoutPipe)
|
||||
|
||||
uploadCtx, timeoutCancel := context.WithTimeout(ctx, uploadTimeout)
|
||||
defer timeoutCancel()
|
||||
|
||||
idleCtx, idleCancel := context.WithCancelCause(uploadCtx)
|
||||
defer idleCancel(nil)
|
||||
|
||||
idleReader := api.NewIdleTimeoutReader(pipeReader, uploadIdleTimeout, idleCancel)
|
||||
defer idleReader.Stop()
|
||||
|
||||
uploadResp, uploadErr := backuper.apiClient.UploadBasebackup(idleCtx, idleReader)
|
||||
|
||||
if uploadErr != nil && cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
|
||||
cmdErr := cmd.Wait()
|
||||
|
||||
if uploadErr != nil {
|
||||
if cause := context.Cause(idleCtx); cause != nil {
|
||||
uploadErr = cause
|
||||
}
|
||||
|
||||
stderrStr := stderrBuf.String()
|
||||
if stderrStr != "" {
|
||||
return fmt.Errorf("upload basebackup: %w (pg_basebackup stderr: %s)", uploadErr, stderrStr)
|
||||
}
|
||||
|
||||
return fmt.Errorf("upload basebackup: %w", uploadErr)
|
||||
}
|
||||
|
||||
if cmdErr != nil {
|
||||
errMsg := fmt.Sprintf("pg_basebackup exited with error: %v (stderr: %s)", cmdErr, stderrBuf.String())
|
||||
_ = backuper.apiClient.FinalizeBasebackupWithError(ctx, uploadResp.BackupID, errMsg)
|
||||
|
||||
return fmt.Errorf("%s", errMsg)
|
||||
}
|
||||
|
||||
// Phase 2: Parse stderr for WAL segments and finalize the backup.
|
||||
stderrStr := stderrBuf.String()
|
||||
backuper.log.Debug("pg_basebackup stderr", "stderr", stderrStr)
|
||||
|
||||
startSegment, stopSegment, err := ParseBasebackupStderr(stderrStr)
|
||||
if err != nil {
|
||||
errMsg := fmt.Sprintf("parse pg_basebackup stderr: %v", err)
|
||||
_ = backuper.apiClient.FinalizeBasebackupWithError(ctx, uploadResp.BackupID, errMsg)
|
||||
|
||||
return fmt.Errorf("parse pg_basebackup stderr: %w", err)
|
||||
}
|
||||
|
||||
backuper.log.Info("Basebackup WAL segments parsed",
|
||||
"startSegment", startSegment,
|
||||
"stopSegment", stopSegment,
|
||||
"backupId", uploadResp.BackupID,
|
||||
)
|
||||
|
||||
if err := backuper.apiClient.FinalizeBasebackup(ctx, uploadResp.BackupID, startSegment, stopSegment); err != nil {
|
||||
return fmt.Errorf("finalize basebackup: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (backuper *FullBackuper) compressAndStream(pipeWriter *io.PipeWriter, reader io.Reader) {
|
||||
encoder, err := zstd.NewWriter(pipeWriter,
|
||||
zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(5)),
|
||||
zstd.WithEncoderCRC(true),
|
||||
)
|
||||
if err != nil {
|
||||
_ = pipeWriter.CloseWithError(fmt.Errorf("create zstd encoder: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := io.Copy(encoder, reader); err != nil {
|
||||
_ = encoder.Close()
|
||||
_ = pipeWriter.CloseWithError(fmt.Errorf("compress: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := encoder.Close(); err != nil {
|
||||
_ = pipeWriter.CloseWithError(fmt.Errorf("close encoder: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
_ = pipeWriter.Close()
|
||||
}
|
||||
|
||||
func (backuper *FullBackuper) reportError(ctx context.Context, errMsg string) {
|
||||
if err := backuper.apiClient.ReportBackupError(ctx, errMsg); err != nil {
|
||||
backuper.log.Error("Failed to report error to server", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (backuper *FullBackuper) defaultCmdBuilder(ctx context.Context) *exec.Cmd {
|
||||
switch backuper.cfg.PgType {
|
||||
case "docker":
|
||||
return backuper.buildDockerCmd(ctx)
|
||||
default:
|
||||
return backuper.buildHostCmd(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (backuper *FullBackuper) buildHostCmd(ctx context.Context) *exec.Cmd {
|
||||
binary := "pg_basebackup"
|
||||
if backuper.cfg.PgHostBinDir != "" {
|
||||
binary = filepath.Join(backuper.cfg.PgHostBinDir, "pg_basebackup")
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, binary,
|
||||
"-Ft", "-D", "-", "-X", "fetch", "--verbose", "--checkpoint=fast",
|
||||
"-h", backuper.cfg.PgHost,
|
||||
"-p", fmt.Sprintf("%d", backuper.cfg.PgPort),
|
||||
"-U", backuper.cfg.PgUser,
|
||||
)
|
||||
|
||||
cmd.Env = append(os.Environ(), "PGPASSWORD="+backuper.cfg.PgPassword)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (backuper *FullBackuper) buildDockerCmd(ctx context.Context) *exec.Cmd {
|
||||
cmd := exec.CommandContext(ctx, "docker", "exec",
|
||||
"-e", "PGPASSWORD="+backuper.cfg.PgPassword,
|
||||
"-i", backuper.cfg.PgDockerContainerName,
|
||||
"pg_basebackup",
|
||||
"-Ft", "-D", "-", "-X", "fetch", "--verbose", "--checkpoint=fast",
|
||||
"-h", "localhost",
|
||||
"-p", "5432",
|
||||
"-U", backuper.cfg.PgUser,
|
||||
)
|
||||
|
||||
return cmd
|
||||
}
|
||||
735
agent/internal/features/full_backup/backuper_test.go
Normal file
735
agent/internal/features/full_backup/backuper_test.go
Normal file
@@ -0,0 +1,735 @@
|
||||
package full_backup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"databasus-agent/internal/config"
|
||||
"databasus-agent/internal/features/api"
|
||||
"databasus-agent/internal/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
testChainValidPath = "/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup"
|
||||
testNextBackupTimePath = "/api/v1/backups/postgres/wal/next-full-backup-time"
|
||||
testFullStartPath = "/api/v1/backups/postgres/wal/upload/full-start"
|
||||
testFullCompletePath = "/api/v1/backups/postgres/wal/upload/full-complete"
|
||||
testReportErrorPath = "/api/v1/backups/postgres/wal/error"
|
||||
|
||||
testBackupID = "test-backup-id-1234"
|
||||
)
|
||||
|
||||
func Test_RunFullBackup_WhenChainBroken_BasebackupTriggered(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var uploadReceived bool
|
||||
var uploadHeaders http.Header
|
||||
var finalizeReceived bool
|
||||
var finalizeBody map[string]any
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{
|
||||
IsValid: false,
|
||||
Error: "wal_chain_broken",
|
||||
LastContiguousSegment: "000000010000000100000011",
|
||||
})
|
||||
case testFullStartPath:
|
||||
mu.Lock()
|
||||
uploadReceived = true
|
||||
uploadHeaders = r.Header.Clone()
|
||||
mu.Unlock()
|
||||
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
mu.Lock()
|
||||
finalizeReceived = true
|
||||
_ = json.NewDecoder(r.Body).Decode(&finalizeBody)
|
||||
mu.Unlock()
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "test-backup-data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
waitForCondition(t, func() bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return finalizeReceived
|
||||
}, 5*time.Second)
|
||||
cancel()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
assert.True(t, uploadReceived)
|
||||
assert.Equal(t, "application/octet-stream", uploadHeaders.Get("Content-Type"))
|
||||
assert.Equal(t, "test-token", uploadHeaders.Get("Authorization"))
|
||||
|
||||
assert.True(t, finalizeReceived)
|
||||
assert.Equal(t, testBackupID, finalizeBody["backupId"])
|
||||
assert.Equal(t, "000000010000000000000002", finalizeBody["startSegment"])
|
||||
assert.Equal(t, "000000010000000000000002", finalizeBody["stopSegment"])
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenScheduledBackupDue_BasebackupTriggered(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var finalizeReceived bool
|
||||
|
||||
pastTime := time.Now().UTC().Add(-1 * time.Hour)
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{IsValid: true})
|
||||
case testNextBackupTimePath:
|
||||
writeJSON(w, api.NextFullBackupTimeResponse{NextFullBackupTime: &pastTime})
|
||||
case testFullStartPath:
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
mu.Lock()
|
||||
finalizeReceived = true
|
||||
mu.Unlock()
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "scheduled-backup-data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
waitForCondition(t, func() bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return finalizeReceived
|
||||
}, 5*time.Second)
|
||||
cancel()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
assert.True(t, finalizeReceived)
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenNoFullBackupExists_ImmediateBasebackupTriggered(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var finalizeReceived bool
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{
|
||||
IsValid: false,
|
||||
Error: "no_full_backup",
|
||||
})
|
||||
case testFullStartPath:
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
mu.Lock()
|
||||
finalizeReceived = true
|
||||
mu.Unlock()
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "first-backup-data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
waitForCondition(t, func() bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return finalizeReceived
|
||||
}, 5*time.Second)
|
||||
cancel()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
assert.True(t, finalizeReceived)
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenUploadFails_RetriesAfterDelay(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var uploadAttempts int
|
||||
var errorReported bool
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{
|
||||
IsValid: false,
|
||||
Error: "no_full_backup",
|
||||
})
|
||||
case testFullStartPath:
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
|
||||
mu.Lock()
|
||||
uploadAttempts++
|
||||
attempt := uploadAttempts
|
||||
mu.Unlock()
|
||||
|
||||
if attempt == 1 {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte(`{"error":"storage unavailable"}`))
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
case testReportErrorPath:
|
||||
mu.Lock()
|
||||
errorReported = true
|
||||
mu.Unlock()
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "retry-backup-data", validStderr())
|
||||
|
||||
origRetryDelay := retryDelay
|
||||
setRetryDelay(100 * time.Millisecond)
|
||||
defer setRetryDelay(origRetryDelay)
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
waitForCondition(t, func() bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return uploadAttempts >= 2
|
||||
}, 10*time.Second)
|
||||
cancel()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
assert.GreaterOrEqual(t, uploadAttempts, 2)
|
||||
assert.True(t, errorReported)
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenAlreadyRunning_SkipsExecution(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var uploadCount int
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{
|
||||
IsValid: false,
|
||||
Error: "no_full_backup",
|
||||
})
|
||||
case testFullStartPath:
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
|
||||
mu.Lock()
|
||||
uploadCount++
|
||||
mu.Unlock()
|
||||
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr())
|
||||
|
||||
fb.isRunning.Store(true)
|
||||
|
||||
fb.checkAndRunIfNeeded(t.Context())
|
||||
|
||||
mu.Lock()
|
||||
count := uploadCount
|
||||
mu.Unlock()
|
||||
|
||||
assert.Equal(t, 0, count, "should not trigger backup when already running")
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenContextCancelled_StopsCleanly(t *testing.T) {
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{
|
||||
IsValid: false,
|
||||
Error: "no_full_backup",
|
||||
})
|
||||
case testFullStartPath:
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
case testFullCompletePath:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
case testReportErrorPath:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr())
|
||||
|
||||
origRetryDelay := retryDelay
|
||||
setRetryDelay(5 * time.Second)
|
||||
defer setRetryDelay(origRetryDelay)
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
fb.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Run should have stopped after context cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenChainValidAndNotScheduled_NoBasebackupTriggered(t *testing.T) {
|
||||
var uploadReceived atomic.Bool
|
||||
|
||||
futureTime := time.Now().UTC().Add(24 * time.Hour)
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{IsValid: true})
|
||||
case testNextBackupTimePath:
|
||||
writeJSON(w, api.NextFullBackupTimeResponse{NextFullBackupTime: &futureTime})
|
||||
case testFullStartPath:
|
||||
uploadReceived.Store(true)
|
||||
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
assert.False(t, uploadReceived.Load(), "should not trigger backup when chain valid and not scheduled")
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenStderrParsingFails_FinalizesWithErrorAndRetries(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var errorReported bool
|
||||
var finalizeWithErrorReceived bool
|
||||
var finalizeBody map[string]any
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{
|
||||
IsValid: false,
|
||||
Error: "no_full_backup",
|
||||
})
|
||||
case testFullStartPath:
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
mu.Lock()
|
||||
finalizeWithErrorReceived = true
|
||||
_ = json.NewDecoder(r.Body).Decode(&finalizeBody)
|
||||
mu.Unlock()
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
case testReportErrorPath:
|
||||
mu.Lock()
|
||||
errorReported = true
|
||||
mu.Unlock()
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "data", "pg_basebackup: unexpected output with no LSN info")
|
||||
|
||||
origRetryDelay := retryDelay
|
||||
setRetryDelay(100 * time.Millisecond)
|
||||
defer setRetryDelay(origRetryDelay)
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
waitForCondition(t, func() bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return errorReported
|
||||
}, 2*time.Second)
|
||||
cancel()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
assert.True(t, errorReported)
|
||||
assert.True(t, finalizeWithErrorReceived, "should finalize with error when stderr parsing fails")
|
||||
assert.Equal(t, testBackupID, finalizeBody["backupId"])
|
||||
assert.NotNil(t, finalizeBody["error"], "finalize should include error message")
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenNextBackupTimeNull_BasebackupTriggered(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var finalizeReceived bool
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{IsValid: true})
|
||||
case testNextBackupTimePath:
|
||||
writeJSON(w, api.NextFullBackupTimeResponse{NextFullBackupTime: nil})
|
||||
case testFullStartPath:
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
mu.Lock()
|
||||
finalizeReceived = true
|
||||
mu.Unlock()
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "first-run-data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
waitForCondition(t, func() bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return finalizeReceived
|
||||
}, 5*time.Second)
|
||||
cancel()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
assert.True(t, finalizeReceived)
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenChainValidityReturns401_NoBasebackupTriggered(t *testing.T) {
|
||||
var uploadReceived atomic.Bool
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"error":"invalid token"}`))
|
||||
case testFullStartPath:
|
||||
uploadReceived.Store(true)
|
||||
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
assert.False(t, uploadReceived.Load(), "should not trigger backup when API returns 401")
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenUploadSucceeds_BodyIsZstdCompressed(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var receivedBody []byte
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testChainValidPath:
|
||||
writeJSON(w, api.WalChainValidityResponse{
|
||||
IsValid: false,
|
||||
Error: "no_full_backup",
|
||||
})
|
||||
case testFullStartPath:
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
|
||||
mu.Lock()
|
||||
receivedBody = body
|
||||
mu.Unlock()
|
||||
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
case testFullCompletePath:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
originalContent := "test-backup-content-for-compression-check"
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = mockCmdBuilder(t, originalContent, validStderr())
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go fb.Run(ctx)
|
||||
waitForCondition(t, func() bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return len(receivedBody) > 0
|
||||
}, 5*time.Second)
|
||||
cancel()
|
||||
|
||||
mu.Lock()
|
||||
body := receivedBody
|
||||
mu.Unlock()
|
||||
|
||||
decoder, err := zstd.NewReader(nil)
|
||||
require.NoError(t, err)
|
||||
defer decoder.Close()
|
||||
|
||||
decompressed, err := decoder.DecodeAll(body, nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, originalContent, string(decompressed))
|
||||
}
|
||||
|
||||
func Test_RunFullBackup_WhenUploadStalls_FailsWithIdleTimeout(t *testing.T) {
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testFullStartPath:
|
||||
// Server reads body normally — it will block until connection is closed
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
writeJSON(w, map[string]string{"backupId": testBackupID})
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
fb := newTestFullBackuper(server.URL)
|
||||
fb.cmdBuilder = stallingCmdBuilder(t)
|
||||
|
||||
origIdleTimeout := uploadIdleTimeout
|
||||
uploadIdleTimeout = 200 * time.Millisecond
|
||||
defer func() { uploadIdleTimeout = origIdleTimeout }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := fb.executeAndUploadBasebackup(ctx)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "idle timeout", "error should mention idle timeout")
|
||||
}
|
||||
|
||||
func stallingCmdBuilder(t *testing.T) CmdBuilder {
|
||||
t.Helper()
|
||||
|
||||
return func(ctx context.Context) *exec.Cmd {
|
||||
cmd := exec.CommandContext(ctx, os.Args[0],
|
||||
"-test.run=TestHelperProcessStalling",
|
||||
"--",
|
||||
)
|
||||
|
||||
cmd.Env = append(os.Environ(), "GO_TEST_HELPER_PROCESS_STALLING=1")
|
||||
|
||||
return cmd
|
||||
}
|
||||
}
|
||||
|
||||
func TestHelperProcessStalling(t *testing.T) {
|
||||
if os.Getenv("GO_TEST_HELPER_PROCESS_STALLING") != "1" {
|
||||
return
|
||||
}
|
||||
|
||||
// Write enough data to flush through the zstd encoder's internal buffer (~128KB blocks).
|
||||
// Without enough data, zstd buffers everything and the pipe never receives bytes.
|
||||
data := make([]byte, 256*1024)
|
||||
for i := range data {
|
||||
data[i] = byte(i)
|
||||
}
|
||||
_, _ = os.Stdout.Write(data)
|
||||
|
||||
// Stall with stdout open — the compress goroutine blocks on its next read.
|
||||
// The parent process will kill us when the context is cancelled.
|
||||
time.Sleep(time.Hour)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func newTestServer(t *testing.T, handler http.HandlerFunc) *httptest.Server {
|
||||
t.Helper()
|
||||
|
||||
server := httptest.NewServer(handler)
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
func newTestFullBackuper(serverURL string) *FullBackuper {
|
||||
cfg := &config.Config{
|
||||
DatabasusHost: serverURL,
|
||||
DbID: "test-db-id",
|
||||
Token: "test-token",
|
||||
PgHost: "localhost",
|
||||
PgPort: 5432,
|
||||
PgUser: "postgres",
|
||||
PgPassword: "password",
|
||||
PgType: "host",
|
||||
}
|
||||
|
||||
apiClient := api.NewClient(serverURL, cfg.Token, logger.GetLogger())
|
||||
|
||||
return NewFullBackuper(cfg, apiClient, logger.GetLogger())
|
||||
}
|
||||
|
||||
func mockCmdBuilder(t *testing.T, stdoutContent, stderrContent string) CmdBuilder {
|
||||
t.Helper()
|
||||
|
||||
return func(ctx context.Context) *exec.Cmd {
|
||||
cmd := exec.CommandContext(ctx, os.Args[0],
|
||||
"-test.run=TestHelperProcess",
|
||||
"--",
|
||||
stdoutContent,
|
||||
stderrContent,
|
||||
)
|
||||
|
||||
cmd.Env = append(os.Environ(), "GO_TEST_HELPER_PROCESS=1")
|
||||
|
||||
return cmd
|
||||
}
|
||||
}
|
||||
|
||||
func TestHelperProcess(t *testing.T) {
|
||||
if os.Getenv("GO_TEST_HELPER_PROCESS") != "1" {
|
||||
return
|
||||
}
|
||||
|
||||
args := os.Args
|
||||
for i, arg := range args {
|
||||
if arg == "--" {
|
||||
args = args[i+1:]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(args) >= 1 {
|
||||
_, _ = fmt.Fprint(os.Stdout, args[0])
|
||||
}
|
||||
|
||||
if len(args) >= 2 {
|
||||
_, _ = fmt.Fprint(os.Stderr, args[1])
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func validStderr() string {
|
||||
return `pg_basebackup: initiating base backup, waiting for checkpoint to complete
|
||||
pg_basebackup: checkpoint completed
|
||||
pg_basebackup: write-ahead log start point: 0/2000028 on timeline 1
|
||||
pg_basebackup: starting background WAL receiver
|
||||
pg_basebackup: write-ahead log end point: 0/2000100
|
||||
pg_basebackup: waiting for background process to finish streaming ...
|
||||
pg_basebackup: syncing data to disk ...
|
||||
pg_basebackup: base backup completed`
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, v any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if err := json.NewEncoder(w).Encode(v); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func waitForCondition(t *testing.T, condition func() bool, timeout time.Duration) {
|
||||
t.Helper()
|
||||
|
||||
deadline := time.Now().Add(timeout)
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
if condition() {
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Fatalf("condition not met within %v", timeout)
|
||||
}
|
||||
|
||||
func setRetryDelay(d time.Duration) {
|
||||
retryDelayOverride = &d
|
||||
}
|
||||
|
||||
func init() {
|
||||
retryDelayOverride = nil
|
||||
}
|
||||
75
agent/internal/features/full_backup/stderr_parser.go
Normal file
75
agent/internal/features/full_backup/stderr_parser.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package full_backup
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const defaultWalSegmentSize uint32 = 16 * 1024 * 1024 // 16 MB
|
||||
|
||||
var (
|
||||
startLSNRegex = regexp.MustCompile(`write-ahead log start point: ([0-9A-Fa-f]+/[0-9A-Fa-f]+)`)
|
||||
stopLSNRegex = regexp.MustCompile(`write-ahead log end point: ([0-9A-Fa-f]+/[0-9A-Fa-f]+)`)
|
||||
)
|
||||
|
||||
func ParseBasebackupStderr(stderr string) (startSegment, stopSegment string, err error) {
|
||||
startMatch := startLSNRegex.FindStringSubmatch(stderr)
|
||||
if len(startMatch) < 2 {
|
||||
return "", "", fmt.Errorf("failed to parse start WAL location from pg_basebackup stderr")
|
||||
}
|
||||
|
||||
stopMatch := stopLSNRegex.FindStringSubmatch(stderr)
|
||||
if len(stopMatch) < 2 {
|
||||
return "", "", fmt.Errorf("failed to parse stop WAL location from pg_basebackup stderr")
|
||||
}
|
||||
|
||||
startSegment, err = LSNToSegmentName(startMatch[1], 1, defaultWalSegmentSize)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to convert start LSN to segment name: %w", err)
|
||||
}
|
||||
|
||||
stopSegment, err = LSNToSegmentName(stopMatch[1], 1, defaultWalSegmentSize)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to convert stop LSN to segment name: %w", err)
|
||||
}
|
||||
|
||||
return startSegment, stopSegment, nil
|
||||
}
|
||||
|
||||
func LSNToSegmentName(lsn string, timelineID, walSegmentSize uint32) (string, error) {
|
||||
high, low, err := parseLSN(lsn)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
segmentsPerXLogID := uint32(0x100000000 / uint64(walSegmentSize))
|
||||
logID := high
|
||||
segmentOffset := low / walSegmentSize
|
||||
|
||||
if segmentOffset >= segmentsPerXLogID {
|
||||
return "", fmt.Errorf("segment offset %d exceeds segments per XLogId %d", segmentOffset, segmentsPerXLogID)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%08X%08X%08X", timelineID, logID, segmentOffset), nil
|
||||
}
|
||||
|
||||
func parseLSN(lsn string) (high, low uint32, err error) {
|
||||
parts := strings.SplitN(lsn, "/", 2)
|
||||
if len(parts) != 2 {
|
||||
return 0, 0, fmt.Errorf("invalid LSN format: %q (expected X/Y)", lsn)
|
||||
}
|
||||
|
||||
highVal, err := strconv.ParseUint(parts[0], 16, 32)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("invalid LSN high part %q: %w", parts[0], err)
|
||||
}
|
||||
|
||||
lowVal, err := strconv.ParseUint(parts[1], 16, 32)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("invalid LSN low part %q: %w", parts[1], err)
|
||||
}
|
||||
|
||||
return uint32(highVal), uint32(lowVal), nil
|
||||
}
|
||||
157
agent/internal/features/full_backup/stderr_parser_test.go
Normal file
157
agent/internal/features/full_backup/stderr_parser_test.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package full_backup
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_ParseBasebackupStderr_WithPG17FetchOutput_ExtractsCorrectSegments(t *testing.T) {
|
||||
stderr := `pg_basebackup: initiating base backup, waiting for checkpoint to complete
|
||||
pg_basebackup: checkpoint completed
|
||||
pg_basebackup: write-ahead log start point: 0/2000028 on timeline 1
|
||||
pg_basebackup: starting background WAL receiver
|
||||
pg_basebackup: write-ahead log end point: 0/2000100
|
||||
pg_basebackup: waiting for background process to finish streaming ...
|
||||
pg_basebackup: syncing data to disk ...
|
||||
pg_basebackup: renaming backup_manifest.tmp to backup_manifest
|
||||
pg_basebackup: base backup completed`
|
||||
|
||||
startSeg, stopSeg, err := ParseBasebackupStderr(stderr)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "000000010000000000000002", startSeg)
|
||||
assert.Equal(t, "000000010000000000000002", stopSeg)
|
||||
}
|
||||
|
||||
func Test_ParseBasebackupStderr_WithHighLSNValues_ExtractsCorrectSegments(t *testing.T) {
|
||||
stderr := `pg_basebackup: write-ahead log start point: 1/AB000028 on timeline 1
|
||||
pg_basebackup: write-ahead log end point: 1/AC000000`
|
||||
|
||||
startSeg, stopSeg, err := ParseBasebackupStderr(stderr)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "0000000100000001000000AB", startSeg)
|
||||
assert.Equal(t, "0000000100000001000000AC", stopSeg)
|
||||
}
|
||||
|
||||
func Test_ParseBasebackupStderr_WithHighLogID_ExtractsCorrectSegments(t *testing.T) {
|
||||
stderr := `pg_basebackup: write-ahead log start point: A/FF000028 on timeline 1
|
||||
pg_basebackup: write-ahead log end point: B/1000000`
|
||||
|
||||
startSeg, stopSeg, err := ParseBasebackupStderr(stderr)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "000000010000000A000000FF", startSeg)
|
||||
assert.Equal(t, "000000010000000B00000001", stopSeg)
|
||||
}
|
||||
|
||||
func Test_ParseBasebackupStderr_WhenStartLSNMissing_ReturnsError(t *testing.T) {
|
||||
stderr := `pg_basebackup: write-ahead log end point: 0/2000100
|
||||
pg_basebackup: base backup completed`
|
||||
|
||||
_, _, err := ParseBasebackupStderr(stderr)
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to parse start WAL location")
|
||||
}
|
||||
|
||||
func Test_ParseBasebackupStderr_WhenStopLSNMissing_ReturnsError(t *testing.T) {
|
||||
stderr := `pg_basebackup: write-ahead log start point: 0/2000028 on timeline 1
|
||||
pg_basebackup: base backup completed`
|
||||
|
||||
_, _, err := ParseBasebackupStderr(stderr)
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to parse stop WAL location")
|
||||
}
|
||||
|
||||
func Test_ParseBasebackupStderr_WhenEmptyStderr_ReturnsError(t *testing.T) {
|
||||
_, _, err := ParseBasebackupStderr("")
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to parse start WAL location")
|
||||
}
|
||||
|
||||
func Test_LSNToSegmentName_WithBoundaryValues_ConvertsCorrectly(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
lsn string
|
||||
timeline uint32
|
||||
segSize uint32
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "first segment",
|
||||
lsn: "0/1000000",
|
||||
timeline: 1,
|
||||
segSize: 16 * 1024 * 1024,
|
||||
expected: "000000010000000000000001",
|
||||
},
|
||||
{
|
||||
name: "segment at boundary FF",
|
||||
lsn: "0/FF000000",
|
||||
timeline: 1,
|
||||
segSize: 16 * 1024 * 1024,
|
||||
expected: "0000000100000000000000FF",
|
||||
},
|
||||
{
|
||||
name: "segment in second log file",
|
||||
lsn: "1/0",
|
||||
timeline: 1,
|
||||
segSize: 16 * 1024 * 1024,
|
||||
expected: "000000010000000100000000",
|
||||
},
|
||||
{
|
||||
name: "segment with offset within 16MB",
|
||||
lsn: "0/200ABCD",
|
||||
timeline: 1,
|
||||
segSize: 16 * 1024 * 1024,
|
||||
expected: "000000010000000000000002",
|
||||
},
|
||||
{
|
||||
name: "zero LSN",
|
||||
lsn: "0/0",
|
||||
timeline: 1,
|
||||
segSize: 16 * 1024 * 1024,
|
||||
expected: "000000010000000000000000",
|
||||
},
|
||||
{
|
||||
name: "high timeline ID",
|
||||
lsn: "0/1000000",
|
||||
timeline: 2,
|
||||
segSize: 16 * 1024 * 1024,
|
||||
expected: "000000020000000000000001",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := LSNToSegmentName(tt.lsn, tt.timeline, tt.segSize)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_LSNToSegmentName_WithInvalidLSN_ReturnsError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
lsn string
|
||||
}{
|
||||
{name: "no slash", lsn: "012345"},
|
||||
{name: "empty string", lsn: ""},
|
||||
{name: "invalid hex high", lsn: "GG/0"},
|
||||
{name: "invalid hex low", lsn: "0/ZZ"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := LSNToSegmentName(tt.lsn, 1, 16*1024*1024)
|
||||
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
444
agent/internal/features/restore/restorer.go
Normal file
444
agent/internal/features/restore/restorer.go
Normal file
@@ -0,0 +1,444 @@
|
||||
package restore
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"databasus-agent/internal/features/api"
|
||||
)
|
||||
|
||||
const (
|
||||
walRestoreDir = "databasus-wal-restore"
|
||||
maxRetryAttempts = 3
|
||||
retryBaseDelay = 1 * time.Second
|
||||
recoverySignalFile = "recovery.signal"
|
||||
autoConfFile = "postgresql.auto.conf"
|
||||
dockerContainerPgDataDir = "/var/lib/postgresql/data"
|
||||
)
|
||||
|
||||
var retryDelayOverride *time.Duration
|
||||
|
||||
type Restorer struct {
|
||||
apiClient *api.Client
|
||||
log *slog.Logger
|
||||
targetPgDataDir string
|
||||
backupID string
|
||||
targetTime string
|
||||
pgType string
|
||||
}
|
||||
|
||||
func NewRestorer(
|
||||
apiClient *api.Client,
|
||||
log *slog.Logger,
|
||||
targetPgDataDir string,
|
||||
backupID string,
|
||||
targetTime string,
|
||||
pgType string,
|
||||
) *Restorer {
|
||||
return &Restorer{
|
||||
apiClient,
|
||||
log,
|
||||
targetPgDataDir,
|
||||
backupID,
|
||||
targetTime,
|
||||
pgType,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Restorer) Run(ctx context.Context) error {
|
||||
var parsedTargetTime *time.Time
|
||||
|
||||
if r.targetTime != "" {
|
||||
parsed, err := time.Parse(time.RFC3339, r.targetTime)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid --target-time format (expected RFC3339, e.g. 2026-02-28T14:30:00Z): %w", err)
|
||||
}
|
||||
|
||||
parsedTargetTime = &parsed
|
||||
}
|
||||
|
||||
if err := r.validateTargetPgDataDir(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
plan, err := r.getRestorePlanFromServer(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.logRestorePlan(plan, parsedTargetTime)
|
||||
|
||||
r.log.Info("Downloading and extracting basebackup...")
|
||||
if err := r.downloadAndExtractBasebackup(ctx, plan.FullBackup.BackupID); err != nil {
|
||||
return fmt.Errorf("basebackup download failed: %w", err)
|
||||
}
|
||||
r.log.Info("Basebackup extracted successfully")
|
||||
|
||||
if err := r.downloadAllWalSegments(ctx, plan.WalSegments); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.configurePostgresRecovery(parsedTargetTime); err != nil {
|
||||
return fmt.Errorf("failed to configure recovery: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(r.targetPgDataDir, 0o700); err != nil {
|
||||
return fmt.Errorf("set PGDATA permissions: %w", err)
|
||||
}
|
||||
|
||||
r.printCompletionMessage()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Restorer) validateTargetPgDataDir() error {
|
||||
info, err := os.Stat(r.targetPgDataDir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return fmt.Errorf("target pgdata directory does not exist: %s", r.targetPgDataDir)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot access target pgdata directory: %w", err)
|
||||
}
|
||||
|
||||
if !info.IsDir() {
|
||||
return fmt.Errorf("target pgdata path is not a directory: %s", r.targetPgDataDir)
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(r.targetPgDataDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot read target pgdata directory: %w", err)
|
||||
}
|
||||
|
||||
if len(entries) > 0 {
|
||||
return fmt.Errorf("target pgdata directory is not empty: %s", r.targetPgDataDir)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Restorer) getRestorePlanFromServer(ctx context.Context) (*api.GetRestorePlanResponse, error) {
|
||||
plan, planErr, err := r.apiClient.GetRestorePlan(ctx, r.backupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch restore plan: %w", err)
|
||||
}
|
||||
|
||||
if planErr != nil {
|
||||
if planErr.LastContiguousSegment != "" {
|
||||
return nil, fmt.Errorf("restore plan error: %s (last contiguous segment: %s)",
|
||||
planErr.Message, planErr.LastContiguousSegment)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("restore plan error: %s", planErr.Message)
|
||||
}
|
||||
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
func (r *Restorer) logRestorePlan(plan *api.GetRestorePlanResponse, parsedTargetTime *time.Time) {
|
||||
recoveryTarget := "full recovery (all available WAL)"
|
||||
if parsedTargetTime != nil {
|
||||
recoveryTarget = parsedTargetTime.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
r.log.Info("Restore plan",
|
||||
"fullBackupID", plan.FullBackup.BackupID,
|
||||
"fullBackupCreatedAt", plan.FullBackup.CreatedAt.Format(time.RFC3339),
|
||||
"pgVersion", plan.FullBackup.PgVersion,
|
||||
"walSegmentCount", len(plan.WalSegments),
|
||||
"totalDownloadSize", formatSizeBytes(plan.TotalSizeBytes),
|
||||
"latestAvailableSegment", plan.LatestAvailableSegment,
|
||||
"recoveryTarget", recoveryTarget,
|
||||
)
|
||||
}
|
||||
|
||||
func (r *Restorer) downloadAndExtractBasebackup(ctx context.Context, backupID string) error {
|
||||
body, err := r.apiClient.DownloadBackupFile(ctx, backupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = body.Close() }()
|
||||
|
||||
zstdReader, err := zstd.NewReader(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create zstd decompressor: %w", err)
|
||||
}
|
||||
defer zstdReader.Close()
|
||||
|
||||
tarReader := tar.NewReader(zstdReader)
|
||||
|
||||
return r.extractTarArchive(tarReader)
|
||||
}
|
||||
|
||||
func (r *Restorer) extractTarArchive(tarReader *tar.Reader) error {
|
||||
for {
|
||||
header, err := tarReader.Next()
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("read tar entry: %w", err)
|
||||
}
|
||||
|
||||
targetPath := filepath.Join(r.targetPgDataDir, header.Name)
|
||||
|
||||
relativePath, err := filepath.Rel(r.targetPgDataDir, targetPath)
|
||||
if err != nil || strings.HasPrefix(relativePath, "..") {
|
||||
return fmt.Errorf("tar entry attempts path traversal: %s", header.Name)
|
||||
}
|
||||
|
||||
switch header.Typeflag {
|
||||
case tar.TypeDir:
|
||||
if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil {
|
||||
return fmt.Errorf("create directory %s: %w", header.Name, err)
|
||||
}
|
||||
|
||||
case tar.TypeReg:
|
||||
parentDir := filepath.Dir(targetPath)
|
||||
if err := os.MkdirAll(parentDir, 0o755); err != nil {
|
||||
return fmt.Errorf("create parent directory for %s: %w", header.Name, err)
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(targetPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create file %s: %w", header.Name, err)
|
||||
}
|
||||
|
||||
if _, err := io.Copy(file, tarReader); err != nil {
|
||||
_ = file.Close()
|
||||
return fmt.Errorf("write file %s: %w", header.Name, err)
|
||||
}
|
||||
|
||||
_ = file.Close()
|
||||
|
||||
case tar.TypeSymlink:
|
||||
if err := os.Symlink(header.Linkname, targetPath); err != nil {
|
||||
return fmt.Errorf("create symlink %s: %w", header.Name, err)
|
||||
}
|
||||
|
||||
case tar.TypeLink:
|
||||
linkTarget := filepath.Join(r.targetPgDataDir, header.Linkname)
|
||||
if err := os.Link(linkTarget, targetPath); err != nil {
|
||||
return fmt.Errorf("create hard link %s: %w", header.Name, err)
|
||||
}
|
||||
|
||||
default:
|
||||
r.log.Warn("Skipping unsupported tar entry type",
|
||||
"name", header.Name,
|
||||
"type", header.Typeflag,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Restorer) downloadAllWalSegments(ctx context.Context, segments []api.RestorePlanWalSegment) error {
|
||||
walRestorePath := filepath.Join(r.targetPgDataDir, walRestoreDir)
|
||||
if err := os.MkdirAll(walRestorePath, 0o755); err != nil {
|
||||
return fmt.Errorf("create WAL restore directory: %w", err)
|
||||
}
|
||||
|
||||
for segmentIndex, segment := range segments {
|
||||
if err := r.downloadWalSegmentWithRetry(ctx, segment, segmentIndex, len(segments)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Restorer) downloadWalSegmentWithRetry(
|
||||
ctx context.Context,
|
||||
segment api.RestorePlanWalSegment,
|
||||
segmentIndex int,
|
||||
segmentsTotal int,
|
||||
) error {
|
||||
r.log.Info("Downloading WAL segment",
|
||||
"segment", segment.SegmentName,
|
||||
"progress", fmt.Sprintf("%d/%d", segmentIndex+1, segmentsTotal),
|
||||
)
|
||||
|
||||
var lastErr error
|
||||
|
||||
for attempt := range maxRetryAttempts {
|
||||
if err := r.downloadWalSegment(ctx, segment); err != nil {
|
||||
lastErr = err
|
||||
|
||||
delay := r.getRetryDelay(attempt)
|
||||
r.log.Warn("WAL segment download failed, retrying",
|
||||
"segment", segment.SegmentName,
|
||||
"attempt", attempt+1,
|
||||
"maxAttempts", maxRetryAttempts,
|
||||
"retryDelay", delay,
|
||||
"error", err,
|
||||
)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(delay):
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to download WAL segment %s after %d attempts: %w",
|
||||
segment.SegmentName, maxRetryAttempts, lastErr)
|
||||
}
|
||||
|
||||
func (r *Restorer) downloadWalSegment(ctx context.Context, segment api.RestorePlanWalSegment) error {
|
||||
body, err := r.apiClient.DownloadBackupFile(ctx, segment.BackupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = body.Close() }()
|
||||
|
||||
zstdReader, err := zstd.NewReader(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create zstd decompressor: %w", err)
|
||||
}
|
||||
defer zstdReader.Close()
|
||||
|
||||
segmentPath := filepath.Join(r.targetPgDataDir, walRestoreDir, segment.SegmentName)
|
||||
|
||||
file, err := os.Create(segmentPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create WAL segment file: %w", err)
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
|
||||
if _, err := io.Copy(file, zstdReader); err != nil {
|
||||
return fmt.Errorf("write WAL segment: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Restorer) configurePostgresRecovery(parsedTargetTime *time.Time) error {
|
||||
recoverySignalPath := filepath.Join(r.targetPgDataDir, recoverySignalFile)
|
||||
if err := os.WriteFile(recoverySignalPath, []byte{}, 0o644); err != nil {
|
||||
return fmt.Errorf("create recovery.signal: %w", err)
|
||||
}
|
||||
|
||||
walRestoreAbsPath, err := r.resolveWalRestorePath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
autoConfPath := filepath.Join(r.targetPgDataDir, autoConfFile)
|
||||
|
||||
autoConfFile, err := os.OpenFile(autoConfPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open postgresql.auto.conf: %w", err)
|
||||
}
|
||||
defer func() { _ = autoConfFile.Close() }()
|
||||
|
||||
var configLines strings.Builder
|
||||
configLines.WriteString("\n# Added by databasus-agent restore\n")
|
||||
fmt.Fprintf(&configLines, "restore_command = 'cp %s/%%f %%p'\n", walRestoreAbsPath)
|
||||
fmt.Fprintf(&configLines, "recovery_end_command = 'rm -rf %s'\n", walRestoreAbsPath)
|
||||
configLines.WriteString("recovery_target_action = 'promote'\n")
|
||||
|
||||
if parsedTargetTime != nil {
|
||||
fmt.Fprintf(&configLines, "recovery_target_time = '%s'\n", parsedTargetTime.Format(time.RFC3339))
|
||||
}
|
||||
|
||||
if _, err := autoConfFile.WriteString(configLines.String()); err != nil {
|
||||
return fmt.Errorf("write to postgresql.auto.conf: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Restorer) printCompletionMessage() {
|
||||
absPgDataDir, _ := filepath.Abs(r.targetPgDataDir)
|
||||
isDocker := r.pgType == "docker"
|
||||
|
||||
fmt.Printf("\nRestore complete. PGDATA directory is ready at %s.\n", absPgDataDir)
|
||||
|
||||
fmt.Print(`
|
||||
What happens when you start PostgreSQL:
|
||||
1. PostgreSQL detects recovery.signal and enters recovery mode
|
||||
2. It replays WAL from the basebackup's consistency point
|
||||
3. It executes restore_command to fetch WAL segments from databasus-wal-restore/
|
||||
4. WAL replay continues until target_time (if PITR) or end of available WAL
|
||||
5. recovery_end_command automatically removes databasus-wal-restore/
|
||||
6. PostgreSQL promotes to primary and removes recovery.signal
|
||||
7. Normal operations resume
|
||||
`)
|
||||
|
||||
if isDocker {
|
||||
fmt.Printf(`
|
||||
Start PostgreSQL by launching a container with the restored data mounted:
|
||||
docker run -d -v %s:%s postgres:<VERSION>
|
||||
|
||||
Or if you have an existing container:
|
||||
docker start <CONTAINER_NAME>
|
||||
|
||||
Ensure %s is mounted as the container's pgdata volume at %s.
|
||||
`, absPgDataDir, dockerContainerPgDataDir, absPgDataDir, dockerContainerPgDataDir)
|
||||
} else {
|
||||
fmt.Printf(`
|
||||
Start PostgreSQL:
|
||||
pg_ctl -D %s start
|
||||
|
||||
Note: If you move the PGDATA directory before starting PostgreSQL,
|
||||
update restore_command and recovery_end_command paths in
|
||||
postgresql.auto.conf accordingly.
|
||||
`, absPgDataDir)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Restorer) resolveWalRestorePath() (string, error) {
|
||||
if r.pgType == "docker" {
|
||||
return dockerContainerPgDataDir + "/" + walRestoreDir, nil
|
||||
}
|
||||
|
||||
absPgDataDir, err := filepath.Abs(r.targetPgDataDir)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("resolve absolute path: %w", err)
|
||||
}
|
||||
|
||||
absPgDataDir = filepath.ToSlash(absPgDataDir)
|
||||
|
||||
return absPgDataDir + "/" + walRestoreDir, nil
|
||||
}
|
||||
|
||||
func (r *Restorer) getRetryDelay(attempt int) time.Duration {
|
||||
if retryDelayOverride != nil {
|
||||
return *retryDelayOverride
|
||||
}
|
||||
|
||||
return retryBaseDelay * time.Duration(1<<attempt)
|
||||
}
|
||||
|
||||
func formatSizeBytes(sizeBytes int64) string {
|
||||
const (
|
||||
kilobyte = 1024
|
||||
megabyte = 1024 * kilobyte
|
||||
gigabyte = 1024 * megabyte
|
||||
)
|
||||
|
||||
switch {
|
||||
case sizeBytes >= gigabyte:
|
||||
return fmt.Sprintf("%.2f GB", float64(sizeBytes)/float64(gigabyte))
|
||||
case sizeBytes >= megabyte:
|
||||
return fmt.Sprintf("%.2f MB", float64(sizeBytes)/float64(megabyte))
|
||||
case sizeBytes >= kilobyte:
|
||||
return fmt.Sprintf("%.2f KB", float64(sizeBytes)/float64(kilobyte))
|
||||
default:
|
||||
return fmt.Sprintf("%d B", sizeBytes)
|
||||
}
|
||||
}
|
||||
711
agent/internal/features/restore/restorer_test.go
Normal file
711
agent/internal/features/restore/restorer_test.go
Normal file
@@ -0,0 +1,711 @@
|
||||
package restore
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"databasus-agent/internal/features/api"
|
||||
"databasus-agent/internal/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
testRestorePlanPath = "/api/v1/backups/postgres/wal/restore/plan"
|
||||
testRestoreDownloadPath = "/api/v1/backups/postgres/wal/restore/download"
|
||||
|
||||
testFullBackupID = "full-backup-id-1234"
|
||||
testWalSegment1 = "000000010000000100000001"
|
||||
testWalSegment2 = "000000010000000100000002"
|
||||
)
|
||||
|
||||
func Test_RunRestore_WhenBasebackupAndWalSegmentsAvailable_FilesExtractedAndRecoveryConfigured(t *testing.T) {
|
||||
tarFiles := map[string][]byte{
|
||||
"PG_VERSION": []byte("16"),
|
||||
"base/1/somefile": []byte("table-data"),
|
||||
}
|
||||
zstdTarData := createZstdTar(t, tarFiles)
|
||||
walData1 := createZstdData(t, []byte("wal-segment-1-data"))
|
||||
walData2 := createZstdData(t, []byte("wal-segment-2-data"))
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testRestorePlanPath:
|
||||
writeJSON(w, api.GetRestorePlanResponse{
|
||||
FullBackup: api.RestorePlanFullBackup{
|
||||
BackupID: testFullBackupID,
|
||||
FullBackupWalStartSegment: testWalSegment1,
|
||||
FullBackupWalStopSegment: testWalSegment1,
|
||||
PgVersion: "16",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
SizeBytes: 1024,
|
||||
},
|
||||
WalSegments: []api.RestorePlanWalSegment{
|
||||
{BackupID: "wal-1", SegmentName: testWalSegment1, SizeBytes: 512},
|
||||
{BackupID: "wal-2", SegmentName: testWalSegment2, SizeBytes: 512},
|
||||
},
|
||||
TotalSizeBytes: 2048,
|
||||
LatestAvailableSegment: testWalSegment2,
|
||||
})
|
||||
|
||||
case testRestoreDownloadPath:
|
||||
backupID := r.URL.Query().Get("backupId")
|
||||
switch backupID {
|
||||
case testFullBackupID:
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write(zstdTarData)
|
||||
case "wal-1":
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write(walData1)
|
||||
case "wal-2":
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write(walData2)
|
||||
default:
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}
|
||||
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
|
||||
|
||||
err := restorer.Run(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
pgVersionContent, err := os.ReadFile(filepath.Join(targetDir, "PG_VERSION"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "16", string(pgVersionContent))
|
||||
|
||||
someFileContent, err := os.ReadFile(filepath.Join(targetDir, "base", "1", "somefile"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "table-data", string(someFileContent))
|
||||
|
||||
walSegment1Content, err := os.ReadFile(filepath.Join(targetDir, walRestoreDir, testWalSegment1))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "wal-segment-1-data", string(walSegment1Content))
|
||||
|
||||
walSegment2Content, err := os.ReadFile(filepath.Join(targetDir, walRestoreDir, testWalSegment2))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "wal-segment-2-data", string(walSegment2Content))
|
||||
|
||||
recoverySignalPath := filepath.Join(targetDir, "recovery.signal")
|
||||
recoverySignalInfo, err := os.Stat(recoverySignalPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), recoverySignalInfo.Size())
|
||||
|
||||
autoConfContent, err := os.ReadFile(filepath.Join(targetDir, "postgresql.auto.conf"))
|
||||
require.NoError(t, err)
|
||||
autoConfStr := string(autoConfContent)
|
||||
|
||||
assert.Contains(t, autoConfStr, "restore_command")
|
||||
assert.Contains(t, autoConfStr, walRestoreDir)
|
||||
assert.Contains(t, autoConfStr, "recovery_target_action = 'promote'")
|
||||
assert.Contains(t, autoConfStr, "recovery_end_command")
|
||||
assert.NotContains(t, autoConfStr, "recovery_target_time")
|
||||
}
|
||||
|
||||
func Test_RunRestore_WhenTargetTimeProvided_RecoveryTargetTimeWrittenToConfig(t *testing.T) {
|
||||
tarFiles := map[string][]byte{"PG_VERSION": []byte("16")}
|
||||
zstdTarData := createZstdTar(t, tarFiles)
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testRestorePlanPath:
|
||||
writeJSON(w, api.GetRestorePlanResponse{
|
||||
FullBackup: api.RestorePlanFullBackup{
|
||||
BackupID: testFullBackupID,
|
||||
PgVersion: "16",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
SizeBytes: 1024,
|
||||
},
|
||||
WalSegments: []api.RestorePlanWalSegment{},
|
||||
TotalSizeBytes: 1024,
|
||||
LatestAvailableSegment: "",
|
||||
})
|
||||
|
||||
case testRestoreDownloadPath:
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write(zstdTarData)
|
||||
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "2026-02-28T14:30:00Z", "")
|
||||
|
||||
err := restorer.Run(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
autoConfContent, err := os.ReadFile(filepath.Join(targetDir, "postgresql.auto.conf"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Contains(t, string(autoConfContent), "recovery_target_time = '2026-02-28T14:30:00Z'")
|
||||
}
|
||||
|
||||
func Test_RunRestore_WhenPgDataDirNotEmpty_ReturnsError(t *testing.T) {
|
||||
targetDir := createTestTargetDir(t)
|
||||
|
||||
err := os.WriteFile(filepath.Join(targetDir, "existing-file"), []byte("data"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
restorer := newTestRestorer("http://localhost:0", targetDir, "", "", "")
|
||||
|
||||
err = restorer.Run(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not empty")
|
||||
}
|
||||
|
||||
func Test_RunRestore_WhenPgDataDirDoesNotExist_ReturnsError(t *testing.T) {
|
||||
nonExistentDir := filepath.Join(os.TempDir(), "databasus-test-nonexistent-dir-12345")
|
||||
|
||||
restorer := newTestRestorer("http://localhost:0", nonExistentDir, "", "", "")
|
||||
|
||||
err := restorer.Run(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "does not exist")
|
||||
}
|
||||
|
||||
func Test_RunRestore_WhenNoBackupsAvailable_ReturnsError(t *testing.T) {
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_ = json.NewEncoder(w).Encode(api.GetRestorePlanErrorResponse{
|
||||
Error: "no_backups",
|
||||
Message: "No full backups available",
|
||||
})
|
||||
})
|
||||
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
|
||||
|
||||
err := restorer.Run(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "No full backups available")
|
||||
}
|
||||
|
||||
func Test_RunRestore_WhenWalChainBroken_ReturnsError(t *testing.T) {
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_ = json.NewEncoder(w).Encode(api.GetRestorePlanErrorResponse{
|
||||
Error: "wal_chain_broken",
|
||||
Message: "WAL chain broken",
|
||||
LastContiguousSegment: testWalSegment1,
|
||||
})
|
||||
})
|
||||
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
|
||||
|
||||
err := restorer.Run(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "WAL chain broken")
|
||||
assert.Contains(t, err.Error(), testWalSegment1)
|
||||
}
|
||||
|
||||
func Test_DownloadWalSegment_WhenFirstAttemptFails_RetriesAndSucceeds(t *testing.T) {
|
||||
tarFiles := map[string][]byte{"PG_VERSION": []byte("16")}
|
||||
zstdTarData := createZstdTar(t, tarFiles)
|
||||
walData := createZstdData(t, []byte("wal-segment-data"))
|
||||
|
||||
var mu sync.Mutex
|
||||
var walDownloadAttempts int
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testRestorePlanPath:
|
||||
writeJSON(w, api.GetRestorePlanResponse{
|
||||
FullBackup: api.RestorePlanFullBackup{
|
||||
BackupID: testFullBackupID,
|
||||
PgVersion: "16",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
SizeBytes: 1024,
|
||||
},
|
||||
WalSegments: []api.RestorePlanWalSegment{
|
||||
{BackupID: "wal-1", SegmentName: testWalSegment1, SizeBytes: 512},
|
||||
},
|
||||
TotalSizeBytes: 1536,
|
||||
LatestAvailableSegment: testWalSegment1,
|
||||
})
|
||||
|
||||
case testRestoreDownloadPath:
|
||||
backupID := r.URL.Query().Get("backupId")
|
||||
if backupID == testFullBackupID {
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write(zstdTarData)
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
walDownloadAttempts++
|
||||
attempt := walDownloadAttempts
|
||||
mu.Unlock()
|
||||
|
||||
if attempt == 1 {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte(`{"error":"storage unavailable"}`))
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write(walData)
|
||||
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
|
||||
|
||||
origDelay := retryDelayOverride
|
||||
testDelay := 10 * time.Millisecond
|
||||
retryDelayOverride = &testDelay
|
||||
defer func() { retryDelayOverride = origDelay }()
|
||||
|
||||
err := restorer.Run(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
mu.Lock()
|
||||
attempts := walDownloadAttempts
|
||||
mu.Unlock()
|
||||
|
||||
assert.Equal(t, 2, attempts)
|
||||
|
||||
walContent, err := os.ReadFile(filepath.Join(targetDir, walRestoreDir, testWalSegment1))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "wal-segment-data", string(walContent))
|
||||
}
|
||||
|
||||
func Test_DownloadWalSegment_WhenAllAttemptsFail_ReturnsErrorWithSegmentName(t *testing.T) {
|
||||
tarFiles := map[string][]byte{"PG_VERSION": []byte("16")}
|
||||
zstdTarData := createZstdTar(t, tarFiles)
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testRestorePlanPath:
|
||||
writeJSON(w, api.GetRestorePlanResponse{
|
||||
FullBackup: api.RestorePlanFullBackup{
|
||||
BackupID: testFullBackupID,
|
||||
PgVersion: "16",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
SizeBytes: 1024,
|
||||
},
|
||||
WalSegments: []api.RestorePlanWalSegment{
|
||||
{BackupID: "wal-1", SegmentName: testWalSegment1, SizeBytes: 512},
|
||||
},
|
||||
TotalSizeBytes: 1536,
|
||||
LatestAvailableSegment: testWalSegment1,
|
||||
})
|
||||
|
||||
case testRestoreDownloadPath:
|
||||
backupID := r.URL.Query().Get("backupId")
|
||||
if backupID == testFullBackupID {
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write(zstdTarData)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte(`{"error":"storage unavailable"}`))
|
||||
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
|
||||
|
||||
origDelay := retryDelayOverride
|
||||
testDelay := 10 * time.Millisecond
|
||||
retryDelayOverride = &testDelay
|
||||
defer func() { retryDelayOverride = origDelay }()
|
||||
|
||||
err := restorer.Run(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), testWalSegment1)
|
||||
assert.Contains(t, err.Error(), "3 attempts")
|
||||
}
|
||||
|
||||
func Test_RunRestore_WhenInvalidTargetTimeFormat_ReturnsError(t *testing.T) {
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer("http://localhost:0", targetDir, "", "not-a-valid-time", "")
|
||||
|
||||
err := restorer.Run(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid --target-time format")
|
||||
}
|
||||
|
||||
func Test_RunRestore_WhenBasebackupDownloadFails_ReturnsError(t *testing.T) {
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testRestorePlanPath:
|
||||
writeJSON(w, api.GetRestorePlanResponse{
|
||||
FullBackup: api.RestorePlanFullBackup{
|
||||
BackupID: testFullBackupID,
|
||||
PgVersion: "16",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
SizeBytes: 1024,
|
||||
},
|
||||
WalSegments: []api.RestorePlanWalSegment{},
|
||||
TotalSizeBytes: 1024,
|
||||
LatestAvailableSegment: "",
|
||||
})
|
||||
|
||||
case testRestoreDownloadPath:
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte(`{"error":"storage error"}`))
|
||||
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
|
||||
|
||||
err := restorer.Run(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "basebackup download failed")
|
||||
}
|
||||
|
||||
func Test_RunRestore_WhenNoWalSegmentsInPlan_BasebackupRestoredSuccessfully(t *testing.T) {
|
||||
tarFiles := map[string][]byte{
|
||||
"PG_VERSION": []byte("16"),
|
||||
"global/pg_control": []byte("control-data"),
|
||||
}
|
||||
zstdTarData := createZstdTar(t, tarFiles)
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testRestorePlanPath:
|
||||
writeJSON(w, api.GetRestorePlanResponse{
|
||||
FullBackup: api.RestorePlanFullBackup{
|
||||
BackupID: testFullBackupID,
|
||||
PgVersion: "16",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
SizeBytes: 1024,
|
||||
},
|
||||
WalSegments: []api.RestorePlanWalSegment{},
|
||||
TotalSizeBytes: 1024,
|
||||
LatestAvailableSegment: "",
|
||||
})
|
||||
|
||||
case testRestoreDownloadPath:
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write(zstdTarData)
|
||||
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
|
||||
|
||||
err := restorer.Run(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
pgVersionContent, err := os.ReadFile(filepath.Join(targetDir, "PG_VERSION"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "16", string(pgVersionContent))
|
||||
|
||||
walRestoreDirInfo, err := os.Stat(filepath.Join(targetDir, walRestoreDir))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, walRestoreDirInfo.IsDir())
|
||||
|
||||
_, err = os.Stat(filepath.Join(targetDir, "recovery.signal"))
|
||||
require.NoError(t, err)
|
||||
|
||||
autoConfContent, err := os.ReadFile(filepath.Join(targetDir, "postgresql.auto.conf"))
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, string(autoConfContent), "restore_command")
|
||||
}
|
||||
|
||||
func Test_RunRestore_WhenMakingApiCalls_AuthTokenIncludedInRequests(t *testing.T) {
|
||||
tarFiles := map[string][]byte{"PG_VERSION": []byte("16")}
|
||||
zstdTarData := createZstdTar(t, tarFiles)
|
||||
|
||||
var receivedAuthHeaders atomic.Int32
|
||||
var mu sync.Mutex
|
||||
var authHeaderValues []string
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader != "" {
|
||||
receivedAuthHeaders.Add(1)
|
||||
|
||||
mu.Lock()
|
||||
authHeaderValues = append(authHeaderValues, authHeader)
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
switch r.URL.Path {
|
||||
case testRestorePlanPath:
|
||||
writeJSON(w, api.GetRestorePlanResponse{
|
||||
FullBackup: api.RestorePlanFullBackup{
|
||||
BackupID: testFullBackupID,
|
||||
PgVersion: "16",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
SizeBytes: 1024,
|
||||
},
|
||||
WalSegments: []api.RestorePlanWalSegment{},
|
||||
TotalSizeBytes: 1024,
|
||||
LatestAvailableSegment: "",
|
||||
})
|
||||
|
||||
case testRestoreDownloadPath:
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write(zstdTarData)
|
||||
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "")
|
||||
|
||||
err := restorer.Run(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.GreaterOrEqual(t, int(receivedAuthHeaders.Load()), 2)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
for _, headerValue := range authHeaderValues {
|
||||
assert.Equal(t, "test-token", headerValue)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ConfigurePostgresRecovery_WhenPgTypeHost_UsesHostAbsolutePath(t *testing.T) {
|
||||
tarFiles := map[string][]byte{"PG_VERSION": []byte("16")}
|
||||
zstdTarData := createZstdTar(t, tarFiles)
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testRestorePlanPath:
|
||||
writeJSON(w, api.GetRestorePlanResponse{
|
||||
FullBackup: api.RestorePlanFullBackup{
|
||||
BackupID: testFullBackupID,
|
||||
PgVersion: "16",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
SizeBytes: 1024,
|
||||
},
|
||||
WalSegments: []api.RestorePlanWalSegment{},
|
||||
TotalSizeBytes: 1024,
|
||||
LatestAvailableSegment: "",
|
||||
})
|
||||
|
||||
case testRestoreDownloadPath:
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write(zstdTarData)
|
||||
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "host")
|
||||
|
||||
err := restorer.Run(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
autoConfContent, err := os.ReadFile(filepath.Join(targetDir, "postgresql.auto.conf"))
|
||||
require.NoError(t, err)
|
||||
autoConfStr := string(autoConfContent)
|
||||
|
||||
absTargetDir, _ := filepath.Abs(targetDir)
|
||||
absTargetDir = filepath.ToSlash(absTargetDir)
|
||||
expectedWalPath := absTargetDir + "/" + walRestoreDir
|
||||
|
||||
assert.Contains(t, autoConfStr, fmt.Sprintf("restore_command = 'cp %s/%%f %%p'", expectedWalPath))
|
||||
assert.Contains(t, autoConfStr, fmt.Sprintf("recovery_end_command = 'rm -rf %s'", expectedWalPath))
|
||||
assert.NotContains(t, autoConfStr, "/var/lib/postgresql/data")
|
||||
}
|
||||
|
||||
func Test_ConfigurePostgresRecovery_WhenPgTypeDocker_UsesContainerPath(t *testing.T) {
|
||||
tarFiles := map[string][]byte{"PG_VERSION": []byte("16")}
|
||||
zstdTarData := createZstdTar(t, tarFiles)
|
||||
|
||||
server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case testRestorePlanPath:
|
||||
writeJSON(w, api.GetRestorePlanResponse{
|
||||
FullBackup: api.RestorePlanFullBackup{
|
||||
BackupID: testFullBackupID,
|
||||
PgVersion: "16",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
SizeBytes: 1024,
|
||||
},
|
||||
WalSegments: []api.RestorePlanWalSegment{},
|
||||
TotalSizeBytes: 1024,
|
||||
LatestAvailableSegment: "",
|
||||
})
|
||||
|
||||
case testRestoreDownloadPath:
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write(zstdTarData)
|
||||
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
targetDir := createTestTargetDir(t)
|
||||
restorer := newTestRestorer(server.URL, targetDir, "", "", "docker")
|
||||
|
||||
err := restorer.Run(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
autoConfContent, err := os.ReadFile(filepath.Join(targetDir, "postgresql.auto.conf"))
|
||||
require.NoError(t, err)
|
||||
autoConfStr := string(autoConfContent)
|
||||
|
||||
expectedWalPath := "/var/lib/postgresql/data/" + walRestoreDir
|
||||
|
||||
assert.Contains(t, autoConfStr, fmt.Sprintf("restore_command = 'cp %s/%%f %%p'", expectedWalPath))
|
||||
assert.Contains(t, autoConfStr, fmt.Sprintf("recovery_end_command = 'rm -rf %s'", expectedWalPath))
|
||||
|
||||
absTargetDir, _ := filepath.Abs(targetDir)
|
||||
absTargetDir = filepath.ToSlash(absTargetDir)
|
||||
assert.NotContains(t, autoConfStr, absTargetDir)
|
||||
}
|
||||
|
||||
func newTestServer(t *testing.T, handler http.HandlerFunc) *httptest.Server {
|
||||
t.Helper()
|
||||
|
||||
server := httptest.NewServer(handler)
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
func createTestTargetDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
baseDir := filepath.Join(".", ".test-tmp")
|
||||
if err := os.MkdirAll(baseDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create base test dir: %v", err)
|
||||
}
|
||||
|
||||
dir, err := os.MkdirTemp(baseDir, t.Name()+"-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test target dir: %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(dir)
|
||||
})
|
||||
|
||||
return dir
|
||||
}
|
||||
|
||||
func createZstdTar(t *testing.T, files map[string][]byte) []byte {
|
||||
t.Helper()
|
||||
|
||||
var tarBuffer bytes.Buffer
|
||||
tarWriter := tar.NewWriter(&tarBuffer)
|
||||
|
||||
createdDirs := make(map[string]bool)
|
||||
|
||||
for name, content := range files {
|
||||
dir := filepath.Dir(name)
|
||||
if dir != "." && !createdDirs[dir] {
|
||||
parts := strings.Split(filepath.ToSlash(dir), "/")
|
||||
for partIndex := range parts {
|
||||
partialDir := strings.Join(parts[:partIndex+1], "/")
|
||||
if !createdDirs[partialDir] {
|
||||
err := tarWriter.WriteHeader(&tar.Header{
|
||||
Name: partialDir + "/",
|
||||
Typeflag: tar.TypeDir,
|
||||
Mode: 0o755,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
createdDirs[partialDir] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err := tarWriter.WriteHeader(&tar.Header{
|
||||
Name: name,
|
||||
Size: int64(len(content)),
|
||||
Mode: 0o644,
|
||||
Typeflag: tar.TypeReg,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = tarWriter.Write(content)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.NoError(t, tarWriter.Close())
|
||||
|
||||
var zstdBuffer bytes.Buffer
|
||||
|
||||
encoder, err := zstd.NewWriter(&zstdBuffer,
|
||||
zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(5)),
|
||||
zstd.WithEncoderCRC(true),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = encoder.Write(tarBuffer.Bytes())
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, encoder.Close())
|
||||
|
||||
return zstdBuffer.Bytes()
|
||||
}
|
||||
|
||||
func createZstdData(t *testing.T, data []byte) []byte {
|
||||
t.Helper()
|
||||
|
||||
var buffer bytes.Buffer
|
||||
|
||||
encoder, err := zstd.NewWriter(&buffer,
|
||||
zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(5)),
|
||||
zstd.WithEncoderCRC(true),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = encoder.Write(data)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, encoder.Close())
|
||||
|
||||
return buffer.Bytes()
|
||||
}
|
||||
|
||||
func newTestRestorer(serverURL, targetPgDataDir, backupID, targetTime, pgType string) *Restorer {
|
||||
apiClient := api.NewClient(serverURL, "test-token", logger.GetLogger())
|
||||
|
||||
return NewRestorer(apiClient, logger.GetLogger(), targetPgDataDir, backupID, targetTime, pgType)
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, value any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if err := json.NewEncoder(w).Encode(value); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
121
agent/internal/features/start/daemon.go
Normal file
121
agent/internal/features/start/daemon.go
Normal file
@@ -0,0 +1,121 @@
|
||||
//go:build !windows
|
||||
|
||||
package start
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
logFileName = "databasus.log"
|
||||
stopTimeout = 30 * time.Second
|
||||
stopPollInterval = 500 * time.Millisecond
|
||||
daemonStartupDelay = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
func Stop(log *slog.Logger) error {
|
||||
pid, err := ReadLockFilePID()
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return errors.New("agent is not running (no lock file found)")
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to read lock file: %w", err)
|
||||
}
|
||||
|
||||
if !isProcessAlive(pid) {
|
||||
_ = os.Remove(lockFileName)
|
||||
return fmt.Errorf("agent is not running (stale lock file removed, PID %d)", pid)
|
||||
}
|
||||
|
||||
log.Info("Sending SIGTERM to agent", "pid", pid)
|
||||
|
||||
if err := syscall.Kill(pid, syscall.SIGTERM); err != nil {
|
||||
return fmt.Errorf("failed to send SIGTERM to PID %d: %w", pid, err)
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(stopTimeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if !isProcessAlive(pid) {
|
||||
log.Info("Agent stopped", "pid", pid)
|
||||
return nil
|
||||
}
|
||||
|
||||
time.Sleep(stopPollInterval)
|
||||
}
|
||||
|
||||
return fmt.Errorf("agent (PID %d) did not stop within %s — process may be stuck", pid, stopTimeout)
|
||||
}
|
||||
|
||||
func Status(log *slog.Logger) error {
|
||||
pid, err := ReadLockFilePID()
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
fmt.Println("Agent is not running")
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to read lock file: %w", err)
|
||||
}
|
||||
|
||||
if isProcessAlive(pid) {
|
||||
fmt.Printf("Agent is running (PID %d)\n", pid)
|
||||
} else {
|
||||
fmt.Println("Agent is not running (stale lock file)")
|
||||
_ = os.Remove(lockFileName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func spawnDaemon(log *slog.Logger) (int, error) {
|
||||
execPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to resolve executable path: %w", err)
|
||||
}
|
||||
|
||||
args := []string{"_run"}
|
||||
|
||||
logFile, err := os.OpenFile(logFileName, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to open log file %s: %w", logFileName, err)
|
||||
}
|
||||
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
_ = logFile.Close()
|
||||
return 0, fmt.Errorf("failed to get working directory: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(context.Background(), execPath, args...)
|
||||
cmd.Dir = cwd
|
||||
cmd.Stderr = logFile
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
_ = logFile.Close()
|
||||
return 0, fmt.Errorf("failed to start daemon process: %w", err)
|
||||
}
|
||||
|
||||
pid := cmd.Process.Pid
|
||||
|
||||
// Detach — we don't wait for the child
|
||||
_ = logFile.Close()
|
||||
|
||||
time.Sleep(daemonStartupDelay)
|
||||
|
||||
if !isProcessAlive(pid) {
|
||||
return 0, fmt.Errorf("daemon process (PID %d) exited immediately — check %s for details", pid, logFileName)
|
||||
}
|
||||
|
||||
log.Info("Daemon spawned", "pid", pid, "log", logFileName)
|
||||
|
||||
return pid, nil
|
||||
}
|
||||
20
agent/internal/features/start/daemon_windows.go
Normal file
20
agent/internal/features/start/daemon_windows.go
Normal file
@@ -0,0 +1,20 @@
|
||||
//go:build windows
|
||||
|
||||
package start
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
func Stop(log *slog.Logger) error {
|
||||
return errors.New("stop is not supported on Windows — use Ctrl+C in the terminal where the agent is running")
|
||||
}
|
||||
|
||||
func Status(log *slog.Logger) error {
|
||||
return errors.New("status is not supported on Windows — check the terminal where the agent is running")
|
||||
}
|
||||
|
||||
func spawnDaemon(_ *slog.Logger) (int, error) {
|
||||
return 0, errors.New("daemon mode is not supported on Windows")
|
||||
}
|
||||
132
agent/internal/features/start/lock.go
Normal file
132
agent/internal/features/start/lock.go
Normal file
@@ -0,0 +1,132 @@
|
||||
//go:build !windows
|
||||
|
||||
package start
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
const lockFileName = "databasus.lock"
|
||||
|
||||
func AcquireLock(log *slog.Logger) (*os.File, error) {
|
||||
f, err := os.OpenFile(lockFileName, os.O_CREATE|os.O_RDWR, 0o644)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open lock file: %w", err)
|
||||
}
|
||||
|
||||
err = syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB)
|
||||
if err == nil {
|
||||
if err := writePID(f); err != nil {
|
||||
_ = f.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Info("Process lock acquired", "pid", os.Getpid(), "lockFile", lockFileName)
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
if !errors.Is(err, syscall.EWOULDBLOCK) {
|
||||
_ = f.Close()
|
||||
return nil, fmt.Errorf("failed to acquire lock: %w", err)
|
||||
}
|
||||
|
||||
pid, pidErr := readLockPID(f)
|
||||
_ = f.Close()
|
||||
|
||||
if pidErr != nil {
|
||||
return nil, fmt.Errorf("another instance is already running")
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("another instance is already running (PID %d)", pid)
|
||||
}
|
||||
|
||||
func ReleaseLock(f *os.File) {
|
||||
_ = syscall.Flock(int(f.Fd()), syscall.LOCK_UN)
|
||||
|
||||
lockedStat, lockedErr := f.Stat()
|
||||
_ = f.Close()
|
||||
|
||||
if lockedErr != nil {
|
||||
_ = os.Remove(lockFileName)
|
||||
return
|
||||
}
|
||||
|
||||
diskStat, diskErr := os.Stat(lockFileName)
|
||||
if diskErr != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if os.SameFile(lockedStat, diskStat) {
|
||||
_ = os.Remove(lockFileName)
|
||||
}
|
||||
}
|
||||
|
||||
func ReadLockFilePID() (int, error) {
|
||||
f, err := os.Open(lockFileName)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
return readLockPID(f)
|
||||
}
|
||||
|
||||
func writePID(f *os.File) error {
|
||||
if err := f.Truncate(0); err != nil {
|
||||
return fmt.Errorf("failed to truncate lock file: %w", err)
|
||||
}
|
||||
|
||||
if _, err := f.Seek(0, io.SeekStart); err != nil {
|
||||
return fmt.Errorf("failed to seek lock file: %w", err)
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(f, "%d\n", os.Getpid()); err != nil {
|
||||
return fmt.Errorf("failed to write PID to lock file: %w", err)
|
||||
}
|
||||
|
||||
return f.Sync()
|
||||
}
|
||||
|
||||
func readLockPID(f *os.File) (int, error) {
|
||||
if _, err := f.Seek(0, io.SeekStart); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
s := strings.TrimSpace(string(data))
|
||||
if s == "" {
|
||||
return 0, errors.New("lock file is empty")
|
||||
}
|
||||
|
||||
pid, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid PID in lock file: %w", err)
|
||||
}
|
||||
|
||||
return pid, nil
|
||||
}
|
||||
|
||||
func isProcessAlive(pid int) bool {
|
||||
err := syscall.Kill(pid, 0)
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if errors.Is(err, syscall.EPERM) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
148
agent/internal/features/start/lock_test.go
Normal file
148
agent/internal/features/start/lock_test.go
Normal file
@@ -0,0 +1,148 @@
|
||||
//go:build !windows
|
||||
|
||||
package start
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"databasus-agent/internal/logger"
|
||||
)
|
||||
|
||||
func Test_AcquireLock_LockFileCreatedWithPID(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
log := logger.GetLogger()
|
||||
|
||||
lockFile, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
data, err := os.ReadFile(lockFileName)
|
||||
require.NoError(t, err)
|
||||
|
||||
pid, err := strconv.Atoi(strings.TrimSpace(string(data)))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, os.Getpid(), pid)
|
||||
}
|
||||
|
||||
func Test_AcquireLock_SecondAcquireFails_WhenFirstHeld(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
log := logger.GetLogger()
|
||||
|
||||
first, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(first)
|
||||
|
||||
second, err := AcquireLock(log)
|
||||
assert.Nil(t, second)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "another instance is already running")
|
||||
assert.Contains(t, err.Error(), fmt.Sprintf("PID %d", os.Getpid()))
|
||||
}
|
||||
|
||||
func Test_AcquireLock_StaleLockReacquired_WhenProcessDead(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
log := logger.GetLogger()
|
||||
|
||||
err := os.WriteFile(lockFileName, []byte("999999999\n"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
lockFile, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
data, err := os.ReadFile(lockFileName)
|
||||
require.NoError(t, err)
|
||||
|
||||
pid, err := strconv.Atoi(strings.TrimSpace(string(data)))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, os.Getpid(), pid)
|
||||
}
|
||||
|
||||
func Test_ReleaseLock_LockFileRemoved(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
log := logger.GetLogger()
|
||||
|
||||
lockFile, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
|
||||
ReleaseLock(lockFile)
|
||||
|
||||
_, err = os.Stat(lockFileName)
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
}
|
||||
|
||||
func Test_AcquireLock_ReacquiredAfterRelease(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
log := logger.GetLogger()
|
||||
|
||||
first, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
ReleaseLock(first)
|
||||
|
||||
second, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(second)
|
||||
|
||||
data, err := os.ReadFile(lockFileName)
|
||||
require.NoError(t, err)
|
||||
|
||||
pid, err := strconv.Atoi(strings.TrimSpace(string(data)))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, os.Getpid(), pid)
|
||||
}
|
||||
|
||||
func Test_isProcessAlive_ReturnsTrueForSelf(t *testing.T) {
|
||||
assert.True(t, isProcessAlive(os.Getpid()))
|
||||
}
|
||||
|
||||
func Test_isProcessAlive_ReturnsFalseForNonExistentPID(t *testing.T) {
|
||||
assert.False(t, isProcessAlive(999999999))
|
||||
}
|
||||
|
||||
func Test_readLockPID_ParsesValidPID(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
|
||||
f, err := os.CreateTemp("", "lock-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(f.Name())
|
||||
|
||||
_, err = f.WriteString("12345\n")
|
||||
require.NoError(t, err)
|
||||
|
||||
pid, err := readLockPID(f)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 12345, pid)
|
||||
}
|
||||
|
||||
func Test_readLockPID_ReturnsErrorForEmptyFile(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
|
||||
f, err := os.CreateTemp("", "lock-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(f.Name())
|
||||
|
||||
_, err = readLockPID(f)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "lock file is empty")
|
||||
}
|
||||
|
||||
func setupTempDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
origDir, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
|
||||
dir := t.TempDir()
|
||||
require.NoError(t, os.Chdir(dir))
|
||||
|
||||
t.Cleanup(func() { _ = os.Chdir(origDir) })
|
||||
|
||||
return dir
|
||||
}
|
||||
90
agent/internal/features/start/lock_watcher.go
Normal file
90
agent/internal/features/start/lock_watcher.go
Normal file
@@ -0,0 +1,90 @@
|
||||
//go:build !windows
|
||||
|
||||
package start
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
const lockWatchInterval = 5 * time.Second
|
||||
|
||||
type LockWatcher struct {
|
||||
originalInode uint64
|
||||
cancel context.CancelFunc
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
func NewLockWatcher(lockFile *os.File, cancel context.CancelFunc, log *slog.Logger) (*LockWatcher, error) {
|
||||
inode, err := getFileInode(lockFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &LockWatcher{
|
||||
originalInode: inode,
|
||||
cancel: cancel,
|
||||
log: log,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *LockWatcher) Run(ctx context.Context) {
|
||||
ticker := time.NewTicker(lockWatchInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
w.check()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *LockWatcher) check() {
|
||||
info, err := os.Stat(lockFileName)
|
||||
if err != nil {
|
||||
w.log.Error("Lock file disappeared, shutting down", "file", lockFileName, "error", err)
|
||||
w.cancel()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
currentInode, err := getStatInode(info)
|
||||
if err != nil {
|
||||
w.log.Error("Failed to read lock file inode, shutting down", "error", err)
|
||||
w.cancel()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if currentInode != w.originalInode {
|
||||
w.log.Error("Lock file was replaced (inode changed), shutting down",
|
||||
"originalInode", w.originalInode,
|
||||
"currentInode", currentInode,
|
||||
)
|
||||
w.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
func getFileInode(f *os.File) (uint64, error) {
|
||||
info, err := f.Stat()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return getStatInode(info)
|
||||
}
|
||||
|
||||
func getStatInode(info os.FileInfo) (uint64, error) {
|
||||
stat, ok := info.Sys().(*syscall.Stat_t)
|
||||
if !ok {
|
||||
return 0, os.ErrInvalid
|
||||
}
|
||||
|
||||
return stat.Ino, nil
|
||||
}
|
||||
110
agent/internal/features/start/lock_watcher_test.go
Normal file
110
agent/internal/features/start/lock_watcher_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
//go:build !windows
|
||||
|
||||
package start
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"databasus-agent/internal/logger"
|
||||
)
|
||||
|
||||
func Test_NewLockWatcher_CapturesInode(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
log := logger.GetLogger()
|
||||
|
||||
lockFile, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
_, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
watcher, err := NewLockWatcher(lockFile, cancel, log)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, watcher.originalInode)
|
||||
}
|
||||
|
||||
func Test_LockWatcher_FileUnchanged_ContextNotCancelled(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
log := logger.GetLogger()
|
||||
|
||||
lockFile, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
watcher, err := NewLockWatcher(lockFile, cancel, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
watcher.check()
|
||||
watcher.check()
|
||||
watcher.check()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("context should not be cancelled when lock file is unchanged")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func Test_LockWatcher_FileDeleted_CancelsContext(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
log := logger.GetLogger()
|
||||
|
||||
lockFile, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
watcher, err := NewLockWatcher(lockFile, cancel, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.Remove(lockFileName)
|
||||
require.NoError(t, err)
|
||||
|
||||
watcher.check()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
default:
|
||||
t.Fatal("context should be cancelled when lock file is deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_LockWatcher_FileReplacedWithDifferentInode_CancelsContext(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
log := logger.GetLogger()
|
||||
|
||||
lockFile, err := AcquireLock(log)
|
||||
require.NoError(t, err)
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
watcher, err := NewLockWatcher(lockFile, cancel, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.Remove(lockFileName)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(lockFileName, []byte("99999\n"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
watcher.check()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
default:
|
||||
t.Fatal("context should be cancelled when lock file inode changes")
|
||||
}
|
||||
}
|
||||
17
agent/internal/features/start/lock_watcher_windows.go
Normal file
17
agent/internal/features/start/lock_watcher_windows.go
Normal file
@@ -0,0 +1,17 @@
|
||||
//go:build windows
|
||||
|
||||
package start
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
)
|
||||
|
||||
type LockWatcher struct{}
|
||||
|
||||
func NewLockWatcher(_ *os.File, _ context.CancelFunc, _ *slog.Logger) (*LockWatcher, error) {
|
||||
return &LockWatcher{}, nil
|
||||
}
|
||||
|
||||
func (w *LockWatcher) Run(_ context.Context) {}
|
||||
18
agent/internal/features/start/lock_windows.go
Normal file
18
agent/internal/features/start/lock_windows.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package start
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
)
|
||||
|
||||
func AcquireLock(log *slog.Logger) (*os.File, error) {
|
||||
log.Warn("Process locking is not supported on Windows, skipping")
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func ReleaseLock(f *os.File) {
|
||||
if f != nil {
|
||||
_ = f.Close()
|
||||
}
|
||||
}
|
||||
325
agent/internal/features/start/start.go
Normal file
325
agent/internal/features/start/start.go
Normal file
@@ -0,0 +1,325 @@
|
||||
package start
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"databasus-agent/internal/config"
|
||||
"databasus-agent/internal/features/api"
|
||||
full_backup "databasus-agent/internal/features/full_backup"
|
||||
"databasus-agent/internal/features/upgrade"
|
||||
"databasus-agent/internal/features/wal"
|
||||
)
|
||||
|
||||
const (
|
||||
pgBasebackupVerifyTimeout = 10 * time.Second
|
||||
dbVerifyTimeout = 10 * time.Second
|
||||
minPgMajorVersion = 15
|
||||
)
|
||||
|
||||
func Start(cfg *config.Config, agentVersion string, isDev bool, log *slog.Logger) error {
|
||||
if err := validateConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := verifyPgBasebackup(cfg, log); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := verifyDatabase(cfg, log); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
return RunDaemon(cfg, agentVersion, isDev, log)
|
||||
}
|
||||
|
||||
pid, err := spawnDaemon(log)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("Agent started in background (PID %d)\n", pid)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func RunDaemon(cfg *config.Config, agentVersion string, isDev bool, log *slog.Logger) error {
|
||||
lockFile, err := AcquireLock(log)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer ReleaseLock(lockFile)
|
||||
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer cancel()
|
||||
|
||||
watcher, err := NewLockWatcher(lockFile, cancel, log)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize lock watcher: %w", err)
|
||||
}
|
||||
go watcher.Run(ctx)
|
||||
|
||||
apiClient := api.NewClient(cfg.DatabasusHost, cfg.Token, log)
|
||||
|
||||
var backgroundUpgrader *upgrade.BackgroundUpgrader
|
||||
if agentVersion != "dev" && runtime.GOOS != "windows" {
|
||||
backgroundUpgrader = upgrade.NewBackgroundUpgrader(apiClient, agentVersion, isDev, cancel, log)
|
||||
go backgroundUpgrader.Run(ctx)
|
||||
}
|
||||
|
||||
fullBackuper := full_backup.NewFullBackuper(cfg, apiClient, log)
|
||||
go fullBackuper.Run(ctx)
|
||||
|
||||
streamer := wal.NewStreamer(cfg, apiClient, log)
|
||||
streamer.Run(ctx)
|
||||
|
||||
if backgroundUpgrader != nil {
|
||||
backgroundUpgrader.WaitForCompletion(30 * time.Second)
|
||||
|
||||
if backgroundUpgrader.IsUpgraded() {
|
||||
return upgrade.ErrUpgradeRestart
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("Agent stopped")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateConfig(cfg *config.Config) error {
|
||||
if cfg.DatabasusHost == "" {
|
||||
return errors.New("argument databasus-host is required")
|
||||
}
|
||||
|
||||
if cfg.DbID == "" {
|
||||
return errors.New("argument db-id is required")
|
||||
}
|
||||
|
||||
if cfg.Token == "" {
|
||||
return errors.New("argument token is required")
|
||||
}
|
||||
|
||||
if cfg.PgHost == "" {
|
||||
return errors.New("argument pg-host is required")
|
||||
}
|
||||
|
||||
if cfg.PgPort <= 0 {
|
||||
return errors.New("argument pg-port must be a positive number")
|
||||
}
|
||||
|
||||
if cfg.PgUser == "" {
|
||||
return errors.New("argument pg-user is required")
|
||||
}
|
||||
|
||||
if cfg.PgType != "host" && cfg.PgType != "docker" {
|
||||
return fmt.Errorf("argument pg-type must be 'host' or 'docker', got '%s'", cfg.PgType)
|
||||
}
|
||||
|
||||
if cfg.PgWalDir == "" {
|
||||
return errors.New("argument pg-wal-dir is required")
|
||||
}
|
||||
|
||||
if cfg.PgType == "docker" && cfg.PgDockerContainerName == "" {
|
||||
return errors.New("argument pg-docker-container-name is required when pg-type is 'docker'")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyPgBasebackup(cfg *config.Config, log *slog.Logger) error {
|
||||
switch cfg.PgType {
|
||||
case "host":
|
||||
return verifyPgBasebackupHost(cfg, log)
|
||||
case "docker":
|
||||
return verifyPgBasebackupDocker(cfg, log)
|
||||
default:
|
||||
return fmt.Errorf("unexpected pg-type: %s", cfg.PgType)
|
||||
}
|
||||
}
|
||||
|
||||
func verifyPgBasebackupHost(cfg *config.Config, log *slog.Logger) error {
|
||||
binary := "pg_basebackup"
|
||||
if cfg.PgHostBinDir != "" {
|
||||
binary = filepath.Join(cfg.PgHostBinDir, "pg_basebackup")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgBasebackupVerifyTimeout)
|
||||
defer cancel()
|
||||
|
||||
output, err := exec.CommandContext(ctx, binary, "--version").CombinedOutput()
|
||||
if err != nil {
|
||||
if cfg.PgHostBinDir != "" {
|
||||
return fmt.Errorf(
|
||||
"pg_basebackup not found at '%s': %w. Verify pg-host-bin-dir is correct",
|
||||
binary, err,
|
||||
)
|
||||
}
|
||||
|
||||
return fmt.Errorf(
|
||||
"pg_basebackup not found in PATH: %w. Install PostgreSQL client tools or set pg-host-bin-dir",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
log.Info("pg_basebackup verified", "version", strings.TrimSpace(string(output)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyPgBasebackupDocker(cfg *config.Config, log *slog.Logger) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgBasebackupVerifyTimeout)
|
||||
defer cancel()
|
||||
|
||||
output, err := exec.CommandContext(ctx,
|
||||
"docker", "exec", cfg.PgDockerContainerName,
|
||||
"pg_basebackup", "--version",
|
||||
).CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"pg_basebackup not available in container '%s': %w. "+
|
||||
"Check that the container is running and pg_basebackup is installed inside it",
|
||||
cfg.PgDockerContainerName, err,
|
||||
)
|
||||
}
|
||||
|
||||
log.Info("pg_basebackup verified (docker)",
|
||||
"container", cfg.PgDockerContainerName,
|
||||
"version", strings.TrimSpace(string(output)),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyDatabase(cfg *config.Config, log *slog.Logger) error {
|
||||
switch cfg.PgType {
|
||||
case "docker":
|
||||
return verifyDatabaseDocker(cfg, log)
|
||||
default:
|
||||
return verifyDatabaseHost(cfg, log)
|
||||
}
|
||||
}
|
||||
|
||||
func verifyDatabaseHost(cfg *config.Config, log *slog.Logger) error {
|
||||
connStr := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=postgres sslmode=disable",
|
||||
cfg.PgHost, cfg.PgPort, cfg.PgUser, cfg.PgPassword,
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dbVerifyTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pgx.Connect(ctx, connStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to connect to PostgreSQL at %s:%d as user '%s': %w",
|
||||
cfg.PgHost, cfg.PgPort, cfg.PgUser, err,
|
||||
)
|
||||
}
|
||||
defer func() { _ = conn.Close(ctx) }()
|
||||
|
||||
if err := conn.Ping(ctx); err != nil {
|
||||
return fmt.Errorf("PostgreSQL ping failed at %s:%d: %w",
|
||||
cfg.PgHost, cfg.PgPort, err,
|
||||
)
|
||||
}
|
||||
|
||||
var versionNumStr string
|
||||
if err := conn.QueryRow(ctx, "SHOW server_version_num").Scan(&versionNumStr); err != nil {
|
||||
return fmt.Errorf("failed to query PostgreSQL version: %w", err)
|
||||
}
|
||||
|
||||
majorVersion, err := parsePgVersionNum(versionNumStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse PostgreSQL version '%s': %w", versionNumStr, err)
|
||||
}
|
||||
|
||||
if majorVersion < minPgMajorVersion {
|
||||
return fmt.Errorf(
|
||||
"PostgreSQL %d is not supported, minimum required version is %d",
|
||||
majorVersion, minPgMajorVersion,
|
||||
)
|
||||
}
|
||||
|
||||
log.Info("PostgreSQL connection verified",
|
||||
"host", cfg.PgHost,
|
||||
"port", cfg.PgPort,
|
||||
"user", cfg.PgUser,
|
||||
"version", majorVersion,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyDatabaseDocker(cfg *config.Config, log *slog.Logger) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dbVerifyTimeout)
|
||||
defer cancel()
|
||||
|
||||
query := "SELECT current_setting('server_version_num')"
|
||||
|
||||
cmd := exec.CommandContext(ctx,
|
||||
"docker", "exec",
|
||||
"-e", "PGPASSWORD="+cfg.PgPassword,
|
||||
cfg.PgDockerContainerName,
|
||||
"psql", "-h", "localhost", "-p", "5432", "-U", cfg.PgUser,
|
||||
"-d", "postgres", "-t", "-A", "-c", query,
|
||||
)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to connect to PostgreSQL in container '%s' as user '%s': %w (output: %s)",
|
||||
cfg.PgDockerContainerName, cfg.PgUser, err, strings.TrimSpace(string(output)),
|
||||
)
|
||||
}
|
||||
|
||||
versionNumStr := strings.TrimSpace(string(output))
|
||||
|
||||
majorVersion, err := parsePgVersionNum(versionNumStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse PostgreSQL version '%s': %w", versionNumStr, err)
|
||||
}
|
||||
|
||||
if majorVersion < minPgMajorVersion {
|
||||
return fmt.Errorf(
|
||||
"PostgreSQL %d is not supported, minimum required version is %d",
|
||||
majorVersion, minPgMajorVersion,
|
||||
)
|
||||
}
|
||||
|
||||
log.Info("PostgreSQL connection verified (docker)",
|
||||
"container", cfg.PgDockerContainerName,
|
||||
"user", cfg.PgUser,
|
||||
"version", majorVersion,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parsePgVersionNum(versionNumStr string) (int, error) {
|
||||
versionNum, err := strconv.Atoi(strings.TrimSpace(versionNumStr))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid version number: %w", err)
|
||||
}
|
||||
|
||||
if versionNum <= 0 {
|
||||
return 0, fmt.Errorf("invalid version number: %d", versionNum)
|
||||
}
|
||||
|
||||
majorVersion := versionNum / 10000
|
||||
|
||||
return majorVersion, nil
|
||||
}
|
||||
84
agent/internal/features/start/start_test.go
Normal file
84
agent/internal/features/start/start_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package start
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_ParsePgVersionNum_SupportedVersions_ReturnsMajorVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
versionNumStr string
|
||||
expectedMajor int
|
||||
}{
|
||||
{name: "PG 15.0", versionNumStr: "150000", expectedMajor: 15},
|
||||
{name: "PG 15.4", versionNumStr: "150004", expectedMajor: 15},
|
||||
{name: "PG 16.0", versionNumStr: "160000", expectedMajor: 16},
|
||||
{name: "PG 16.3", versionNumStr: "160003", expectedMajor: 16},
|
||||
{name: "PG 17.2", versionNumStr: "170002", expectedMajor: 17},
|
||||
{name: "PG 18.0", versionNumStr: "180000", expectedMajor: 18},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
major, err := parsePgVersionNum(tt.versionNumStr)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedMajor, major)
|
||||
assert.GreaterOrEqual(t, major, minPgMajorVersion)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ParsePgVersionNum_UnsupportedVersions_ReturnsMajorVersionBelow15(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
versionNumStr string
|
||||
expectedMajor int
|
||||
}{
|
||||
{name: "PG 12.5", versionNumStr: "120005", expectedMajor: 12},
|
||||
{name: "PG 13.0", versionNumStr: "130000", expectedMajor: 13},
|
||||
{name: "PG 14.12", versionNumStr: "140012", expectedMajor: 14},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
major, err := parsePgVersionNum(tt.versionNumStr)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedMajor, major)
|
||||
assert.Less(t, major, minPgMajorVersion)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ParsePgVersionNum_InvalidInput_ReturnsError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
versionNumStr string
|
||||
}{
|
||||
{name: "empty string", versionNumStr: ""},
|
||||
{name: "non-numeric", versionNumStr: "abc"},
|
||||
{name: "negative number", versionNumStr: "-1"},
|
||||
{name: "zero", versionNumStr: "0"},
|
||||
{name: "float", versionNumStr: "15.4"},
|
||||
{name: "whitespace only", versionNumStr: " "},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := parsePgVersionNum(tt.versionNumStr)
|
||||
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ParsePgVersionNum_WithWhitespace_ParsesCorrectly(t *testing.T) {
|
||||
major, err := parsePgVersionNum(" 150004 ")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 15, major)
|
||||
}
|
||||
88
agent/internal/features/upgrade/background_upgrader.go
Normal file
88
agent/internal/features/upgrade/background_upgrader.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package upgrade
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"databasus-agent/internal/features/api"
|
||||
)
|
||||
|
||||
const backgroundCheckInterval = 10 * time.Second
|
||||
|
||||
type BackgroundUpgrader struct {
|
||||
apiClient *api.Client
|
||||
currentVersion string
|
||||
isDev bool
|
||||
cancel context.CancelFunc
|
||||
isUpgraded atomic.Bool
|
||||
log *slog.Logger
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func NewBackgroundUpgrader(
|
||||
apiClient *api.Client,
|
||||
currentVersion string,
|
||||
isDev bool,
|
||||
cancel context.CancelFunc,
|
||||
log *slog.Logger,
|
||||
) *BackgroundUpgrader {
|
||||
return &BackgroundUpgrader{
|
||||
apiClient,
|
||||
currentVersion,
|
||||
isDev,
|
||||
cancel,
|
||||
atomic.Bool{},
|
||||
log,
|
||||
make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (u *BackgroundUpgrader) Run(ctx context.Context) {
|
||||
defer close(u.done)
|
||||
|
||||
ticker := time.NewTicker(backgroundCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if u.checkAndUpgrade() {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *BackgroundUpgrader) IsUpgraded() bool {
|
||||
return u.isUpgraded.Load()
|
||||
}
|
||||
|
||||
func (u *BackgroundUpgrader) WaitForCompletion(timeout time.Duration) {
|
||||
select {
|
||||
case <-u.done:
|
||||
case <-time.After(timeout):
|
||||
}
|
||||
}
|
||||
|
||||
func (u *BackgroundUpgrader) checkAndUpgrade() bool {
|
||||
isUpgraded, err := CheckAndUpdate(u.apiClient, u.currentVersion, u.isDev, u.log)
|
||||
if err != nil {
|
||||
u.log.Warn("Background update check failed", "error", err)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
if !isUpgraded {
|
||||
return false
|
||||
}
|
||||
|
||||
u.log.Info("Background upgrade complete, restarting...")
|
||||
u.isUpgraded.Store(true)
|
||||
u.cancel()
|
||||
|
||||
return true
|
||||
}
|
||||
5
agent/internal/features/upgrade/errors.go
Normal file
5
agent/internal/features/upgrade/errors.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package upgrade
|
||||
|
||||
import "errors"
|
||||
|
||||
var ErrUpgradeRestart = errors.New("agent upgraded, restart required")
|
||||
89
agent/internal/features/upgrade/upgrader.go
Normal file
89
agent/internal/features/upgrade/upgrader.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package upgrade
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"databasus-agent/internal/features/api"
|
||||
)
|
||||
|
||||
// CheckAndUpdate checks if a new version is available and upgrades the binary on disk.
|
||||
// Returns (true, nil) if the binary was upgraded, (false, nil) if already up to date,
|
||||
// or (false, err) on failure. Callers are responsible for re-exec or restart signaling.
|
||||
func CheckAndUpdate(apiClient *api.Client, currentVersion string, isDev bool, log *slog.Logger) (bool, error) {
|
||||
if isDev {
|
||||
log.Info("Skipping update check (development mode)")
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
serverVersion, err := apiClient.FetchServerVersion(context.Background())
|
||||
if err != nil {
|
||||
log.Warn("Could not reach server for update check", "error", err)
|
||||
|
||||
return false, fmt.Errorf(
|
||||
"unable to check version, please verify Databasus server is available: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
if serverVersion == currentVersion {
|
||||
log.Info("Agent version is up to date", "version", currentVersion)
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
log.Info("Updating agent...", "current", currentVersion, "target", serverVersion)
|
||||
|
||||
selfPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to determine executable path: %w", err)
|
||||
}
|
||||
|
||||
tempPath := selfPath + ".update"
|
||||
|
||||
defer func() {
|
||||
_ = os.Remove(tempPath)
|
||||
}()
|
||||
|
||||
if err := apiClient.DownloadAgentBinary(context.Background(), runtime.GOARCH, tempPath); err != nil {
|
||||
return false, fmt.Errorf("failed to download update: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(tempPath, 0o755); err != nil {
|
||||
return false, fmt.Errorf("failed to set permissions on update: %w", err)
|
||||
}
|
||||
|
||||
if err := verifyBinary(tempPath, serverVersion); err != nil {
|
||||
return false, fmt.Errorf("update verification failed: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tempPath, selfPath); err != nil {
|
||||
return false, fmt.Errorf("failed to replace binary (try --skip-update if this persists): %w", err)
|
||||
}
|
||||
|
||||
log.Info("Agent binary updated", "version", serverVersion)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func verifyBinary(binaryPath, expectedVersion string) error {
|
||||
cmd := exec.CommandContext(context.Background(), binaryPath, "version")
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("binary failed to execute: %w", err)
|
||||
}
|
||||
|
||||
got := strings.TrimSpace(string(output))
|
||||
if got != expectedVersion {
|
||||
return fmt.Errorf("version mismatch: expected %q, got %q", expectedVersion, got)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
204
agent/internal/features/wal/streamer.go
Normal file
204
agent/internal/features/wal/streamer.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"databasus-agent/internal/config"
|
||||
"databasus-agent/internal/features/api"
|
||||
)
|
||||
|
||||
var uploadIdleTimeout = 5 * time.Minute
|
||||
|
||||
const (
|
||||
pollInterval = 10 * time.Second
|
||||
uploadTimeout = 5 * time.Minute
|
||||
)
|
||||
|
||||
var segmentNameRegex = regexp.MustCompile(`^[0-9A-Fa-f]{24}$`)
|
||||
|
||||
type Streamer struct {
|
||||
cfg *config.Config
|
||||
apiClient *api.Client
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
func NewStreamer(cfg *config.Config, apiClient *api.Client, log *slog.Logger) *Streamer {
|
||||
return &Streamer{
|
||||
cfg: cfg,
|
||||
apiClient: apiClient,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Streamer) Run(ctx context.Context) {
|
||||
s.log.Info("WAL streamer started", "pgWalDir", s.cfg.PgWalDir)
|
||||
|
||||
s.processQueue(ctx)
|
||||
|
||||
ticker := time.NewTicker(pollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.log.Info("WAL streamer stopping")
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.processQueue(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Streamer) processQueue(ctx context.Context) {
|
||||
segments, err := s.listSegments()
|
||||
if err != nil {
|
||||
s.log.Error("Failed to list WAL segments", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(segments) == 0 {
|
||||
s.log.Info("No WAL segments pending", "dir", s.cfg.PgWalDir)
|
||||
return
|
||||
}
|
||||
|
||||
s.log.Info("WAL segments pending upload", "dir", s.cfg.PgWalDir, "count", len(segments))
|
||||
|
||||
for _, segmentName := range segments {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.uploadSegment(ctx, segmentName); err != nil {
|
||||
s.log.Error("Failed to upload WAL segment",
|
||||
"segment", segmentName,
|
||||
"error", err,
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Streamer) listSegments() ([]string, error) {
|
||||
entries, err := os.ReadDir(s.cfg.PgWalDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read wal dir: %w", err)
|
||||
}
|
||||
|
||||
var segments []string
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
name := entry.Name()
|
||||
|
||||
if strings.HasSuffix(name, ".tmp") {
|
||||
continue
|
||||
}
|
||||
|
||||
if !segmentNameRegex.MatchString(name) {
|
||||
continue
|
||||
}
|
||||
|
||||
segments = append(segments, name)
|
||||
}
|
||||
|
||||
slices.Sort(segments)
|
||||
|
||||
return segments, nil
|
||||
}
|
||||
|
||||
func (s *Streamer) uploadSegment(ctx context.Context, segmentName string) error {
|
||||
filePath := filepath.Join(s.cfg.PgWalDir, segmentName)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
defer func() { _ = pr.Close() }()
|
||||
|
||||
go s.compressAndStream(pw, filePath)
|
||||
|
||||
uploadCtx, timeoutCancel := context.WithTimeout(ctx, uploadTimeout)
|
||||
defer timeoutCancel()
|
||||
|
||||
idleCtx, idleCancel := context.WithCancelCause(uploadCtx)
|
||||
defer idleCancel(nil)
|
||||
|
||||
idleReader := api.NewIdleTimeoutReader(pr, uploadIdleTimeout, idleCancel)
|
||||
defer idleReader.Stop()
|
||||
|
||||
s.log.Info("Uploading WAL segment", "segment", segmentName)
|
||||
|
||||
result, err := s.apiClient.UploadWalSegment(idleCtx, segmentName, idleReader)
|
||||
if err != nil {
|
||||
if cause := context.Cause(idleCtx); cause != nil {
|
||||
return fmt.Errorf("upload WAL segment: %w", cause)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
if result.IsGapDetected {
|
||||
s.log.Warn("WAL chain gap detected",
|
||||
"segment", segmentName,
|
||||
"expected", result.ExpectedSegmentName,
|
||||
"received", result.ReceivedSegmentName,
|
||||
)
|
||||
|
||||
return fmt.Errorf("gap detected for segment %s", segmentName)
|
||||
}
|
||||
|
||||
s.log.Info("WAL segment uploaded", "segment", segmentName)
|
||||
|
||||
if *s.cfg.IsDeleteWalAfterUpload {
|
||||
if err := os.Remove(filePath); err != nil {
|
||||
s.log.Warn("Failed to delete uploaded WAL segment",
|
||||
"segment", segmentName,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Streamer) compressAndStream(pw *io.PipeWriter, filePath string) {
|
||||
f, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
_ = pw.CloseWithError(fmt.Errorf("open file: %w", err))
|
||||
return
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
encoder, err := zstd.NewWriter(pw,
|
||||
zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(5)),
|
||||
zstd.WithEncoderCRC(true),
|
||||
)
|
||||
if err != nil {
|
||||
_ = pw.CloseWithError(fmt.Errorf("create zstd encoder: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := io.Copy(encoder, f); err != nil {
|
||||
_ = encoder.Close()
|
||||
_ = pw.CloseWithError(fmt.Errorf("compress: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := encoder.Close(); err != nil {
|
||||
_ = pw.CloseWithError(fmt.Errorf("close encoder: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
_ = pw.Close()
|
||||
}
|
||||
393
agent/internal/features/wal/streamer_test.go
Normal file
393
agent/internal/features/wal/streamer_test.go
Normal file
@@ -0,0 +1,393 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"databasus-agent/internal/config"
|
||||
"databasus-agent/internal/features/api"
|
||||
"databasus-agent/internal/logger"
|
||||
)
|
||||
|
||||
func Test_UploadSegment_SingleSegment_ServerReceivesCorrectHeadersAndBody(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
segmentContent := []byte("test-wal-segment-data-for-upload")
|
||||
writeTestSegment(t, walDir, "000000010000000100000001", segmentContent)
|
||||
|
||||
var receivedHeaders http.Header
|
||||
var receivedBody []byte
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedHeaders = r.Header.Clone()
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
receivedBody = body
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
require.NotNil(t, receivedHeaders)
|
||||
assert.Equal(t, "test-token", receivedHeaders.Get("Authorization"))
|
||||
assert.Equal(t, "application/octet-stream", receivedHeaders.Get("Content-Type"))
|
||||
assert.Equal(t, "000000010000000100000001", receivedHeaders.Get("X-Wal-Segment-Name"))
|
||||
|
||||
decompressed := decompressZstd(t, receivedBody)
|
||||
assert.Equal(t, segmentContent, decompressed)
|
||||
}
|
||||
|
||||
func Test_UploadSegments_MultipleSegmentsOutOfOrder_UploadedInAscendingOrder(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
writeTestSegment(t, walDir, "000000010000000100000003", []byte("third"))
|
||||
writeTestSegment(t, walDir, "000000010000000100000001", []byte("first"))
|
||||
writeTestSegment(t, walDir, "000000010000000100000002", []byte("second"))
|
||||
|
||||
var mu sync.Mutex
|
||||
var uploadOrder []string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
uploadOrder = append(uploadOrder, r.Header.Get("X-Wal-Segment-Name"))
|
||||
mu.Unlock()
|
||||
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
require.Len(t, uploadOrder, 3)
|
||||
assert.Equal(t, "000000010000000100000001", uploadOrder[0])
|
||||
assert.Equal(t, "000000010000000100000002", uploadOrder[1])
|
||||
assert.Equal(t, "000000010000000100000003", uploadOrder[2])
|
||||
}
|
||||
|
||||
func Test_UploadSegments_DirectoryHasTmpFiles_TmpFilesIgnored(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
writeTestSegment(t, walDir, "000000010000000100000001", []byte("real segment"))
|
||||
writeTestSegment(t, walDir, "000000010000000100000002.tmp", []byte("partial copy"))
|
||||
|
||||
var mu sync.Mutex
|
||||
var uploadedSegments []string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
uploadedSegments = append(uploadedSegments, r.Header.Get("X-Wal-Segment-Name"))
|
||||
mu.Unlock()
|
||||
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
require.Len(t, uploadedSegments, 1)
|
||||
assert.Equal(t, "000000010000000100000001", uploadedSegments[0])
|
||||
}
|
||||
|
||||
func Test_UploadSegment_DeleteEnabled_FileRemovedAfterUpload(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
segmentName := "000000010000000100000001"
|
||||
writeTestSegment(t, walDir, segmentName, []byte("segment data"))
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
isDeleteEnabled := true
|
||||
cfg := createTestConfig(walDir, server.URL)
|
||||
cfg.IsDeleteWalAfterUpload = &isDeleteEnabled
|
||||
apiClient := api.NewClient(server.URL, cfg.Token, logger.GetLogger())
|
||||
streamer := NewStreamer(cfg, apiClient, logger.GetLogger())
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
_, err := os.Stat(filepath.Join(walDir, segmentName))
|
||||
assert.True(t, os.IsNotExist(err), "segment file should be deleted after successful upload")
|
||||
}
|
||||
|
||||
func Test_UploadSegment_DeleteDisabled_FileKeptAfterUpload(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
segmentName := "000000010000000100000001"
|
||||
writeTestSegment(t, walDir, segmentName, []byte("segment data"))
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
isDeleteDisabled := false
|
||||
cfg := createTestConfig(walDir, server.URL)
|
||||
cfg.IsDeleteWalAfterUpload = &isDeleteDisabled
|
||||
apiClient := api.NewClient(server.URL, cfg.Token, logger.GetLogger())
|
||||
streamer := NewStreamer(cfg, apiClient, logger.GetLogger())
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
_, err := os.Stat(filepath.Join(walDir, segmentName))
|
||||
assert.NoError(t, err, "segment file should be kept when delete is disabled")
|
||||
}
|
||||
|
||||
func Test_UploadSegment_ServerReturns500_FileKeptInQueue(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
segmentName := "000000010000000100000001"
|
||||
writeTestSegment(t, walDir, segmentName, []byte("segment data"))
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte(`{"error":"internal server error"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
_, err := os.Stat(filepath.Join(walDir, segmentName))
|
||||
assert.NoError(t, err, "segment file should remain in queue after server error")
|
||||
}
|
||||
|
||||
func Test_ProcessQueue_EmptyDirectory_NoUploads(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
|
||||
uploadCount := 0
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
uploadCount++
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
assert.Equal(t, 0, uploadCount, "no uploads should occur for empty directory")
|
||||
}
|
||||
|
||||
func Test_Run_ContextCancelled_StopsImmediately(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
|
||||
streamer := newTestStreamer(walDir, "http://localhost:0")
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
streamer.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Run should have stopped immediately when context is already cancelled")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_UploadSegment_ServerReturns409_FileNotDeleted(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
segmentName := "000000010000000100000005"
|
||||
writeTestSegment(t, walDir, segmentName, []byte("gap segment"))
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusConflict)
|
||||
|
||||
resp := map[string]string{
|
||||
"error": "gap_detected",
|
||||
"expectedSegmentName": "000000010000000100000003",
|
||||
"receivedSegmentName": segmentName,
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go streamer.Run(ctx)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
_, err := os.Stat(filepath.Join(walDir, segmentName))
|
||||
assert.NoError(t, err, "segment file should not be deleted on gap detection")
|
||||
}
|
||||
|
||||
func Test_UploadSegment_WhenUploadStalls_FailsWithIdleTimeout(t *testing.T) {
|
||||
walDir := createTestWalDir(t)
|
||||
|
||||
// Use incompressible random data to ensure TCP buffers fill up
|
||||
segmentContent := make([]byte, 1024*1024)
|
||||
_, err := rand.Read(segmentContent)
|
||||
require.NoError(t, err)
|
||||
|
||||
writeTestSegment(t, walDir, "000000010000000100000001", segmentContent)
|
||||
|
||||
var requestReceived atomic.Bool
|
||||
handlerDone := make(chan struct{})
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestReceived.Store(true)
|
||||
|
||||
// Read one byte then stall — simulates a network stall
|
||||
buf := make([]byte, 1)
|
||||
_, _ = r.Body.Read(buf)
|
||||
<-handlerDone
|
||||
}))
|
||||
defer server.Close()
|
||||
defer close(handlerDone)
|
||||
|
||||
origIdleTimeout := uploadIdleTimeout
|
||||
uploadIdleTimeout = 200 * time.Millisecond
|
||||
defer func() { uploadIdleTimeout = origIdleTimeout }()
|
||||
|
||||
streamer := newTestStreamer(walDir, server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
uploadErr := streamer.uploadSegment(ctx, "000000010000000100000001")
|
||||
|
||||
assert.Error(t, uploadErr, "upload should fail when stalled")
|
||||
assert.True(t, requestReceived.Load(), "server should have received the request")
|
||||
assert.Contains(t, uploadErr.Error(), "idle timeout", "error should mention idle timeout")
|
||||
|
||||
_, statErr := os.Stat(filepath.Join(walDir, "000000010000000100000001"))
|
||||
assert.NoError(t, statErr, "segment file should remain in queue after idle timeout")
|
||||
}
|
||||
|
||||
func newTestStreamer(walDir, serverURL string) *Streamer {
|
||||
cfg := createTestConfig(walDir, serverURL)
|
||||
apiClient := api.NewClient(serverURL, cfg.Token, logger.GetLogger())
|
||||
|
||||
return NewStreamer(cfg, apiClient, logger.GetLogger())
|
||||
}
|
||||
|
||||
func createTestWalDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
baseDir := filepath.Join(".", ".test-tmp")
|
||||
if err := os.MkdirAll(baseDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create base test dir: %v", err)
|
||||
}
|
||||
|
||||
dir, err := os.MkdirTemp(baseDir, t.Name()+"-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test wal dir: %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(dir)
|
||||
})
|
||||
|
||||
return dir
|
||||
}
|
||||
|
||||
func writeTestSegment(t *testing.T, dir, name string, content []byte) {
|
||||
t.Helper()
|
||||
|
||||
if err := os.WriteFile(filepath.Join(dir, name), content, 0o644); err != nil {
|
||||
t.Fatalf("failed to write test segment %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
func createTestConfig(walDir, serverURL string) *config.Config {
|
||||
isDeleteEnabled := true
|
||||
|
||||
return &config.Config{
|
||||
DatabasusHost: serverURL,
|
||||
DbID: "test-db-id",
|
||||
Token: "test-token",
|
||||
PgWalDir: walDir,
|
||||
IsDeleteWalAfterUpload: &isDeleteEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func decompressZstd(t *testing.T, data []byte) []byte {
|
||||
t.Helper()
|
||||
|
||||
decoder, err := zstd.NewReader(nil)
|
||||
require.NoError(t, err)
|
||||
defer decoder.Close()
|
||||
|
||||
decoded, err := decoder.DecodeAll(data, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
return decoded
|
||||
}
|
||||
115
agent/internal/logger/logger.go
Normal file
115
agent/internal/logger/logger.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
logFileName = "databasus.log"
|
||||
oldLogFileName = "databasus.log.old"
|
||||
maxLogFileSize = 5 * 1024 * 1024 // 5MB
|
||||
)
|
||||
|
||||
type rotatingWriter struct {
|
||||
mu sync.Mutex
|
||||
file *os.File
|
||||
currentSize int64
|
||||
maxSize int64
|
||||
logPath string
|
||||
oldLogPath string
|
||||
}
|
||||
|
||||
func (w *rotatingWriter) Write(p []byte) (int, error) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if w.currentSize+int64(len(p)) > w.maxSize {
|
||||
if err := w.rotate(); err != nil {
|
||||
return 0, fmt.Errorf("failed to rotate log file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
n, err := w.file.Write(p)
|
||||
w.currentSize += int64(n)
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (w *rotatingWriter) rotate() error {
|
||||
if err := w.file.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close %s: %w", w.logPath, err)
|
||||
}
|
||||
|
||||
if err := os.Remove(w.oldLogPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to remove %s: %w", w.oldLogPath, err)
|
||||
}
|
||||
|
||||
if err := os.Rename(w.logPath, w.oldLogPath); err != nil {
|
||||
return fmt.Errorf("failed to rename %s to %s: %w", w.logPath, w.oldLogPath, err)
|
||||
}
|
||||
|
||||
f, err := os.OpenFile(w.logPath, os.O_CREATE|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create new %s: %w", w.logPath, err)
|
||||
}
|
||||
|
||||
w.file = f
|
||||
w.currentSize = 0
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var loggerInstance *slog.Logger
|
||||
|
||||
var initLogger = sync.OnceFunc(initialize)
|
||||
|
||||
func GetLogger() *slog.Logger {
|
||||
initLogger()
|
||||
return loggerInstance
|
||||
}
|
||||
|
||||
func initialize() {
|
||||
writer := buildWriter()
|
||||
|
||||
loggerInstance = slog.New(slog.NewTextHandler(writer, &slog.HandlerOptions{
|
||||
Level: slog.LevelInfo,
|
||||
ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr {
|
||||
if a.Key == slog.TimeKey {
|
||||
a.Value = slog.StringValue(time.Now().Format("2006/01/02 15:04:05"))
|
||||
}
|
||||
if a.Key == slog.LevelKey {
|
||||
return slog.Attr{}
|
||||
}
|
||||
|
||||
return a
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
func buildWriter() io.Writer {
|
||||
f, err := os.OpenFile(logFileName, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Warning: failed to open %s for logging: %v\n", logFileName, err)
|
||||
return os.Stdout
|
||||
}
|
||||
|
||||
var currentSize int64
|
||||
if info, err := f.Stat(); err == nil {
|
||||
currentSize = info.Size()
|
||||
}
|
||||
|
||||
rw := &rotatingWriter{
|
||||
file: f,
|
||||
currentSize: currentSize,
|
||||
maxSize: maxLogFileSize,
|
||||
logPath: logFileName,
|
||||
oldLogPath: oldLogFileName,
|
||||
}
|
||||
|
||||
return io.MultiWriter(os.Stdout, rw)
|
||||
}
|
||||
128
agent/internal/logger/logger_test.go
Normal file
128
agent/internal/logger/logger_test.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_Write_DataWrittenToFile(t *testing.T) {
|
||||
rw, logPath, _ := setupRotatingWriter(t, 1024)
|
||||
|
||||
data := []byte("hello world\n")
|
||||
n, err := rw.Write(data)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(data), n)
|
||||
assert.Equal(t, int64(len(data)), rw.currentSize)
|
||||
|
||||
content, err := os.ReadFile(logPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(data), string(content))
|
||||
}
|
||||
|
||||
func Test_Write_WhenLimitExceeded_FileRotated(t *testing.T) {
|
||||
rw, logPath, oldLogPath := setupRotatingWriter(t, 100)
|
||||
|
||||
firstData := []byte(strings.Repeat("A", 80))
|
||||
_, err := rw.Write(firstData)
|
||||
require.NoError(t, err)
|
||||
|
||||
secondData := []byte(strings.Repeat("B", 30))
|
||||
_, err = rw.Write(secondData)
|
||||
require.NoError(t, err)
|
||||
|
||||
oldContent, err := os.ReadFile(oldLogPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(firstData), string(oldContent))
|
||||
|
||||
newContent, err := os.ReadFile(logPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(secondData), string(newContent))
|
||||
|
||||
assert.Equal(t, int64(len(secondData)), rw.currentSize)
|
||||
}
|
||||
|
||||
func Test_Write_WhenOldFileExists_OldFileReplaced(t *testing.T) {
|
||||
rw, _, oldLogPath := setupRotatingWriter(t, 100)
|
||||
|
||||
require.NoError(t, os.WriteFile(oldLogPath, []byte("stale data"), 0o644))
|
||||
|
||||
_, err := rw.Write([]byte(strings.Repeat("A", 80)))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = rw.Write([]byte(strings.Repeat("B", 30)))
|
||||
require.NoError(t, err)
|
||||
|
||||
oldContent, err := os.ReadFile(oldLogPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, strings.Repeat("A", 80), string(oldContent))
|
||||
}
|
||||
|
||||
func Test_Write_MultipleSmallWrites_CurrentSizeAccumulated(t *testing.T) {
|
||||
rw, _, _ := setupRotatingWriter(t, 1024)
|
||||
|
||||
var totalWritten int64
|
||||
for range 10 {
|
||||
data := []byte("line\n")
|
||||
n, err := rw.Write(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
totalWritten += int64(n)
|
||||
}
|
||||
|
||||
assert.Equal(t, totalWritten, rw.currentSize)
|
||||
assert.Equal(t, int64(50), rw.currentSize)
|
||||
}
|
||||
|
||||
func Test_Write_ExactlyAtBoundary_NoRotationUntilNextByte(t *testing.T) {
|
||||
rw, logPath, oldLogPath := setupRotatingWriter(t, 100)
|
||||
|
||||
exactData := []byte(strings.Repeat("X", 100))
|
||||
_, err := rw.Write(exactData)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = os.Stat(oldLogPath)
|
||||
assert.True(t, os.IsNotExist(err), ".old file should not exist yet")
|
||||
|
||||
content, err := os.ReadFile(logPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(exactData), string(content))
|
||||
|
||||
_, err = rw.Write([]byte("Z"))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = os.Stat(oldLogPath)
|
||||
assert.NoError(t, err, ".old file should exist after exceeding limit")
|
||||
|
||||
assert.Equal(t, int64(1), rw.currentSize)
|
||||
}
|
||||
|
||||
func setupRotatingWriter(t *testing.T, maxSize int64) (*rotatingWriter, string, string) {
|
||||
t.Helper()
|
||||
|
||||
dir := t.TempDir()
|
||||
logPath := filepath.Join(dir, "test.log")
|
||||
oldLogPath := filepath.Join(dir, "test.log.old")
|
||||
|
||||
f, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY, 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
rw := &rotatingWriter{
|
||||
file: f,
|
||||
currentSize: 0,
|
||||
maxSize: maxSize,
|
||||
logPath: logPath,
|
||||
oldLogPath: oldLogPath,
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
rw.file.Close()
|
||||
})
|
||||
|
||||
return rw, logPath, oldLogPath
|
||||
}
|
||||
@@ -1,152 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
Always place private methods to the bottom of file
|
||||
|
||||
**This rule applies to ALL Go files including tests, services, controllers, repositories, etc.**
|
||||
|
||||
In Go, exported (public) functions/methods start with uppercase letters, while unexported (private) ones start with lowercase letters.
|
||||
|
||||
## Structure Order:
|
||||
|
||||
1. Type definitions and constants
|
||||
2. Public methods/functions (uppercase)
|
||||
3. Private methods/functions (lowercase)
|
||||
|
||||
## Examples:
|
||||
|
||||
### Service with methods:
|
||||
|
||||
```go
|
||||
type UserService struct {
|
||||
repository *UserRepository
|
||||
}
|
||||
|
||||
// Public methods first
|
||||
func (s *UserService) CreateUser(user *User) error {
|
||||
if err := s.validateUser(user); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.repository.Save(user)
|
||||
}
|
||||
|
||||
func (s *UserService) GetUser(id uuid.UUID) (*User, error) {
|
||||
return s.repository.FindByID(id)
|
||||
}
|
||||
|
||||
// Private methods at the bottom
|
||||
func (s *UserService) validateUser(user *User) error {
|
||||
if user.Name == "" {
|
||||
return errors.New("name is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
### Package-level functions:
|
||||
|
||||
```go
|
||||
package utils
|
||||
|
||||
// Public functions first
|
||||
func ProcessData(data []byte) (Result, error) {
|
||||
cleaned := sanitizeInput(data)
|
||||
return parseData(cleaned)
|
||||
}
|
||||
|
||||
func ValidateInput(input string) bool {
|
||||
return isValidFormat(input) && checkLength(input)
|
||||
}
|
||||
|
||||
// Private functions at the bottom
|
||||
func sanitizeInput(data []byte) []byte {
|
||||
// implementation
|
||||
}
|
||||
|
||||
func parseData(data []byte) (Result, error) {
|
||||
// implementation
|
||||
}
|
||||
|
||||
func isValidFormat(input string) bool {
|
||||
// implementation
|
||||
}
|
||||
|
||||
func checkLength(input string) bool {
|
||||
// implementation
|
||||
}
|
||||
```
|
||||
|
||||
### Test files:
|
||||
|
||||
```go
|
||||
package user_test
|
||||
|
||||
// Public test functions first
|
||||
func Test_CreateUser_ValidInput_UserCreated(t *testing.T) {
|
||||
user := createTestUser()
|
||||
result, err := service.CreateUser(user)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
|
||||
func Test_GetUser_ExistingUser_ReturnsUser(t *testing.T) {
|
||||
user := createTestUser()
|
||||
// test implementation
|
||||
}
|
||||
|
||||
// Private helper functions at the bottom
|
||||
func createTestUser() *User {
|
||||
return &User{
|
||||
Name: "Test User",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
}
|
||||
|
||||
func setupTestDatabase() *Database {
|
||||
// setup implementation
|
||||
}
|
||||
```
|
||||
|
||||
### Controller example:
|
||||
|
||||
```go
|
||||
type ProjectController struct {
|
||||
service *ProjectService
|
||||
}
|
||||
|
||||
// Public HTTP handlers first
|
||||
func (c *ProjectController) CreateProject(ctx *gin.Context) {
|
||||
var request CreateProjectRequest
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
c.handleError(ctx, err)
|
||||
return
|
||||
}
|
||||
// handler logic
|
||||
}
|
||||
|
||||
func (c *ProjectController) GetProject(ctx *gin.Context) {
|
||||
projectID := c.extractProjectID(ctx)
|
||||
// handler logic
|
||||
}
|
||||
|
||||
// Private helper methods at the bottom
|
||||
func (c *ProjectController) handleError(ctx *gin.Context, err error) {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
}
|
||||
|
||||
func (c *ProjectController) extractProjectID(ctx *gin.Context) uuid.UUID {
|
||||
return uuid.MustParse(ctx.Param("projectId"))
|
||||
}
|
||||
```
|
||||
|
||||
## Key Points:
|
||||
|
||||
- **Exported/Public** = starts with uppercase letter (CreateUser, GetProject)
|
||||
- **Unexported/Private** = starts with lowercase letter (validateUser, handleError)
|
||||
- This improves code readability by showing the public API first
|
||||
- Private helpers are implementation details, so they go at the bottom
|
||||
- Apply this rule consistently across ALL Go files in the project
|
||||
@@ -1,45 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
## Comment Guidelines
|
||||
|
||||
1. **No obvious comments** - Don't state what the code already clearly shows
|
||||
2. **Functions and variables should have meaningful names** - Code should be self-documenting
|
||||
3. **Comments for unclear code only** - Only add comments when code logic isn't immediately clear
|
||||
|
||||
## Key Principles:
|
||||
|
||||
- **Code should tell a story** - Use descriptive variable and function names
|
||||
- **Comments explain WHY, not WHAT** - The code shows what happens, comments explain business logic or complex decisions
|
||||
- **Prefer refactoring over commenting** - If code needs explaining, consider making it clearer instead
|
||||
- **API documentation is required** - Swagger comments for all HTTP endpoints are mandatory
|
||||
- **Complex algorithms deserve comments** - Mathematical formulas, business rules, or non-obvious optimizations
|
||||
|
||||
Example of useless comment:
|
||||
|
||||
1.
|
||||
|
||||
```sql
|
||||
// Create projects table
|
||||
CREATE TABLE projects (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
name TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
```
|
||||
|
||||
2.
|
||||
|
||||
```go
|
||||
// Create test project
|
||||
project := CreateTestProject(projectName, user, router)
|
||||
```
|
||||
|
||||
3.
|
||||
|
||||
```go
|
||||
// CreateValidLogItems creates valid log items for testing
|
||||
func CreateValidLogItems(count int, uniqueID string) []logs_receiving.LogItemRequestDTO {
|
||||
```
|
||||
@@ -1,133 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
1. When we write controller:
|
||||
|
||||
- we combine all routes to single controller
|
||||
- names them as .WhatWeDo (not "handlers") concept
|
||||
|
||||
2. We use gin and \*gin.Context for all routes.
|
||||
Example:
|
||||
|
||||
func (c *TasksController) GetAvailableTasks(ctx *gin.Context) ...
|
||||
|
||||
3. We document all routes with Swagger in the following format:
|
||||
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
)
|
||||
|
||||
type AuditLogController struct {
|
||||
auditLogService \*AuditLogService
|
||||
}
|
||||
|
||||
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// All audit log endpoints require authentication (handled in main.go)
|
||||
auditRoutes := router.Group("/audit-logs")
|
||||
|
||||
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
|
||||
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
|
||||
|
||||
}
|
||||
|
||||
// GetGlobalAuditLogs
|
||||
// @Summary Get global audit logs (ADMIN only)
|
||||
// @Description Retrieve all audit logs across the system
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/global [get]
|
||||
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(\*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "only administrators can view global audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
|
||||
}
|
||||
|
||||
// GetUserAuditLogs
|
||||
// @Summary Get user audit logs
|
||||
// @Description Retrieve audit logs for a specific user
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param userId path string true "User ID"
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/users/{userId} [get]
|
||||
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(\*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
userIDStr := ctx.Param("userId")
|
||||
targetUserID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view user audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
|
||||
}
|
||||
@@ -1,671 +0,0 @@
|
||||
---
|
||||
alwaysApply: false
|
||||
---
|
||||
|
||||
This is example of CRUD:
|
||||
|
||||
------ backend/internal/features/audit_logs/controller.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogController struct {
|
||||
auditLogService *AuditLogService
|
||||
}
|
||||
|
||||
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// All audit log endpoints require authentication (handled in main.go)
|
||||
auditRoutes := router.Group("/audit-logs")
|
||||
|
||||
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
|
||||
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
|
||||
}
|
||||
|
||||
// GetGlobalAuditLogs
|
||||
// @Summary Get global audit logs (ADMIN only)
|
||||
// @Description Retrieve all audit logs across the system
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/global [get]
|
||||
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "only administrators can view global audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// GetUserAuditLogs
|
||||
// @Summary Get user audit logs
|
||||
// @Description Retrieve audit logs for a specific user
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param userId path string true "User ID"
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/users/{userId} [get]
|
||||
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
userIDStr := ctx.Param("userId")
|
||||
targetUserID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view user audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/controller_test.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
"databasus-backend/internal/storage"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_GetGlobalAuditLogs_AdminSucceedsAndMemberGetsForbidden(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
memberUser := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
projectID := uuid.New()
|
||||
|
||||
// Create test logs
|
||||
createAuditLog(service, "Test log with user", &adminUser.UserID, nil)
|
||||
createAuditLog(service, "Test log with project", nil, &projectID)
|
||||
createAuditLog(service, "Test log standalone", nil, nil)
|
||||
|
||||
// Test ADMIN can access global logs
|
||||
var response GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
"/api/v1/audit-logs/global?limit=10", "Bearer "+adminUser.Token, http.StatusOK, &response)
|
||||
|
||||
assert.GreaterOrEqual(t, len(response.AuditLogs), 3)
|
||||
assert.GreaterOrEqual(t, response.Total, int64(3))
|
||||
|
||||
messages := extractMessages(response.AuditLogs)
|
||||
assert.Contains(t, messages, "Test log with user")
|
||||
assert.Contains(t, messages, "Test log with project")
|
||||
assert.Contains(t, messages, "Test log standalone")
|
||||
|
||||
// Test MEMBER cannot access global logs
|
||||
resp := test_utils.MakeGetRequest(t, router, "/api/v1/audit-logs/global",
|
||||
"Bearer "+memberUser.Token, http.StatusForbidden)
|
||||
assert.Contains(t, string(resp.Body), "only administrators can view global audit logs")
|
||||
}
|
||||
|
||||
func Test_GetUserAuditLogs_PermissionsEnforcedCorrectly(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
projectID := uuid.New()
|
||||
|
||||
// Create test logs for different users
|
||||
createAuditLog(service, "Test log user1 first", &user1.UserID, nil)
|
||||
createAuditLog(service, "Test log user1 second", &user1.UserID, &projectID)
|
||||
createAuditLog(service, "Test log user2 first", &user2.UserID, nil)
|
||||
createAuditLog(service, "Test log user2 second", &user2.UserID, &projectID)
|
||||
createAuditLog(service, "Test project log", nil, &projectID)
|
||||
|
||||
// Test ADMIN can view any user's logs
|
||||
var user1Response GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s?limit=10", user1.UserID.String()),
|
||||
"Bearer "+adminUser.Token, http.StatusOK, &user1Response)
|
||||
|
||||
assert.Equal(t, 2, len(user1Response.AuditLogs))
|
||||
messages := extractMessages(user1Response.AuditLogs)
|
||||
assert.Contains(t, messages, "Test log user1 first")
|
||||
assert.Contains(t, messages, "Test log user1 second")
|
||||
|
||||
// Test user can view own logs
|
||||
var ownLogsResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s", user2.UserID.String()),
|
||||
"Bearer "+user2.Token, http.StatusOK, &ownLogsResponse)
|
||||
assert.Equal(t, 2, len(ownLogsResponse.AuditLogs))
|
||||
|
||||
// Test user cannot view other user's logs
|
||||
resp := test_utils.MakeGetRequest(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s", user1.UserID.String()),
|
||||
"Bearer "+user2.Token, http.StatusForbidden)
|
||||
|
||||
assert.Contains(t, string(resp.Body), "insufficient permissions")
|
||||
}
|
||||
|
||||
func Test_FilterAuditLogsByTime_ReturnsOnlyLogsBeforeDate(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
db := storage.GetDb()
|
||||
baseTime := time.Now().UTC()
|
||||
|
||||
// Create logs with different timestamps
|
||||
createTimedLog(db, &adminUser.UserID, "Test old log", baseTime.Add(-2*time.Hour))
|
||||
createTimedLog(db, &adminUser.UserID, "Test recent log", baseTime.Add(-30*time.Minute))
|
||||
createAuditLog(service, "Test current log", &adminUser.UserID, nil)
|
||||
|
||||
// Test filtering - get logs before 1 hour ago
|
||||
beforeTime := baseTime.Add(-1 * time.Hour)
|
||||
var filteredResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/global?beforeDate=%s", beforeTime.Format(time.RFC3339)),
|
||||
"Bearer "+adminUser.Token, http.StatusOK, &filteredResponse)
|
||||
|
||||
// Verify only old log is returned
|
||||
messages := extractMessages(filteredResponse.AuditLogs)
|
||||
assert.Contains(t, messages, "Test old log")
|
||||
assert.NotContains(t, messages, "Test recent log")
|
||||
assert.NotContains(t, messages, "Test current log")
|
||||
|
||||
// Test without filter - should get all logs
|
||||
var allResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router, "/api/v1/audit-logs/global",
|
||||
"Bearer "+adminUser.Token, http.StatusOK, &allResponse)
|
||||
assert.GreaterOrEqual(t, len(allResponse.AuditLogs), 3)
|
||||
}
|
||||
|
||||
func createRouter() *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
SetupDependencies()
|
||||
|
||||
v1 := router.Group("/api/v1")
|
||||
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
|
||||
GetAuditLogController().RegisterRoutes(protected.(*gin.RouterGroup))
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/di.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var auditLogRepository = &AuditLogRepository{}
|
||||
var auditLogService = &AuditLogService{
|
||||
auditLogRepository: auditLogRepository,
|
||||
logger: logger.GetLogger(),
|
||||
}
|
||||
var auditLogController = &AuditLogController{
|
||||
auditLogService: auditLogService,
|
||||
}
|
||||
|
||||
func GetAuditLogService() *AuditLogService {
|
||||
return auditLogService
|
||||
}
|
||||
|
||||
func GetAuditLogController() *AuditLogController {
|
||||
return auditLogController
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
users_services.GetUserService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/dto.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import "time"
|
||||
|
||||
type GetAuditLogsRequest struct {
|
||||
Limit int `form:"limit" json:"limit"`
|
||||
Offset int `form:"offset" json:"offset"`
|
||||
BeforeDate *time.Time `form:"beforeDate" json:"beforeDate"`
|
||||
}
|
||||
|
||||
type GetAuditLogsResponse struct {
|
||||
AuditLogs []*AuditLog `json:"auditLogs"`
|
||||
Total int64 `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/models.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLog struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id"`
|
||||
UserID *uuid.UUID `json:"userId" gorm:"column:user_id"`
|
||||
ProjectID *uuid.UUID `json:"projectId" gorm:"column:project_id"`
|
||||
Message string `json:"message" gorm:"column:message"`
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
|
||||
}
|
||||
|
||||
func (AuditLog) TableName() string {
|
||||
return "audit_logs"
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/repository.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/storage"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogRepository struct{}
|
||||
|
||||
func (r *AuditLogRepository) Create(auditLog *AuditLog) error {
|
||||
if auditLog.ID == uuid.Nil {
|
||||
auditLog.ID = uuid.New()
|
||||
}
|
||||
|
||||
return storage.GetDb().Create(auditLog).Error
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetGlobal(limit, offset int, beforeDate *time.Time) ([]*AuditLog, error) {
|
||||
var auditLogs []*AuditLog
|
||||
|
||||
query := storage.GetDb().Order("created_at DESC")
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetByUser(
|
||||
userID uuid.UUID,
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLog, error) {
|
||||
var auditLogs []*AuditLog
|
||||
|
||||
query := storage.GetDb().
|
||||
Where("user_id = ?", userID).
|
||||
Order("created_at DESC")
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetByProject(
|
||||
projectID uuid.UUID,
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLog, error) {
|
||||
var auditLogs []*AuditLog
|
||||
|
||||
query := storage.GetDb().
|
||||
Where("project_id = ?", projectID).
|
||||
Order("created_at DESC")
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) CountGlobal(beforeDate *time.Time) (int64, error) {
|
||||
var count int64
|
||||
query := storage.GetDb().Model(&AuditLog{})
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/service.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogService struct {
|
||||
auditLogRepository *AuditLogRepository
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *AuditLogService) WriteAuditLog(
|
||||
message string,
|
||||
userID *uuid.UUID,
|
||||
projectID *uuid.UUID,
|
||||
) {
|
||||
auditLog := &AuditLog{
|
||||
UserID: userID,
|
||||
ProjectID: projectID,
|
||||
Message: message,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
err := s.auditLogRepository.Create(auditLog)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to create audit log", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuditLogService) CreateAuditLog(auditLog *AuditLog) error {
|
||||
return s.auditLogRepository.Create(auditLog)
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetGlobalAuditLogs(
|
||||
user *user_models.User,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
if user.Role != user_enums.UserRoleAdmin {
|
||||
return nil, errors.New("only administrators can view global audit logs")
|
||||
}
|
||||
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetGlobal(limit, offset, request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
total, err := s.auditLogRepository.CountGlobal(request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: total,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetUserAuditLogs(
|
||||
targetUserID uuid.UUID,
|
||||
user *user_models.User,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
// Users can view their own logs, ADMIN can view any user's logs
|
||||
if user.Role != user_enums.UserRoleAdmin && user.ID != targetUserID {
|
||||
return nil, errors.New("insufficient permissions to view user audit logs")
|
||||
}
|
||||
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetByUser(targetUserID, limit, offset, request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: int64(len(auditLogs)),
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetProjectAuditLogs(
|
||||
projectID uuid.UUID,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetByProject(projectID, limit, offset, request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: int64(len(auditLogs)),
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/service_test.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func Test_AuditLogs_ProjectSpecificLogs(t *testing.T) {
|
||||
service := GetAuditLogService()
|
||||
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
project1ID, project2ID := uuid.New(), uuid.New()
|
||||
|
||||
// Create test logs for projects
|
||||
createAuditLog(service, "Test project1 log first", &user1.UserID, &project1ID)
|
||||
createAuditLog(service, "Test project1 log second", &user2.UserID, &project1ID)
|
||||
createAuditLog(service, "Test project2 log first", &user1.UserID, &project2ID)
|
||||
createAuditLog(service, "Test project2 log second", &user2.UserID, &project2ID)
|
||||
createAuditLog(service, "Test no project log", &user1.UserID, nil)
|
||||
|
||||
request := &GetAuditLogsRequest{Limit: 10, Offset: 0}
|
||||
|
||||
// Test project 1 logs
|
||||
project1Response, err := service.GetProjectAuditLogs(project1ID, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(project1Response.AuditLogs))
|
||||
|
||||
messages := extractMessages(project1Response.AuditLogs)
|
||||
assert.Contains(t, messages, "Test project1 log first")
|
||||
assert.Contains(t, messages, "Test project1 log second")
|
||||
for _, log := range project1Response.AuditLogs {
|
||||
assert.Equal(t, &project1ID, log.ProjectID)
|
||||
}
|
||||
|
||||
// Test project 2 logs
|
||||
project2Response, err := service.GetProjectAuditLogs(project2ID, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(project2Response.AuditLogs))
|
||||
|
||||
messages2 := extractMessages(project2Response.AuditLogs)
|
||||
assert.Contains(t, messages2, "Test project2 log first")
|
||||
assert.Contains(t, messages2, "Test project2 log second")
|
||||
|
||||
// Test pagination
|
||||
limitedResponse, err := service.GetProjectAuditLogs(project1ID,
|
||||
&GetAuditLogsRequest{Limit: 1, Offset: 0})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(limitedResponse.AuditLogs))
|
||||
assert.Equal(t, 1, limitedResponse.Limit)
|
||||
|
||||
// Test beforeDate filter
|
||||
beforeTime := time.Now().UTC().Add(-1 * time.Minute)
|
||||
filteredResponse, err := service.GetProjectAuditLogs(project1ID,
|
||||
&GetAuditLogsRequest{Limit: 10, BeforeDate: &beforeTime})
|
||||
assert.NoError(t, err)
|
||||
for _, log := range filteredResponse.AuditLogs {
|
||||
assert.True(t, log.CreatedAt.Before(beforeTime))
|
||||
}
|
||||
}
|
||||
|
||||
func createAuditLog(service *AuditLogService, message string, userID, projectID *uuid.UUID) {
|
||||
service.WriteAuditLog(message, userID, projectID)
|
||||
}
|
||||
|
||||
func extractMessages(logs []*AuditLog) []string {
|
||||
messages := make([]string, len(logs))
|
||||
for i, log := range logs {
|
||||
messages[i] = log.Message
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
func createTimedLog(db *gorm.DB, userID *uuid.UUID, message string, createdAt time.Time) {
|
||||
log := &AuditLog{
|
||||
ID: uuid.New(),
|
||||
UserID: userID,
|
||||
Message: message,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
db.Create(log)
|
||||
}
|
||||
|
||||
```
|
||||
@@ -1,74 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
For DI files use implicit fields declaration styles (espesially
|
||||
for controllers, services, repositories, use cases, etc., not simple
|
||||
data structures).
|
||||
|
||||
So, instead of:
|
||||
|
||||
var orderController = &OrderController{
|
||||
orderService: orderService,
|
||||
botUserService: bot_users.GetBotUserService(),
|
||||
botService: bots.GetBotService(),
|
||||
userService: users.GetUserService(),
|
||||
}
|
||||
|
||||
Use:
|
||||
|
||||
var orderController = &OrderController{
|
||||
orderService,
|
||||
bot_users.GetBotUserService(),
|
||||
bots.GetBotService(),
|
||||
users.GetUserService(),
|
||||
}
|
||||
|
||||
This is needed to avoid forgetting to update DI style
|
||||
when we add new dependency.
|
||||
|
||||
---
|
||||
|
||||
Please force such usage if file look like this (see some
|
||||
services\controllers\repos definitions and getters):
|
||||
|
||||
var orderBackgroundService = &OrderBackgroundService{
|
||||
orderService: orderService,
|
||||
orderPaymentRepository: orderPaymentRepository,
|
||||
botService: bots.GetBotService(),
|
||||
paymentSettingsService: payment_settings.GetPaymentSettingsService(),
|
||||
|
||||
orderSubscriptionListeners: []OrderSubscriptionListener{},
|
||||
}
|
||||
|
||||
var orderController = &OrderController{
|
||||
orderService: orderService,
|
||||
botUserService: bot_users.GetBotUserService(),
|
||||
botService: bots.GetBotService(),
|
||||
userService: users.GetUserService(),
|
||||
}
|
||||
|
||||
func GetUniquePaymentRepository() *repositories.UniquePaymentRepository {
|
||||
return uniquePaymentRepository
|
||||
}
|
||||
|
||||
func GetOrderPaymentRepository() *repositories.OrderPaymentRepository {
|
||||
return orderPaymentRepository
|
||||
}
|
||||
|
||||
func GetOrderService() *OrderService {
|
||||
return orderService
|
||||
}
|
||||
|
||||
func GetOrderController() *OrderController {
|
||||
return orderController
|
||||
}
|
||||
|
||||
func GetOrderBackgroundService() *OrderBackgroundService {
|
||||
return orderBackgroundService
|
||||
}
|
||||
|
||||
func GetOrderRepository() *repositories.OrderRepository {
|
||||
return orderRepository
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
When writting migrations:
|
||||
|
||||
- write them for PostgreSQL
|
||||
- for PRIMARY UUID keys use gen_random_uuid()
|
||||
- for time use TIMESTAMPTZ (timestamp with zone)
|
||||
- split table, constraint and indexes declaration (table first, them other one by one)
|
||||
- format SQL in pretty way (add spaces, align columns types), constraints split by lines. The example:
|
||||
|
||||
CREATE TABLE marketplace_info (
|
||||
bot_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
title TEXT NOT NULL,
|
||||
description TEXT NOT NULL,
|
||||
short_description TEXT NOT NULL,
|
||||
tutorial_url TEXT,
|
||||
info_order BIGINT NOT NULL DEFAULT 0,
|
||||
is_published BOOLEAN NOT NULL DEFAULT FALSE
|
||||
);
|
||||
|
||||
ALTER TABLE marketplace_info_images
|
||||
ADD CONSTRAINT fk_marketplace_info_images_bot_id
|
||||
FOREIGN KEY (bot_id)
|
||||
REFERENCES marketplace_info (bot_id);
|
||||
@@ -1,12 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
When applying changes, do not forget to refactor old code.
|
||||
|
||||
You can shortify, make more readable, improve code quality, etc.
|
||||
Common logic can be extracted to functions, constants, files, etc.
|
||||
|
||||
After each large change with more than ~50-100 lines of code - always run `make lint` (from backend root folder) and, if you change frontend, run `npm run format` (from frontend root folder).
|
||||
@@ -1,147 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
After writing tests, always launch them and verify that they pass.
|
||||
|
||||
## Test Naming Format
|
||||
|
||||
Use these naming patterns:
|
||||
|
||||
- `Test_WhatWeDo_WhatWeExpect`
|
||||
- `Test_WhatWeDo_WhichConditions_WhatWeExpect`
|
||||
|
||||
## Examples from Real Codebase:
|
||||
|
||||
- `Test_CreateApiKey_WhenUserIsProjectOwner_ApiKeyCreated`
|
||||
- `Test_UpdateProject_WhenUserIsProjectAdmin_ProjectUpdated`
|
||||
- `Test_DeleteApiKey_WhenUserIsProjectMember_ReturnsForbidden`
|
||||
- `Test_GetProjectAuditLogs_WithDifferentUserRoles_EnforcesPermissionsCorrectly`
|
||||
- `Test_ProjectLifecycleE2E_CompletesSuccessfully`
|
||||
|
||||
## Testing Philosophy
|
||||
|
||||
**Prefer Controllers Over Unit Tests:**
|
||||
|
||||
- Test through HTTP endpoints via controllers whenever possible
|
||||
- Avoid testing repositories, services in isolation - test via API instead
|
||||
- Only use unit tests for complex model logic when no API exists
|
||||
- Name test files `controller_test.go` or `service_test.go`, not `integration_test.go`
|
||||
|
||||
**Extract Common Logic to Testing Utilities:**
|
||||
|
||||
- Create `testing.go` or `testing/testing.go` files for shared test utilities
|
||||
- Extract router creation, user setup, models creation helpers (in API, not just structs creation)
|
||||
- Reuse common patterns across different test files
|
||||
|
||||
**Refactor Existing Tests:**
|
||||
|
||||
- When working with existing tests, always look for opportunities to refactor and improve
|
||||
- Extract repetitive setup code to common utilities
|
||||
- Simplify complex tests by breaking them into smaller, focused tests
|
||||
- Replace inline test data creation with reusable helper functions
|
||||
- Consolidate similar test patterns across different test files
|
||||
- Make tests more readable and maintainable for other developers
|
||||
|
||||
## Testing Utilities Structure
|
||||
|
||||
**Create `testing.go` or `testing/testing.go` files with common utilities:**
|
||||
|
||||
```go
|
||||
package projects_testing
|
||||
|
||||
// CreateTestRouter creates unified router for all controllers
|
||||
func CreateTestRouter(controllers ...ControllerInterface) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
v1 := router.Group("/api/v1")
|
||||
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
|
||||
|
||||
for _, controller := range controllers {
|
||||
if routerGroup, ok := protected.(*gin.RouterGroup); ok {
|
||||
controller.RegisterRoutes(routerGroup)
|
||||
}
|
||||
}
|
||||
return router
|
||||
}
|
||||
|
||||
// CreateTestProjectViaAPI creates project through HTTP API
|
||||
func CreateTestProjectViaAPI(name string, owner *users_dto.SignInResponseDTO, router *gin.Engine) (*projects_models.Project, string) {
|
||||
request := projects_dto.CreateProjectRequestDTO{Name: name}
|
||||
w := MakeAPIRequest(router, "POST", "/api/v1/projects", "Bearer "+owner.Token, request)
|
||||
// Handle response...
|
||||
return project, owner.Token
|
||||
}
|
||||
|
||||
// AddMemberToProject adds member via API call
|
||||
func AddMemberToProject(project *projects_models.Project, member *users_dto.SignInResponseDTO, role users_enums.ProjectRole, ownerToken string, router *gin.Engine) {
|
||||
// Implementation...
|
||||
}
|
||||
```
|
||||
|
||||
## Controller Test Examples
|
||||
|
||||
**Permission-based testing:**
|
||||
|
||||
```go
|
||||
func Test_CreateApiKey_WhenUserIsProjectOwner_ApiKeyCreated(t *testing.T) {
|
||||
router := CreateApiKeyTestRouter(GetProjectController(), GetMembershipController())
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
project, _ := projects_testing.CreateTestProjectViaAPI("Test Project", owner, router)
|
||||
|
||||
request := CreateApiKeyRequestDTO{Name: "Test API Key"}
|
||||
var response ApiKey
|
||||
test_utils.MakePostRequestAndUnmarshal(t, router, "/api/v1/projects/api-keys/"+project.ID.String(), "Bearer "+owner.Token, request, http.StatusOK, &response)
|
||||
|
||||
assert.Equal(t, "Test API Key", response.Name)
|
||||
assert.NotEmpty(t, response.Token)
|
||||
}
|
||||
```
|
||||
|
||||
**Cross-project security testing:**
|
||||
|
||||
```go
|
||||
func Test_UpdateApiKey_WithApiKeyFromDifferentProject_ReturnsBadRequest(t *testing.T) {
|
||||
router := CreateApiKeyTestRouter(GetProjectController(), GetMembershipController())
|
||||
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
project1, _ := projects_testing.CreateTestProjectViaAPI("Project 1", owner1, router)
|
||||
project2, _ := projects_testing.CreateTestProjectViaAPI("Project 2", owner2, router)
|
||||
|
||||
apiKey := CreateTestApiKey("Cross Project Key", project1.ID, owner1.Token, router)
|
||||
|
||||
// Try to update via different project endpoint
|
||||
request := UpdateApiKeyRequestDTO{Name: &"Hacked Key"}
|
||||
resp := test_utils.MakePutRequest(t, router, "/api/v1/projects/api-keys/"+project2.ID.String()+"/"+apiKey.ID.String(), "Bearer "+owner2.Token, request, http.StatusBadRequest)
|
||||
|
||||
assert.Contains(t, string(resp.Body), "API key does not belong to this project")
|
||||
}
|
||||
```
|
||||
|
||||
**E2E lifecycle testing:**
|
||||
|
||||
```go
|
||||
func Test_ProjectLifecycleE2E_CompletesSuccessfully(t *testing.T) {
|
||||
router := projects_testing.CreateTestRouter(GetProjectController(), GetMembershipController())
|
||||
|
||||
// 1. Create project
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
project := projects_testing.CreateTestProject("E2E Project", owner, router)
|
||||
|
||||
// 2. Add member
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
projects_testing.AddMemberToProject(project, member, users_enums.ProjectRoleMember, owner.Token, router)
|
||||
|
||||
// 3. Promote to admin
|
||||
projects_testing.ChangeMemberRole(project, member.UserID, users_enums.ProjectRoleAdmin, owner.Token, router)
|
||||
|
||||
// 4. Transfer ownership
|
||||
projects_testing.TransferProjectOwnership(project, member.UserID, owner.Token, router)
|
||||
|
||||
// 5. Verify new owner can manage project
|
||||
finalProject := projects_testing.GetProject(project.ID, member.Token, router)
|
||||
assert.Equal(t, project.ID, finalProject.ID)
|
||||
}
|
||||
```
|
||||
@@ -1,6 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
Always use time.Now().UTC() instead of time.Now()
|
||||
@@ -2,8 +2,18 @@
|
||||
DEV_DB_NAME=databasus
|
||||
DEV_DB_USERNAME=postgres
|
||||
DEV_DB_PASSWORD=Q1234567
|
||||
#app
|
||||
# app
|
||||
ENV_MODE=development
|
||||
# logging
|
||||
SHOW_DB_INSTALLATION_VERIFICATION_LOGS=true
|
||||
VICTORIA_LOGS_URL=http://localhost:9428
|
||||
VICTORIA_LOGS_PASSWORD=devpassword
|
||||
# tests
|
||||
TEST_LOCALHOST=localhost
|
||||
IS_SKIP_EXTERNAL_RESOURCES_TESTS=false
|
||||
# cloudflare turnstile
|
||||
CLOUDFLARE_TURNSTILE_SITE_KEY=
|
||||
CLOUDFLARE_TURNSTILE_SECRET_KEY=
|
||||
# db
|
||||
DATABASE_DSN=host=dev-db user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
|
||||
DATABASE_URL=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable
|
||||
@@ -11,6 +21,19 @@ DATABASE_URL=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable
|
||||
GOOSE_DRIVER=postgres
|
||||
GOOSE_DBSTRING=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable
|
||||
GOOSE_MIGRATION_DIR=./migrations
|
||||
# valkey
|
||||
VALKEY_HOST=127.0.0.1
|
||||
VALKEY_PORT=6379
|
||||
VALKEY_USERNAME=
|
||||
VALKEY_PASSWORD=
|
||||
VALKEY_IS_SSL=false
|
||||
# billing
|
||||
PRICE_PER_GB_CENTS=
|
||||
IS_PADDLE_SANDBOX=true
|
||||
PADDLE_API_KEY=
|
||||
PADDLE_WEBHOOK_SECRET=
|
||||
PADDLE_PRICE_ID=
|
||||
PADDLE_CLIENT_TOKEN=
|
||||
# testing
|
||||
# to get Google Drive env variables: add storage in UI and copy data from added storage here
|
||||
TEST_GOOGLE_DRIVE_CLIENT_ID=
|
||||
@@ -29,9 +52,6 @@ TEST_MINIO_PORT=9000
|
||||
TEST_MINIO_CONSOLE_PORT=9001
|
||||
# testing NAS
|
||||
TEST_NAS_PORT=7006
|
||||
# testing Telegram
|
||||
TEST_TELEGRAM_BOT_TOKEN=
|
||||
TEST_TELEGRAM_CHAT_ID=
|
||||
# testing Azure Blob Storage
|
||||
TEST_AZURITE_BLOB_PORT=10000
|
||||
# supabase
|
||||
|
||||
@@ -10,4 +10,10 @@ DATABASE_URL=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disab
|
||||
# migrations
|
||||
GOOSE_DRIVER=postgres
|
||||
GOOSE_DBSTRING=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disable
|
||||
GOOSE_MIGRATION_DIR=./migrations
|
||||
GOOSE_MIGRATION_DIR=./migrations
|
||||
# valkey
|
||||
VALKEY_HOST=127.0.0.1
|
||||
VALKEY_PORT=6379
|
||||
VALKEY_USERNAME=
|
||||
VALKEY_PASSWORD=
|
||||
VALKEY_IS_SSL=false
|
||||
4
backend/.gitignore
vendored
4
backend/.gitignore
vendored
@@ -17,4 +17,6 @@ ui/build/*
|
||||
pgdata-for-restore/
|
||||
temp/
|
||||
cmd.exe
|
||||
temp/
|
||||
temp/
|
||||
valkey-data/
|
||||
victoria-logs-data/
|
||||
@@ -7,6 +7,16 @@ run:
|
||||
|
||||
linters:
|
||||
default: standard
|
||||
enable:
|
||||
- funcorder
|
||||
- bodyclose
|
||||
- errorlint
|
||||
- gocritic
|
||||
- unconvert
|
||||
- misspell
|
||||
- errname
|
||||
- noctx
|
||||
- modernize
|
||||
|
||||
settings:
|
||||
errcheck:
|
||||
@@ -14,6 +24,18 @@ linters:
|
||||
|
||||
formatters:
|
||||
enable:
|
||||
- gofmt
|
||||
- gofumpt
|
||||
- golines
|
||||
- goimports
|
||||
- gci
|
||||
|
||||
settings:
|
||||
golines:
|
||||
max-len: 120
|
||||
gofumpt:
|
||||
module-path: databasus-backend
|
||||
extra-rules: true
|
||||
gci:
|
||||
sections:
|
||||
- standard
|
||||
- default
|
||||
- localmodule
|
||||
|
||||
@@ -2,10 +2,10 @@ run:
|
||||
go run cmd/main.go
|
||||
|
||||
test:
|
||||
go test -p=1 -count=1 -failfast -timeout 10m ./internal/...
|
||||
go test -p=1 -count=1 -failfast -timeout 15m ./internal/...
|
||||
|
||||
lint:
|
||||
golangci-lint fmt && golangci-lint run
|
||||
golangci-lint fmt ./cmd/... ./internal/... && golangci-lint run ./cmd/... ./internal/...
|
||||
|
||||
migration-create:
|
||||
goose create $(name) sql
|
||||
|
||||
@@ -9,13 +9,25 @@ import (
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime/debug"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-contrib/gzip"
|
||||
"github.com/gin-gonic/gin"
|
||||
swaggerFiles "github.com/swaggo/files"
|
||||
ginSwagger "github.com/swaggo/gin-swagger"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_controllers "databasus-backend/internal/features/backups/backups/controllers"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
backups_services "databasus-backend/internal/features/backups/backups/services"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/billing"
|
||||
billing_paddle "databasus-backend/internal/features/billing/paddle"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/disk"
|
||||
"databasus-backend/internal/features/encryption/secrets"
|
||||
@@ -23,22 +35,21 @@ import (
|
||||
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/restores"
|
||||
"databasus-backend/internal/features/restores/restoring"
|
||||
"databasus-backend/internal/features/storages"
|
||||
system_agent "databasus-backend/internal/features/system/agent"
|
||||
system_healthcheck "databasus-backend/internal/features/system/healthcheck"
|
||||
system_version "databasus-backend/internal/features/system/version"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
users_controllers "databasus-backend/internal/features/users/controllers"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
env_utils "databasus-backend/internal/util/env"
|
||||
files_utils "databasus-backend/internal/util/files"
|
||||
"databasus-backend/internal/util/logger"
|
||||
_ "databasus-backend/swagger" // swagger docs
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-contrib/gzip"
|
||||
"github.com/gin-gonic/gin"
|
||||
swaggerFiles "github.com/swaggo/files"
|
||||
ginSwagger "github.com/swaggo/gin-swagger"
|
||||
)
|
||||
|
||||
// @title Databasus Backend API
|
||||
@@ -52,14 +63,29 @@ import (
|
||||
func main() {
|
||||
log := logger.GetLogger()
|
||||
|
||||
runMigrations(log)
|
||||
cache_utils.TestCacheConnection()
|
||||
|
||||
if config.GetEnv().IsPrimaryNode {
|
||||
log.Info("Clearing cache...")
|
||||
|
||||
err := cache_utils.ClearAllCache()
|
||||
if err != nil {
|
||||
log.Error("Failed to clear cache", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
if config.GetEnv().IsPrimaryNode {
|
||||
runMigrations(log)
|
||||
} else {
|
||||
log.Info("Skipping migrations (IS_PRIMARY_NODE is false)")
|
||||
}
|
||||
|
||||
// create directories that used for backups and restore
|
||||
err := files_utils.EnsureDirectories([]string{
|
||||
config.GetEnv().TempFolder,
|
||||
config.GetEnv().DataFolder,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error("Failed to ensure directories", "error", err)
|
||||
os.Exit(1)
|
||||
@@ -82,7 +108,9 @@ func main() {
|
||||
go generateSwaggerDocs(log)
|
||||
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
ginApp := gin.Default()
|
||||
ginApp := gin.New()
|
||||
ginApp.Use(gin.Logger())
|
||||
ginApp.Use(ginRecoveryWithLogger(log))
|
||||
|
||||
// Add GZIP compression middleware
|
||||
ginApp.Use(gzip.Gzip(
|
||||
@@ -96,7 +124,9 @@ func main() {
|
||||
enableCors(ginApp)
|
||||
setUpRoutes(ginApp)
|
||||
setUpDependencies()
|
||||
|
||||
runBackgroundTasks(log)
|
||||
|
||||
mountFrontend(ginApp)
|
||||
|
||||
startServerWithGracefulShutdown(log, ginApp)
|
||||
@@ -124,7 +154,7 @@ func handlePasswordReset(log *slog.Logger) {
|
||||
resetPassword(*email, *newPassword, log)
|
||||
}
|
||||
|
||||
func resetPassword(email string, newPassword string, log *slog.Logger) {
|
||||
func resetPassword(email, newPassword string, log *slog.Logger) {
|
||||
log.Info("Resetting password...")
|
||||
|
||||
userService := users_services.GetUserService()
|
||||
@@ -162,6 +192,9 @@ func startServerWithGracefulShutdown(log *slog.Logger, app *gin.Engine) {
|
||||
<-quit
|
||||
log.Info("Shutdown signal received")
|
||||
|
||||
// Gracefully shutdown VictoriaLogs writer
|
||||
logger.ShutdownVictoriaLogs()
|
||||
|
||||
// The context is used to inform the server it has 10 seconds to finish
|
||||
// the request it is currently handling
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -183,7 +216,15 @@ func setUpRoutes(r *gin.Engine) {
|
||||
userController := users_controllers.GetUserController()
|
||||
userController.RegisterRoutes(v1)
|
||||
system_healthcheck.GetHealthcheckController().RegisterRoutes(v1)
|
||||
backups.GetBackupController().RegisterPublicRoutes(v1)
|
||||
system_version.GetVersionController().RegisterRoutes(v1)
|
||||
system_agent.GetAgentController().RegisterRoutes(v1)
|
||||
backups_controllers.GetBackupController().RegisterPublicRoutes(v1)
|
||||
backups_controllers.GetPostgresWalBackupController().RegisterRoutes(v1)
|
||||
databases.GetDatabaseController().RegisterPublicRoutes(v1)
|
||||
|
||||
if config.GetEnv().IsCloud {
|
||||
billing_paddle.GetPaddleBillingController().RegisterPublicRoutes(v1)
|
||||
}
|
||||
|
||||
// Setup auth middleware
|
||||
userService := users_services.GetUserService()
|
||||
@@ -200,7 +241,7 @@ func setUpRoutes(r *gin.Engine) {
|
||||
notifiers.GetNotifierController().RegisterRoutes(protected)
|
||||
storages.GetStorageController().RegisterRoutes(protected)
|
||||
databases.GetDatabaseController().RegisterRoutes(protected)
|
||||
backups.GetBackupController().RegisterRoutes(protected)
|
||||
backups_controllers.GetBackupController().RegisterRoutes(protected)
|
||||
restores.GetRestoreController().RegisterRoutes(protected)
|
||||
healthcheck_config.GetHealthcheckConfigController().RegisterRoutes(protected)
|
||||
healthcheck_attempt.GetHealthcheckAttemptController().RegisterRoutes(protected)
|
||||
@@ -208,52 +249,109 @@ func setUpRoutes(r *gin.Engine) {
|
||||
audit_logs.GetAuditLogController().RegisterRoutes(protected)
|
||||
users_controllers.GetManagementController().RegisterRoutes(protected)
|
||||
users_controllers.GetSettingsController().RegisterRoutes(protected)
|
||||
billing.GetBillingController().RegisterRoutes(protected)
|
||||
}
|
||||
|
||||
func setUpDependencies() {
|
||||
databases.SetupDependencies()
|
||||
backups.SetupDependencies()
|
||||
backups_services.SetupDependencies()
|
||||
restores.SetupDependencies()
|
||||
healthcheck_config.SetupDependencies()
|
||||
audit_logs.SetupDependencies()
|
||||
notifiers.SetupDependencies()
|
||||
storages.SetupDependencies()
|
||||
backups_config.SetupDependencies()
|
||||
task_cancellation.SetupDependencies()
|
||||
billing.SetupDependencies()
|
||||
|
||||
if config.GetEnv().IsCloud {
|
||||
billing_paddle.SetupDependencies()
|
||||
}
|
||||
}
|
||||
|
||||
func runBackgroundTasks(log *slog.Logger) {
|
||||
log.Info("Preparing to run background tasks...")
|
||||
|
||||
// Create context that will be cancelled on shutdown
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Set up signal handling for graceful shutdown
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-quit
|
||||
log.Info("Shutdown signal received, cancelling all background tasks")
|
||||
cancel()
|
||||
}()
|
||||
|
||||
err := files_utils.CleanFolder(config.GetEnv().TempFolder)
|
||||
if err != nil {
|
||||
log.Error("Failed to clean temp folder", "error", err)
|
||||
}
|
||||
|
||||
go runWithPanicLogging(log, "backup background service", func() {
|
||||
backups.GetBackupBackgroundService().Run()
|
||||
})
|
||||
if config.GetEnv().IsPrimaryNode {
|
||||
log.Info("Starting primary node background tasks...")
|
||||
|
||||
go runWithPanicLogging(log, "restore background service", func() {
|
||||
restores.GetRestoreBackgroundService().Run()
|
||||
})
|
||||
go runWithPanicLogging(log, "backup background service", func() {
|
||||
backuping.GetBackupsScheduler().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "healthcheck attempt background service", func() {
|
||||
healthcheck_attempt.GetHealthcheckAttemptBackgroundService().Run()
|
||||
})
|
||||
go runWithPanicLogging(log, "backup cleaner background service", func() {
|
||||
backuping.GetBackupCleaner().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "audit log cleanup background service", func() {
|
||||
audit_logs.GetAuditLogBackgroundService().Run()
|
||||
})
|
||||
go runWithPanicLogging(log, "restore background service", func() {
|
||||
restoring.GetRestoresScheduler().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "download token cleanup background service", func() {
|
||||
backups.GetDownloadTokenBackgroundService().Run()
|
||||
})
|
||||
go runWithPanicLogging(log, "healthcheck attempt background service", func() {
|
||||
healthcheck_attempt.GetHealthcheckAttemptBackgroundService().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "audit log cleanup background service", func() {
|
||||
audit_logs.GetAuditLogBackgroundService().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "download token cleanup background service", func() {
|
||||
backups_download.GetDownloadTokenBackgroundService().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "backup nodes registry background service", func() {
|
||||
backuping.GetBackupNodesRegistry().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "restore nodes registry background service", func() {
|
||||
restoring.GetRestoreNodesRegistry().Run(ctx)
|
||||
})
|
||||
|
||||
if config.GetEnv().IsCloud {
|
||||
go runWithPanicLogging(log, "billing background service", func() {
|
||||
billing.GetBillingService().Run(ctx, *log)
|
||||
})
|
||||
}
|
||||
} else {
|
||||
log.Info("Skipping primary node tasks as not primary node")
|
||||
}
|
||||
|
||||
if config.GetEnv().IsProcessingNode {
|
||||
log.Info("Starting backup node background tasks...")
|
||||
|
||||
go runWithPanicLogging(log, "backup node", func() {
|
||||
backuping.GetBackuperNode().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "restore node", func() {
|
||||
restoring.GetRestorerNode().Run(ctx)
|
||||
})
|
||||
} else {
|
||||
log.Info("Skipping backup/restore node tasks as not backup node")
|
||||
}
|
||||
}
|
||||
|
||||
func runWithPanicLogging(log *slog.Logger, serviceName string, fn func()) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error("Panic in "+serviceName, "error", r)
|
||||
log.Error("Panic in "+serviceName, "error", r, "stacktrace", string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
fn()
|
||||
@@ -276,7 +374,9 @@ func generateSwaggerDocs(log *slog.Logger) {
|
||||
return
|
||||
}
|
||||
|
||||
cmd := exec.Command("swag", "init", "-d", currentDir, "-g", "cmd/main.go", "-o", "swagger")
|
||||
cmd := exec.CommandContext(
|
||||
context.Background(), "swag", "init", "-d", currentDir, "-g", "cmd/main.go", "-o", "swagger",
|
||||
)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
@@ -290,16 +390,13 @@ func generateSwaggerDocs(log *slog.Logger) {
|
||||
func runMigrations(log *slog.Logger) {
|
||||
log.Info("Running database migrations...")
|
||||
|
||||
cmd := exec.Command("goose", "up")
|
||||
cmd := exec.CommandContext(context.Background(), "goose", "-dir", "./migrations", "up")
|
||||
cmd.Env = append(
|
||||
os.Environ(),
|
||||
"GOOSE_DRIVER=postgres",
|
||||
"GOOSE_DBSTRING="+config.GetEnv().DatabaseDsn,
|
||||
)
|
||||
|
||||
// Set the working directory to where migrations are located
|
||||
cmd.Dir = "./migrations"
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
log.Error("Failed to run migrations", "error", err, "output", string(output))
|
||||
@@ -334,6 +431,25 @@ func enableCors(ginApp *gin.Engine) {
|
||||
}
|
||||
}
|
||||
|
||||
func ginRecoveryWithLogger(log *slog.Logger) gin.HandlerFunc {
|
||||
return func(ctx *gin.Context) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error("Panic recovered in HTTP handler",
|
||||
"error", r,
|
||||
"stacktrace", string(debug.Stack()),
|
||||
"method", ctx.Request.Method,
|
||||
"path", ctx.Request.URL.Path,
|
||||
)
|
||||
|
||||
ctx.AbortWithStatus(http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
|
||||
ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func mountFrontend(ginApp *gin.Engine) {
|
||||
staticDir := "./ui/build"
|
||||
ginApp.NoRoute(func(c *gin.Context) {
|
||||
|
||||
@@ -19,6 +19,35 @@ services:
|
||||
command: -p 5437
|
||||
shm_size: 10gb
|
||||
|
||||
# Valkey for caching
|
||||
dev-valkey:
|
||||
image: valkey/valkey:9.0.1-alpine
|
||||
ports:
|
||||
- "${VALKEY_PORT:-6379}:6379"
|
||||
volumes:
|
||||
- ./valkey-data:/data
|
||||
container_name: dev-valkey
|
||||
healthcheck:
|
||||
test: ["CMD", "valkey-cli", "ping"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
start_period: 20s
|
||||
|
||||
# VictoriaLogs for external logging
|
||||
victoria-logs:
|
||||
image: victoriametrics/victoria-logs:latest
|
||||
container_name: victoria-logs
|
||||
ports:
|
||||
- "9428:9428"
|
||||
command:
|
||||
- -storageDataPath=/victoria-logs-data
|
||||
- -retentionPeriod=7d
|
||||
- -httpAuth.password=devpassword
|
||||
volumes:
|
||||
- ./victoria-logs-data:/victoria-logs-data
|
||||
restart: unless-stopped
|
||||
|
||||
# Test MinIO container
|
||||
test-minio:
|
||||
image: minio/minio:latest
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
module databasus-backend
|
||||
|
||||
go 1.24.4
|
||||
go 1.26.1
|
||||
|
||||
require (
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0
|
||||
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.3
|
||||
github.com/PaddleHQ/paddle-go-sdk v1.0.0
|
||||
github.com/gin-contrib/cors v1.7.5
|
||||
github.com/gin-contrib/gzip v1.2.3
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
@@ -25,9 +26,9 @@ require (
|
||||
github.com/swaggo/files v1.0.1
|
||||
github.com/swaggo/gin-swagger v1.6.0
|
||||
github.com/swaggo/swag v1.16.4
|
||||
github.com/valkey-io/valkey-go v1.0.70
|
||||
go.mongodb.org/mongo-driver v1.17.6
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/time v0.14.0
|
||||
gorm.io/driver/postgres v1.5.11
|
||||
gorm.io/gorm v1.26.1
|
||||
)
|
||||
@@ -100,6 +101,8 @@ require (
|
||||
github.com/emersion/go-message v0.18.2 // indirect
|
||||
github.com/emersion/go-vcard v0.0.0-20241024213814-c9703dde27ff // indirect
|
||||
github.com/flynn/noise v1.1.0 // indirect
|
||||
github.com/ggicci/httpin v0.19.0 // indirect
|
||||
github.com/ggicci/owl v0.8.2 // indirect
|
||||
github.com/go-chi/chi/v5 v5.2.3 // indirect
|
||||
github.com/go-darwin/apfs v0.0.0-20211011131704-f84b94dbf348 // indirect
|
||||
github.com/go-git/go-billy/v5 v5.6.2 // indirect
|
||||
@@ -185,6 +188,7 @@ require (
|
||||
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
golang.org/x/term v0.38.0 // indirect
|
||||
golang.org/x/time v0.14.0 // indirect
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
|
||||
gopkg.in/validator.v2 v2.0.1 // indirect
|
||||
moul.io/http2curl/v2 v2.3.0 // indirect
|
||||
@@ -269,7 +273,7 @@ require (
|
||||
go.opentelemetry.io/otel/metric v1.38.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.38.0 // indirect
|
||||
golang.org/x/arch v0.17.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/net v0.48.0 // indirect
|
||||
golang.org/x/oauth2 v0.33.0
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/sys v0.39.0 // indirect
|
||||
|
||||
@@ -77,6 +77,8 @@ github.com/Max-Sum/base32768 v0.0.0-20230304063302-18e6ce5945fd/go.mod h1:C8yoIf
|
||||
github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY=
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/PaddleHQ/paddle-go-sdk v1.0.0 h1:+EXitsPFbRcc0CpQE/MIeudxiVOR8pFe/aOWTEUHDKU=
|
||||
github.com/PaddleHQ/paddle-go-sdk v1.0.0/go.mod h1:kbBBzf0BHEj38QvhtoELqlGip3alKgA/I+vl7RQzB58=
|
||||
github.com/ProtonMail/bcrypt v0.0.0-20210511135022-227b4adcab57/go.mod h1:HecWFHognK8GfRDGnFQbW/LiV7A3MX3gZVs45vk5h8I=
|
||||
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf h1:yc9daCCYUefEs69zUkSzubzjBbL+cmOXgnmt9Fyd9ug=
|
||||
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf/go.mod h1:o0ESU9p83twszAU8LBeJKFAAMX14tISa0yk4Oo5TOqo=
|
||||
@@ -248,6 +250,10 @@ github.com/gabriel-vasile/mimetype v1.4.11/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9t
|
||||
github.com/geoffgarside/ber v1.1.0/go.mod h1:jVPKeCbj6MvQZhwLYsGwaGI52oUorHoHKNecGT85ZCc=
|
||||
github.com/geoffgarside/ber v1.2.0 h1:/loowoRcs/MWLYmGX9QtIAbA+V/FrnVLsMMPhwiRm64=
|
||||
github.com/geoffgarside/ber v1.2.0/go.mod h1:jVPKeCbj6MvQZhwLYsGwaGI52oUorHoHKNecGT85ZCc=
|
||||
github.com/ggicci/httpin v0.19.0 h1:p0B3SWLVgg770VirYiHB14M5wdRx3zR8mCTzM/TkTQ8=
|
||||
github.com/ggicci/httpin v0.19.0/go.mod h1:hzsQHcbqLabmGOycf7WNw6AAzcVbsMeoOp46bWAbIWc=
|
||||
github.com/ggicci/owl v0.8.2 h1:og+lhqpzSMPDdEB+NJfzoAJARP7qCG3f8uUC3xvGukA=
|
||||
github.com/ggicci/owl v0.8.2/go.mod h1:PHRD57u41vFN5UtFz2SF79yTVoM3HlWpjMiE+ZU2dj4=
|
||||
github.com/gin-contrib/cors v1.7.5 h1:cXC9SmofOrRg0w9PigwGlHG3ztswH6bqq4vJVXnvYMk=
|
||||
github.com/gin-contrib/cors v1.7.5/go.mod h1:4q3yi7xBEDDWKapjT2o1V7mScKDDr8k+jZ0fSquGoy0=
|
||||
github.com/gin-contrib/gzip v1.2.3 h1:dAhT722RuEG330ce2agAs75z7yB+NKvX/ZM1r8w0u2U=
|
||||
@@ -454,6 +460,8 @@ github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1
|
||||
github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk=
|
||||
github.com/jtolio/noiseconn v0.0.0-20231127013910-f6d9ecbf1de7 h1:JcltaO1HXM5S2KYOYcKgAV7slU0xPy1OcvrVgn98sRQ=
|
||||
github.com/jtolio/noiseconn v0.0.0-20231127013910-f6d9ecbf1de7/go.mod h1:MEkhEPFwP3yudWO0lj6vfYpLIB+3eIcuIW+e0AZzUQk=
|
||||
github.com/justinas/alice v1.2.0 h1:+MHSA/vccVCF4Uq37S42jwlkvI2Xzl7zTPCN5BnZNVo=
|
||||
github.com/justinas/alice v1.2.0/go.mod h1:fN5HRH/reO/zrUflLfTN43t3vXvKzvZIENsNEe7i7qA=
|
||||
github.com/jzelinskie/whirlpool v0.0.0-20201016144138-0675e54bb004 h1:G+9t9cEtnC9jFiTxyptEKuNIAbiN5ZCQzX2a74lj3xg=
|
||||
github.com/jzelinskie/whirlpool v0.0.0-20201016144138-0675e54bb004/go.mod h1:KmHnJWQrgEvbuy0vcvj00gtMqbvNn1L+3YUZLK/B92c=
|
||||
github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU=
|
||||
@@ -539,8 +547,8 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
||||
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
|
||||
github.com/onsi/ginkgo/v2 v2.17.3 h1:oJcvKpIb7/8uLpDDtnQuf18xVnwKp8DTD7DQ6gTd/MU=
|
||||
github.com/onsi/ginkgo/v2 v2.17.3/go.mod h1:nP2DPOQoNsQmsVyv5rDA8JkXQoCs6goXIvr/PRJ1eCc=
|
||||
github.com/onsi/gomega v1.37.0 h1:CdEG8g0S133B4OswTDC/5XPSzE1OeP29QOioj2PID2Y=
|
||||
github.com/onsi/gomega v1.37.0/go.mod h1:8D9+Txp43QWKhM24yyOBEdpkzN8FvJyAwecBgsU4KU0=
|
||||
github.com/onsi/gomega v1.38.3 h1:eTX+W6dobAYfFeGC2PV6RwXRu/MyT+cQguijutvkpSM=
|
||||
github.com/onsi/gomega v1.38.3/go.mod h1:ZCU1pkQcXDO5Sl9/VVEGlDyp+zm0m1cmeG5TOzLgdh4=
|
||||
github.com/oracle/oci-go-sdk/v65 v65.104.0 h1:l9awEvzWvxmYhy/97A0hZ87pa7BncYXmcO/S8+rvgK0=
|
||||
github.com/oracle/oci-go-sdk/v65 v65.104.0/go.mod h1:oB8jFGVc/7/zJ+DbleE8MzGHjhs2ioCz5stRTdZdIcY=
|
||||
github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg=
|
||||
@@ -660,6 +668,8 @@ github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY=
|
||||
github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
|
||||
github.com/unknwon/goconfig v1.0.0 h1:rS7O+CmUdli1T+oDm7fYj1MwqNWtEJfNj+FqcUHML8U=
|
||||
github.com/unknwon/goconfig v1.0.0/go.mod h1:qu2ZQ/wcC/if2u32263HTVC39PeOQRSmidQk3DuDFQ8=
|
||||
github.com/valkey-io/valkey-go v1.0.70 h1:mjYNT8qiazxDAJ0QNQ8twWT/YFOkOoRd40ERV2mB49Y=
|
||||
github.com/valkey-io/valkey-go v1.0.70/go.mod h1:VGhZ6fs68Qrn2+OhH+6waZH27bjpgQOiLyUQyXuYK5k=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM=
|
||||
@@ -720,6 +730,8 @@ go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
|
||||
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
|
||||
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
|
||||
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/arch v0.17.0 h1:4O3dfLzd+lQewptAHqjewQZQDyEdejz3VwgeYwkZneU=
|
||||
golang.org/x/arch v0.17.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
@@ -818,8 +830,8 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
||||
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
env_utils "databasus-backend/internal/util/env"
|
||||
"databasus-backend/internal/util/logger"
|
||||
"databasus-backend/internal/util/tools"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ilyakaznacheev/cleanenv"
|
||||
"github.com/joho/godotenv"
|
||||
|
||||
env_utils "databasus-backend/internal/util/env"
|
||||
"databasus-backend/internal/util/logger"
|
||||
"databasus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
var log = logger.GetLogger()
|
||||
@@ -22,17 +24,50 @@ const (
|
||||
|
||||
type EnvVariables struct {
|
||||
IsTesting bool
|
||||
DatabaseDsn string `env:"DATABASE_DSN" required:"true"`
|
||||
EnvMode env_utils.EnvMode `env:"ENV_MODE" required:"true"`
|
||||
PostgresesInstallDir string `env:"POSTGRES_INSTALL_DIR"`
|
||||
MysqlInstallDir string `env:"MYSQL_INSTALL_DIR"`
|
||||
MariadbInstallDir string `env:"MARIADB_INSTALL_DIR"`
|
||||
MongodbInstallDir string `env:"MONGODB_INSTALL_DIR"`
|
||||
|
||||
// Internal database
|
||||
DatabaseDsn string `env:"DATABASE_DSN" required:"true"`
|
||||
// Internal Valkey
|
||||
ValkeyHost string `env:"VALKEY_HOST" required:"true"`
|
||||
ValkeyPort string `env:"VALKEY_PORT" required:"true"`
|
||||
ValkeyUsername string `env:"VALKEY_USERNAME" required:"true"`
|
||||
ValkeyPassword string `env:"VALKEY_PASSWORD" required:"true"`
|
||||
ValkeyIsSsl bool `env:"VALKEY_IS_SSL" required:"true"`
|
||||
|
||||
IsCloud bool `env:"IS_CLOUD"`
|
||||
TestLocalhost string `env:"TEST_LOCALHOST"`
|
||||
|
||||
ShowDbInstallationVerificationLogs bool `env:"SHOW_DB_INSTALLATION_VERIFICATION_LOGS"`
|
||||
IsSkipExternalResourcesTests bool `env:"IS_SKIP_EXTERNAL_RESOURCES_TESTS"`
|
||||
|
||||
IsManyNodesMode bool `env:"IS_MANY_NODES_MODE"`
|
||||
IsPrimaryNode bool `env:"IS_PRIMARY_NODE"`
|
||||
IsProcessingNode bool `env:"IS_PROCESSING_NODE"`
|
||||
NodeNetworkThroughputMBs int `env:"NODE_NETWORK_THROUGHPUT_MBPS"`
|
||||
|
||||
DataFolder string
|
||||
TempFolder string
|
||||
SecretKeyPath string
|
||||
|
||||
// Billing (always tax-exclusive)
|
||||
PricePerGBCents int64 `env:"PRICE_PER_GB_CENTS"`
|
||||
MinStorageGB int
|
||||
MaxStorageGB int
|
||||
TrialDuration time.Duration
|
||||
TrialStorageGB int
|
||||
GracePeriod time.Duration
|
||||
// Paddle billing
|
||||
IsPaddleSandbox bool `env:"IS_PADDLE_SANDBOX"`
|
||||
PaddleApiKey string `env:"PADDLE_API_KEY"`
|
||||
PaddleWebhookSecret string `env:"PADDLE_WEBHOOK_SECRET"`
|
||||
PaddlePriceID string `env:"PADDLE_PRICE_ID"`
|
||||
PaddleClientToken string `env:"PADDLE_CLIENT_TOKEN"`
|
||||
|
||||
TestGoogleDriveClientID string `env:"TEST_GOOGLE_DRIVE_CLIENT_ID"`
|
||||
TestGoogleDriveClientSecret string `env:"TEST_GOOGLE_DRIVE_CLIENT_SECRET"`
|
||||
TestGoogleDriveTokenJSON string `env:"TEST_GOOGLE_DRIVE_TOKEN_JSON"`
|
||||
@@ -85,9 +120,9 @@ type EnvVariables struct {
|
||||
GoogleClientID string `env:"GOOGLE_CLIENT_ID"`
|
||||
GoogleClientSecret string `env:"GOOGLE_CLIENT_SECRET"`
|
||||
|
||||
// testing Telegram
|
||||
TestTelegramBotToken string `env:"TEST_TELEGRAM_BOT_TOKEN"`
|
||||
TestTelegramChatID string `env:"TEST_TELEGRAM_CHAT_ID"`
|
||||
// Cloudflare Turnstile
|
||||
CloudflareTurnstileSecretKey string `env:"CLOUDFLARE_TURNSTILE_SECRET_KEY"`
|
||||
CloudflareTurnstileSiteKey string `env:"CLOUDFLARE_TURNSTILE_SITE_KEY"`
|
||||
|
||||
// testing Supabase
|
||||
TestSupabaseHost string `env:"TEST_SUPABASE_HOST"`
|
||||
@@ -95,16 +130,25 @@ type EnvVariables struct {
|
||||
TestSupabaseUsername string `env:"TEST_SUPABASE_USERNAME"`
|
||||
TestSupabasePassword string `env:"TEST_SUPABASE_PASSWORD"`
|
||||
TestSupabaseDatabase string `env:"TEST_SUPABASE_DATABASE"`
|
||||
|
||||
// SMTP configuration (optional)
|
||||
SMTPHost string `env:"SMTP_HOST"`
|
||||
SMTPPort int `env:"SMTP_PORT"`
|
||||
SMTPUser string `env:"SMTP_USER"`
|
||||
SMTPPassword string `env:"SMTP_PASSWORD"`
|
||||
SMTPFrom string `env:"SMTP_FROM"`
|
||||
|
||||
// Application URL (optional) - used for email links
|
||||
DatabasusURL string `env:"DATABASUS_URL"`
|
||||
}
|
||||
|
||||
var (
|
||||
env EnvVariables
|
||||
once sync.Once
|
||||
)
|
||||
var env EnvVariables
|
||||
|
||||
func GetEnv() EnvVariables {
|
||||
once.Do(loadEnvVariables)
|
||||
return env
|
||||
var initEnv = sync.OnceFunc(loadEnvVariables)
|
||||
|
||||
func GetEnv() *EnvVariables {
|
||||
initEnv()
|
||||
return &env
|
||||
}
|
||||
|
||||
func loadEnvVariables() {
|
||||
@@ -155,6 +199,21 @@ func loadEnvVariables() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Set default value for ShowDbInstallationVerificationLogs if not defined
|
||||
if os.Getenv("SHOW_DB_INSTALLATION_VERIFICATION_LOGS") == "" {
|
||||
env.ShowDbInstallationVerificationLogs = true
|
||||
}
|
||||
|
||||
// Set default value for IsSkipExternalTests if not defined
|
||||
if os.Getenv("IS_SKIP_EXTERNAL_RESOURCES_TESTS") == "" {
|
||||
env.IsSkipExternalResourcesTests = false
|
||||
}
|
||||
|
||||
// Set default value for IsCloud if not defined
|
||||
if os.Getenv("IS_CLOUD") == "" {
|
||||
env.IsCloud = false
|
||||
}
|
||||
|
||||
for _, arg := range os.Args {
|
||||
if strings.Contains(arg, "test") {
|
||||
env.IsTesting = true
|
||||
@@ -162,6 +221,14 @@ func loadEnvVariables() {
|
||||
}
|
||||
}
|
||||
|
||||
// Check for external database override
|
||||
if externalDsn := os.Getenv("DANGEROUS_EXTERNAL_DATABASE_DSN"); externalDsn != "" {
|
||||
log.Warn(
|
||||
"Using DANGEROUS_EXTERNAL_DATABASE_DSN - connecting to external database instead of internal PostgreSQL",
|
||||
)
|
||||
env.DatabaseDsn = externalDsn
|
||||
}
|
||||
|
||||
if env.DatabaseDsn == "" {
|
||||
log.Error("DATABASE_DSN is empty")
|
||||
os.Exit(1)
|
||||
@@ -178,16 +245,80 @@ func loadEnvVariables() {
|
||||
log.Info("ENV_MODE loaded", "mode", env.EnvMode)
|
||||
|
||||
env.PostgresesInstallDir = filepath.Join(backendRoot, "tools", "postgresql")
|
||||
tools.VerifyPostgresesInstallation(log, env.EnvMode, env.PostgresesInstallDir)
|
||||
tools.VerifyPostgresesInstallation(
|
||||
log,
|
||||
env.EnvMode,
|
||||
env.PostgresesInstallDir,
|
||||
env.ShowDbInstallationVerificationLogs,
|
||||
)
|
||||
|
||||
env.MysqlInstallDir = filepath.Join(backendRoot, "tools", "mysql")
|
||||
tools.VerifyMysqlInstallation(log, env.EnvMode, env.MysqlInstallDir)
|
||||
tools.VerifyMysqlInstallation(
|
||||
log,
|
||||
env.EnvMode,
|
||||
env.MysqlInstallDir,
|
||||
env.ShowDbInstallationVerificationLogs,
|
||||
)
|
||||
|
||||
env.MariadbInstallDir = filepath.Join(backendRoot, "tools", "mariadb")
|
||||
tools.VerifyMariadbInstallation(log, env.EnvMode, env.MariadbInstallDir)
|
||||
tools.VerifyMariadbInstallation(
|
||||
log,
|
||||
env.EnvMode,
|
||||
env.MariadbInstallDir,
|
||||
env.ShowDbInstallationVerificationLogs,
|
||||
)
|
||||
|
||||
env.MongodbInstallDir = filepath.Join(backendRoot, "tools", "mongodb")
|
||||
tools.VerifyMongodbInstallation(log, env.EnvMode, env.MongodbInstallDir)
|
||||
tools.VerifyMongodbInstallation(
|
||||
log,
|
||||
env.EnvMode,
|
||||
env.MongodbInstallDir,
|
||||
env.ShowDbInstallationVerificationLogs,
|
||||
)
|
||||
|
||||
if env.NodeNetworkThroughputMBs == 0 {
|
||||
env.NodeNetworkThroughputMBs = 125 // 1 Gbit/s
|
||||
}
|
||||
|
||||
if !env.IsManyNodesMode {
|
||||
env.IsPrimaryNode = true
|
||||
env.IsProcessingNode = true
|
||||
}
|
||||
|
||||
if env.TestLocalhost == "" {
|
||||
env.TestLocalhost = "localhost"
|
||||
}
|
||||
|
||||
// Valkey
|
||||
if env.ValkeyHost == "" {
|
||||
log.Error("VALKEY_HOST is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
if env.ValkeyPort == "" {
|
||||
log.Error("VALKEY_PORT is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Check for external Valkey override
|
||||
if externalValkeyHost := os.Getenv("DANGEROUS_VALKEY_HOST"); externalValkeyHost != "" {
|
||||
log.Warn(
|
||||
"Using DANGEROUS_VALKEY_* variables - connecting to external Valkey instead of internal instance",
|
||||
)
|
||||
env.ValkeyHost = externalValkeyHost
|
||||
|
||||
if externalValkeyPort := os.Getenv("DANGEROUS_VALKEY_PORT"); externalValkeyPort != "" {
|
||||
env.ValkeyPort = externalValkeyPort
|
||||
}
|
||||
if externalValkeyUsername := os.Getenv("DANGEROUS_VALKEY_USERNAME"); externalValkeyUsername != "" {
|
||||
env.ValkeyUsername = externalValkeyUsername
|
||||
}
|
||||
if externalValkeyPassword := os.Getenv("DANGEROUS_VALKEY_PASSWORD"); externalValkeyPassword != "" {
|
||||
env.ValkeyPassword = externalValkeyPassword
|
||||
}
|
||||
if externalValkeyIsSsl := os.Getenv("DANGEROUS_VALKEY_IS_SSL"); externalValkeyIsSsl != "" {
|
||||
env.ValkeyIsSsl = externalValkeyIsSsl == "true"
|
||||
}
|
||||
}
|
||||
|
||||
// Store the data and temp folders one level below the root
|
||||
// (projectRoot/databasus-data -> /databasus-data)
|
||||
@@ -244,16 +375,41 @@ func loadEnvVariables() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if env.TestTelegramBotToken == "" {
|
||||
log.Error("TEST_TELEGRAM_BOT_TOKEN is empty")
|
||||
}
|
||||
|
||||
// Billing
|
||||
if env.IsCloud {
|
||||
if env.PricePerGBCents == 0 {
|
||||
log.Error("PRICE_PER_GB_CENTS is empty or zero")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if env.TestTelegramChatID == "" {
|
||||
log.Error("TEST_TELEGRAM_CHAT_ID is empty")
|
||||
if env.PaddleApiKey == "" {
|
||||
log.Error("PADDLE_API_KEY is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if env.PaddleWebhookSecret == "" {
|
||||
log.Error("PADDLE_WEBHOOK_SECRET is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if env.PaddlePriceID == "" {
|
||||
log.Error("PADDLE_PRICE_ID is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if env.PaddleClientToken == "" {
|
||||
log.Error("PADDLE_CLIENT_TOKEN is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
env.MinStorageGB = 20
|
||||
env.MaxStorageGB = 10_000
|
||||
env.TrialDuration = 24 * time.Hour
|
||||
env.TrialStorageGB = 20
|
||||
env.GracePeriod = 30 * 24 * time.Hour
|
||||
|
||||
log.Info("Environment variables loaded successfully!")
|
||||
}
|
||||
|
||||
@@ -1,33 +1,43 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type AuditLogBackgroundService struct {
|
||||
auditLogService *AuditLogService
|
||||
logger *slog.Logger
|
||||
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *AuditLogBackgroundService) Run() {
|
||||
func (s *AuditLogBackgroundService) Run(ctx context.Context) {
|
||||
if s.hasRun.Swap(true) {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
|
||||
s.logger.Info("Starting audit log cleanup background service")
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
if config.IsShouldShutdown() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.cleanOldAuditLogs(); err != nil {
|
||||
s.logger.Error("Failed to clean old audit logs", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.cleanOldAuditLogs(); err != nil {
|
||||
s.logger.Error("Failed to clean old audit logs", "error", err)
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/storage"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gorm.io/gorm"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
"databasus-backend/internal/storage"
|
||||
)
|
||||
|
||||
func Test_CleanOldAuditLogs_DeletesLogsOlderThanOneYear(t *testing.T) {
|
||||
@@ -102,7 +102,7 @@ func Test_CleanOldAuditLogs_DeletesMultipleOldLogs(t *testing.T) {
|
||||
|
||||
// Create many old logs with specific UUIDs to track them
|
||||
testLogIDs := make([]uuid.UUID, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
for i := range 5 {
|
||||
testLogIDs[i] = uuid.New()
|
||||
daysAgo := 400 + (i * 10)
|
||||
log := &AuditLog{
|
||||
|
||||
@@ -4,10 +4,10 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
)
|
||||
|
||||
type AuditLogController struct {
|
||||
|
||||
@@ -6,15 +6,15 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_GetGlobalAuditLogs_WithDifferentUserRoles_EnforcesPermissionsCorrectly(t *testing.T) {
|
||||
|
||||
@@ -1,21 +1,27 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var auditLogRepository = &AuditLogRepository{}
|
||||
var auditLogService = &AuditLogService{
|
||||
auditLogRepository,
|
||||
logger.GetLogger(),
|
||||
}
|
||||
var (
|
||||
auditLogRepository = &AuditLogRepository{}
|
||||
auditLogService = &AuditLogService{
|
||||
auditLogRepository,
|
||||
logger.GetLogger(),
|
||||
}
|
||||
)
|
||||
|
||||
var auditLogController = &AuditLogController{
|
||||
auditLogService,
|
||||
}
|
||||
|
||||
var auditLogBackgroundService = &AuditLogBackgroundService{
|
||||
auditLogService,
|
||||
logger.GetLogger(),
|
||||
auditLogService: auditLogService,
|
||||
logger: logger.GetLogger(),
|
||||
}
|
||||
|
||||
func GetAuditLogService() *AuditLogService {
|
||||
@@ -30,8 +36,8 @@ func GetAuditLogBackgroundService() *AuditLogBackgroundService {
|
||||
return auditLogBackgroundService
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
var SetupDependencies = sync.OnceFunc(func() {
|
||||
users_services.GetUserService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/storage"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/storage"
|
||||
)
|
||||
|
||||
type AuditLogRepository struct{}
|
||||
@@ -21,7 +22,7 @@ func (r *AuditLogRepository) GetGlobal(
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLogDTO, error) {
|
||||
var auditLogs = make([]*AuditLogDTO, 0)
|
||||
auditLogs := make([]*AuditLogDTO, 0)
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
@@ -37,7 +38,7 @@ func (r *AuditLogRepository) GetGlobal(
|
||||
LEFT JOIN users u ON al.user_id = u.id
|
||||
LEFT JOIN workspaces w ON al.workspace_id = w.id`
|
||||
|
||||
args := []interface{}{}
|
||||
args := []any{}
|
||||
|
||||
if beforeDate != nil {
|
||||
sql += " WHERE al.created_at < ?"
|
||||
@@ -57,7 +58,7 @@ func (r *AuditLogRepository) GetByUser(
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLogDTO, error) {
|
||||
var auditLogs = make([]*AuditLogDTO, 0)
|
||||
auditLogs := make([]*AuditLogDTO, 0)
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
@@ -74,7 +75,7 @@ func (r *AuditLogRepository) GetByUser(
|
||||
LEFT JOIN workspaces w ON al.workspace_id = w.id
|
||||
WHERE al.user_id = ?`
|
||||
|
||||
args := []interface{}{userID}
|
||||
args := []any{userID}
|
||||
|
||||
if beforeDate != nil {
|
||||
sql += " AND al.created_at < ?"
|
||||
@@ -94,7 +95,7 @@ func (r *AuditLogRepository) GetByWorkspace(
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLogDTO, error) {
|
||||
var auditLogs = make([]*AuditLogDTO, 0)
|
||||
auditLogs := make([]*AuditLogDTO, 0)
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
@@ -111,7 +112,7 @@ func (r *AuditLogRepository) GetByWorkspace(
|
||||
LEFT JOIN workspaces w ON al.workspace_id = w.id
|
||||
WHERE al.workspace_id = ?`
|
||||
|
||||
args := []interface{}{workspaceID}
|
||||
args := []any{workspaceID}
|
||||
|
||||
if beforeDate != nil {
|
||||
sql += " AND al.created_at < ?"
|
||||
|
||||
@@ -4,10 +4,10 @@ import (
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogService struct {
|
||||
|
||||
@@ -4,11 +4,11 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
)
|
||||
|
||||
func Test_AuditLogs_WorkspaceSpecificLogs(t *testing.T) {
|
||||
|
||||
@@ -1,254 +0,0 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/storages"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/period"
|
||||
"log/slog"
|
||||
"time"
|
||||
)
|
||||
|
||||
type BackupBackgroundService struct {
|
||||
backupService *BackupService
|
||||
backupRepository *BackupRepository
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
storageService *storages.StorageService
|
||||
|
||||
lastBackupTime time.Time
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) Run() {
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
|
||||
if err := s.failBackupsInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail backups in progress", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
if config.IsShouldShutdown() {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.cleanOldBackups(); err != nil {
|
||||
s.logger.Error("Failed to clean old backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.runPendingBackups(); err != nil {
|
||||
s.logger.Error("Failed to run pending backups", "error", err)
|
||||
}
|
||||
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
time.Sleep(1 * time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) IsBackupsWorkerRunning() bool {
|
||||
// if last backup time is more than 5 minutes ago, return false
|
||||
return s.lastBackupTime.After(time.Now().UTC().Add(-5 * time.Minute))
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) failBackupsInProgress() error {
|
||||
backupsInProgress, err := s.backupRepository.FindByStatus(BackupStatusInProgress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backup := range backupsInProgress {
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(backup.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
failMessage := "Backup failed due to application restart"
|
||||
backup.FailMessage = &failMessage
|
||||
backup.Status = BackupStatusFailed
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
s.backupService.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupFailed,
|
||||
&failMessage,
|
||||
)
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) cleanOldBackups() error {
|
||||
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
backupStorePeriod := backupConfig.StorePeriod
|
||||
|
||||
if backupStorePeriod == period.PeriodForever {
|
||||
continue
|
||||
}
|
||||
|
||||
storeDuration := backupStorePeriod.ToDuration()
|
||||
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
|
||||
|
||||
oldBackups, err := s.backupRepository.FindBackupsBeforeDate(
|
||||
backupConfig.DatabaseID,
|
||||
dateBeforeBackupsShouldBeDeleted,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to find old backups for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, backup := range oldBackups {
|
||||
storage, err := s.storageService.GetStorageByID(backup.StorageID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to get storage by ID",
|
||||
"storageId",
|
||||
backup.StorageID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
err = storage.DeleteFile(encryptor, backup.ID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
|
||||
}
|
||||
|
||||
if err := s.backupRepository.DeleteByID(backup.ID); err != nil {
|
||||
s.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Deleted old backup",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) runPendingBackups() error {
|
||||
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
if backupConfig.BackupInterval == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
lastBackup, err := s.backupRepository.FindLastByDatabaseID(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to get last backup for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
var lastBackupTime *time.Time
|
||||
if lastBackup != nil {
|
||||
lastBackupTime = &lastBackup.CreatedAt
|
||||
}
|
||||
|
||||
remainedBackupTryCount := s.GetRemainedBackupTryCount(lastBackup)
|
||||
|
||||
if backupConfig.BackupInterval.ShouldTriggerBackup(time.Now().UTC(), lastBackupTime) ||
|
||||
remainedBackupTryCount > 0 {
|
||||
s.logger.Info(
|
||||
"Triggering scheduled backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"intervalType",
|
||||
backupConfig.BackupInterval.Interval,
|
||||
)
|
||||
|
||||
go s.backupService.MakeBackup(backupConfig.DatabaseID, remainedBackupTryCount == 1)
|
||||
s.logger.Info(
|
||||
"Successfully triggered scheduled backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRemainedBackupTryCount returns the number of remaining backup tries for a given backup.
|
||||
// If the backup is not failed or the backup config does not allow retries, it returns 0.
|
||||
// If the backup is failed and the backup config allows retries, it returns the number of remaining tries.
|
||||
// If the backup is failed and the backup config does not allow retries, it returns 0.
|
||||
func (s *BackupBackgroundService) GetRemainedBackupTryCount(lastBackup *Backup) int {
|
||||
if lastBackup == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if lastBackup.Status != BackupStatusFailed {
|
||||
return 0
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(lastBackup.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
if !backupConfig.IsRetryIfFailed {
|
||||
return 0
|
||||
}
|
||||
|
||||
maxFailedTriesCount := backupConfig.MaxFailedTriesCount
|
||||
|
||||
lastBackups, err := s.backupRepository.FindByDatabaseIDWithLimit(
|
||||
lastBackup.DatabaseID,
|
||||
maxFailedTriesCount,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to find last backups by database ID", "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
lastFailedBackups := make([]*Backup, 0)
|
||||
|
||||
for _, backup := range lastBackups {
|
||||
if backup.Status == BackupStatusFailed {
|
||||
lastFailedBackups = append(lastFailedBackups, backup)
|
||||
}
|
||||
}
|
||||
|
||||
return maxFailedTriesCount - len(lastFailedBackups)
|
||||
}
|
||||
@@ -1,389 +0,0 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/util/period"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Enable backups for the database
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add old backup
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusCompleted,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
// Wait for backup to complete (runs in goroutine)
|
||||
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 2)
|
||||
}
|
||||
|
||||
func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Enable backups for the database
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add recent backup (1 hour ago)
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusCompleted,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1) // Should still be 1 backup, no new backup created
|
||||
}
|
||||
|
||||
func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T) {
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Enable backups for the database with retries disabled
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
backupConfig.IsRetryIfFailed = false
|
||||
backupConfig.MaxFailedTriesCount = 0
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add failed backup
|
||||
failMessage := "backup failed"
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1) // Should still be 1 backup, no retry attempted
|
||||
}
|
||||
|
||||
func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Enable backups for the database with retries enabled
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
backupConfig.IsRetryIfFailed = true
|
||||
backupConfig.MaxFailedTriesCount = 3
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add failed backup
|
||||
failMessage := "backup failed"
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
// Wait for backup to complete (runs in goroutine)
|
||||
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 2) // Should have 2 backups, retry was attempted
|
||||
}
|
||||
|
||||
func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *testing.T) {
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Enable backups for the database with retries enabled
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
backupConfig.IsRetryIfFailed = true
|
||||
backupConfig.MaxFailedTriesCount = 3
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
failMessage := "backup failed"
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
}
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 3) // Should have 3 backups, not more than max
|
||||
}
|
||||
|
||||
func Test_MakeBackgroundBackupWhenBakupsDisabled_BackupSkipped(t *testing.T) {
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = false
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add old backup that would trigger new backup if enabled
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusCompleted,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BackupContextManager struct {
|
||||
mu sync.RWMutex
|
||||
cancelFuncs map[uuid.UUID]context.CancelFunc
|
||||
cancelledBackups map[uuid.UUID]bool
|
||||
}
|
||||
|
||||
func NewBackupContextManager() *BackupContextManager {
|
||||
return &BackupContextManager{
|
||||
cancelFuncs: make(map[uuid.UUID]context.CancelFunc),
|
||||
cancelledBackups: make(map[uuid.UUID]bool),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) RegisterBackup(backupID uuid.UUID, cancelFunc context.CancelFunc) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.cancelFuncs[backupID] = cancelFunc
|
||||
delete(m.cancelledBackups, backupID)
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) CancelBackup(backupID uuid.UUID) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.cancelledBackups[backupID] {
|
||||
return nil
|
||||
}
|
||||
|
||||
cancelFunc, exists := m.cancelFuncs[backupID]
|
||||
if exists {
|
||||
cancelFunc()
|
||||
delete(m.cancelFuncs, backupID)
|
||||
}
|
||||
|
||||
m.cancelledBackups[backupID] = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) IsCancelled(backupID uuid.UUID) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.cancelledBackups[backupID]
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) UnregisterBackup(backupID uuid.UUID) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.cancelFuncs, backupID)
|
||||
delete(m.cancelledBackups, backupID)
|
||||
}
|
||||
441
backend/internal/features/backups/backups/backuping/backuper.go
Normal file
441
backend/internal/features/backups/backups/backuping/backuper.go
Normal file
@@ -0,0 +1,441 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
)
|
||||
|
||||
const (
|
||||
heartbeatTickerInterval = 15 * time.Second
|
||||
backuperHeathcheckThreshold = 5 * time.Minute
|
||||
)
|
||||
|
||||
type BackuperNode struct {
|
||||
databaseService *databases.DatabaseService
|
||||
fieldEncryptor util_encryption.FieldEncryptor
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
backupRepository *backups_core.BackupRepository
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
storageService *storages.StorageService
|
||||
notificationSender backups_core.NotificationSender
|
||||
backupCancelManager *tasks_cancellation.TaskCancelManager
|
||||
backupNodesRegistry *BackupNodesRegistry
|
||||
logger *slog.Logger
|
||||
createBackupUseCase backups_core.CreateBackupUsecase
|
||||
nodeID uuid.UUID
|
||||
|
||||
lastHeartbeat time.Time
|
||||
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (n *BackuperNode) Run(ctx context.Context) {
|
||||
if n.hasRun.Swap(true) {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", n))
|
||||
}
|
||||
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
|
||||
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
|
||||
|
||||
backupNode := BackupNode{
|
||||
ID: n.nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: time.Now().UTC(),
|
||||
}
|
||||
|
||||
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
|
||||
n.logger.Error("Failed to register node in registry", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
|
||||
go func() {
|
||||
n.MakeBackup(backupID, isCallNotifier)
|
||||
if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil {
|
||||
n.logger.Error(
|
||||
"Failed to publish backup completion",
|
||||
"error",
|
||||
err,
|
||||
"backupID",
|
||||
backupID,
|
||||
)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
err := n.backupNodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID, backupHandler)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := n.backupNodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
|
||||
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(heartbeatTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
|
||||
|
||||
if err := n.backupNodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
|
||||
n.logger.Error("Failed to unregister node from registry", "error", err)
|
||||
}
|
||||
|
||||
return
|
||||
case <-ticker.C:
|
||||
n.sendHeartbeat(&backupNode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (n *BackuperNode) IsBackuperRunning() bool {
|
||||
return n.lastHeartbeat.After(time.Now().UTC().Add(-backuperHeathcheckThreshold))
|
||||
}
|
||||
|
||||
func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
backup, err := n.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get backup by ID", "backupId", backupID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
databaseID := backup.DatabaseID
|
||||
|
||||
database, err := n.databaseService.GetDatabaseByID(databaseID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get database by ID", "databaseId", databaseID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
backupConfig, err := n.backupConfigService.GetBackupConfigByDbId(databaseID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if backupConfig.StorageID == nil {
|
||||
n.logger.Error("Backup config storage ID is not defined")
|
||||
return
|
||||
}
|
||||
|
||||
storage, err := n.storageService.GetStorageByID(*backupConfig.StorageID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get storage by ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now().UTC()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
n.backupCancelManager.RegisterTask(backup.ID, cancel)
|
||||
defer n.backupCancelManager.UnregisterTask(backup.ID)
|
||||
|
||||
backupProgressListener := func(
|
||||
completedMBs float64,
|
||||
) {
|
||||
backup.BackupSizeMb = completedMBs
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to update backup progress", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
backupMetadata, err := n.createBackupUseCase.Execute(
|
||||
ctx,
|
||||
backup,
|
||||
backupConfig,
|
||||
database,
|
||||
storage,
|
||||
backupProgressListener,
|
||||
)
|
||||
if err != nil {
|
||||
// Check if backup was already marked as failed by progress listener (e.g., size limit exceeded)
|
||||
// If so, skip error handling to avoid overwriting the status
|
||||
currentBackup, fetchErr := n.backupRepository.FindByID(backup.ID)
|
||||
if fetchErr == nil && currentBackup.Status == backups_core.BackupStatusFailed {
|
||||
n.logger.Warn(
|
||||
"Backup already marked as failed by progress listener, skipping error handling",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"failMessage",
|
||||
*currentBackup.FailMessage,
|
||||
)
|
||||
|
||||
// Still call notification for size limit failures
|
||||
n.SendBackupNotification(
|
||||
backupConfig,
|
||||
currentBackup,
|
||||
backups_config.NotificationBackupFailed,
|
||||
currentBackup.FailMessage,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
errMsg := err.Error()
|
||||
|
||||
// Log detailed error information for debugging
|
||||
n.logger.Error("Backup execution failed",
|
||||
"backupId", backup.ID,
|
||||
"databaseId", databaseID,
|
||||
"databaseType", database.Type,
|
||||
"storageId", storage.ID,
|
||||
"storageType", storage.Type,
|
||||
"error", err,
|
||||
"errorMessage", errMsg,
|
||||
)
|
||||
|
||||
// Check if backup was cancelled (not due to shutdown)
|
||||
isCancelled := strings.Contains(errMsg, "backup cancelled") ||
|
||||
strings.Contains(errMsg, "context canceled") ||
|
||||
errors.Is(err, context.Canceled)
|
||||
isShutdown := strings.Contains(errMsg, "shutdown")
|
||||
|
||||
if isCancelled && !isShutdown {
|
||||
n.logger.Warn("Backup was cancelled by user or system",
|
||||
"backupId", backup.ID,
|
||||
"isCancelled", isCancelled,
|
||||
"isShutdown", isShutdown,
|
||||
)
|
||||
|
||||
backup.Status = backups_core.BackupStatusCanceled
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save cancelled backup", "error", err)
|
||||
}
|
||||
|
||||
// Delete partial backup from storage
|
||||
storage, storageErr := n.storageService.GetStorageByID(backup.StorageID)
|
||||
if storageErr == nil {
|
||||
if deleteErr := storage.DeleteFile(n.fieldEncryptor, backup.FileName); deleteErr != nil {
|
||||
n.logger.Error(
|
||||
"Failed to delete partial backup file",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
deleteErr,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
backup.FailMessage = &errMsg
|
||||
backup.Status = backups_core.BackupStatusFailed
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
if updateErr := n.databaseService.SetBackupError(databaseID, errMsg); updateErr != nil {
|
||||
n.logger.Error(
|
||||
"Failed to update database last backup time",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
updateErr,
|
||||
)
|
||||
}
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save backup", "error", err)
|
||||
}
|
||||
|
||||
n.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupFailed,
|
||||
&errMsg,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
backup.Status = backups_core.BackupStatusCompleted
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
// Update backup with encryption metadata if provided
|
||||
if backupMetadata != nil {
|
||||
backupMetadata.BackupID = backup.ID
|
||||
|
||||
if err := backupMetadata.Validate(); err != nil {
|
||||
n.logger.Error("Failed to validate backup metadata", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
backup.EncryptionSalt = backupMetadata.EncryptionSalt
|
||||
backup.EncryptionIV = backupMetadata.EncryptionIV
|
||||
backup.Encryption = backupMetadata.Encryption
|
||||
}
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save backup", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Save metadata file to storage
|
||||
if backupMetadata != nil {
|
||||
metadataJSON, err := json.Marshal(backupMetadata)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to marshal backup metadata to JSON",
|
||||
"backupId", backup.ID,
|
||||
"error", err,
|
||||
)
|
||||
} else {
|
||||
metadataReader := bytes.NewReader(metadataJSON)
|
||||
metadataFileName := backup.FileName + ".metadata"
|
||||
|
||||
if err := storage.SaveFile(
|
||||
context.Background(),
|
||||
n.fieldEncryptor,
|
||||
n.logger,
|
||||
metadataFileName,
|
||||
metadataReader,
|
||||
); err != nil {
|
||||
n.logger.Error("Failed to save backup metadata file to storage",
|
||||
"backupId", backup.ID,
|
||||
"fileName", metadataFileName,
|
||||
"error", err,
|
||||
)
|
||||
} else {
|
||||
n.logger.Info("Backup metadata file saved successfully",
|
||||
"backupId", backup.ID,
|
||||
"fileName", metadataFileName,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update database last backup time
|
||||
now := time.Now().UTC()
|
||||
if updateErr := n.databaseService.SetLastBackupTime(databaseID, now); updateErr != nil {
|
||||
n.logger.Error(
|
||||
"Failed to update database last backup time",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
updateErr,
|
||||
)
|
||||
}
|
||||
|
||||
if backup.Status != backups_core.BackupStatusCompleted && !isCallNotifier {
|
||||
return
|
||||
}
|
||||
|
||||
n.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupSuccess,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
func (n *BackuperNode) SendBackupNotification(
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
backup *backups_core.Backup,
|
||||
notificationType backups_config.BackupNotificationType,
|
||||
errorMessage *string,
|
||||
) {
|
||||
database, err := n.databaseService.GetDatabaseByID(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
workspace, err := n.workspaceService.GetWorkspaceByID(*database.WorkspaceID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, notifier := range database.Notifiers {
|
||||
if !slices.Contains(
|
||||
backupConfig.SendNotificationsOn,
|
||||
notificationType,
|
||||
) {
|
||||
continue
|
||||
}
|
||||
|
||||
title := ""
|
||||
switch notificationType {
|
||||
case backups_config.NotificationBackupFailed:
|
||||
title = fmt.Sprintf(
|
||||
"❌ Backup failed for database \"%s\" (workspace \"%s\")",
|
||||
database.Name,
|
||||
workspace.Name,
|
||||
)
|
||||
case backups_config.NotificationBackupSuccess:
|
||||
title = fmt.Sprintf(
|
||||
"✅ Backup completed for database \"%s\" (workspace \"%s\")",
|
||||
database.Name,
|
||||
workspace.Name,
|
||||
)
|
||||
}
|
||||
|
||||
message := ""
|
||||
if errorMessage != nil {
|
||||
message = *errorMessage
|
||||
} else {
|
||||
// Format size conditionally
|
||||
var sizeStr string
|
||||
if backup.BackupSizeMb < 1024 {
|
||||
sizeStr = fmt.Sprintf("%.2f MB", backup.BackupSizeMb)
|
||||
} else {
|
||||
sizeGB := backup.BackupSizeMb / 1024
|
||||
sizeStr = fmt.Sprintf("%.2f GB", sizeGB)
|
||||
}
|
||||
|
||||
// Format duration as "0m 0s 0ms"
|
||||
totalMs := backup.BackupDurationMs
|
||||
minutes := totalMs / (1000 * 60)
|
||||
seconds := (totalMs % (1000 * 60)) / 1000
|
||||
durationStr := fmt.Sprintf("%dm %ds", minutes, seconds)
|
||||
|
||||
message = fmt.Sprintf(
|
||||
"Backup completed successfully in %s.\nCompressed backup size: %s",
|
||||
durationStr,
|
||||
sizeStr,
|
||||
)
|
||||
}
|
||||
|
||||
n.notificationSender.SendNotification(
|
||||
¬ifier,
|
||||
title,
|
||||
message,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (n *BackuperNode) sendHeartbeat(backupNode *BackupNode) {
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil {
|
||||
n.logger.Error("Failed to send heartbeat", "error", err)
|
||||
}
|
||||
}
|
||||
@@ -1,31 +1,26 @@
|
||||
package backups
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
)
|
||||
|
||||
func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
@@ -50,23 +45,19 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
|
||||
t.Run("BackupFailed_FailNotificationSent", func(t *testing.T) {
|
||||
mockNotificationSender := &MockNotificationSender{}
|
||||
backupService := &BackupService{
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
backupRepository,
|
||||
notifiers.GetNotifierService(),
|
||||
mockNotificationSender,
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption_secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
&CreateFailedBackupUsecase{},
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
nil,
|
||||
NewBackupContextManager(),
|
||||
nil,
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.notificationSender = mockNotificationSender
|
||||
backuperNode.createBackupUseCase = &CreateFailedBackupUsecase{}
|
||||
|
||||
// Create a backup record directly that will be looked up by MakeBackup
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err := backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Set up expectations
|
||||
mockNotificationSender.On("SendNotification",
|
||||
@@ -79,7 +70,7 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
}),
|
||||
).Once()
|
||||
|
||||
backupService.MakeBackup(database.ID, true)
|
||||
backuperNode.MakeBackup(backup.ID, true)
|
||||
|
||||
// Verify all expectations were met
|
||||
mockNotificationSender.AssertExpectations(t)
|
||||
@@ -87,6 +78,19 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
|
||||
t.Run("BackupSuccess_SuccessNotificationSent", func(t *testing.T) {
|
||||
mockNotificationSender := &MockNotificationSender{}
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.notificationSender = mockNotificationSender
|
||||
backuperNode.createBackupUseCase = &CreateSuccessBackupUsecase{}
|
||||
|
||||
// Create a backup record directly that will be looked up by MakeBackup
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err := backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Set up expectations
|
||||
mockNotificationSender.On("SendNotification",
|
||||
@@ -99,25 +103,7 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
}),
|
||||
).Once()
|
||||
|
||||
backupService := &BackupService{
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
backupRepository,
|
||||
notifiers.GetNotifierService(),
|
||||
mockNotificationSender,
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption_secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
&CreateSuccessBackupUsecase{},
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
nil,
|
||||
NewBackupContextManager(),
|
||||
nil,
|
||||
}
|
||||
|
||||
backupService.MakeBackup(database.ID, true)
|
||||
backuperNode.MakeBackup(backup.ID, true)
|
||||
|
||||
// Verify all expectations were met
|
||||
mockNotificationSender.AssertExpectations(t)
|
||||
@@ -125,23 +111,19 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
|
||||
t.Run("BackupSuccess_VerifyNotificationContent", func(t *testing.T) {
|
||||
mockNotificationSender := &MockNotificationSender{}
|
||||
backupService := &BackupService{
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
backupRepository,
|
||||
notifiers.GetNotifierService(),
|
||||
mockNotificationSender,
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption_secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
&CreateSuccessBackupUsecase{},
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
nil,
|
||||
NewBackupContextManager(),
|
||||
nil,
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.notificationSender = mockNotificationSender
|
||||
backuperNode.createBackupUseCase = &CreateSuccessBackupUsecase{}
|
||||
|
||||
// Create a backup record directly that will be looked up by MakeBackup
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err := backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// capture arguments
|
||||
var capturedNotifier *notifiers.Notifier
|
||||
@@ -158,7 +140,7 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
capturedMessage = args.Get(2).(string)
|
||||
}).Once()
|
||||
|
||||
backupService.MakeBackup(database.ID, true)
|
||||
backuperNode.MakeBackup(backup.ID, true)
|
||||
|
||||
// Verify expectations were met
|
||||
mockNotificationSender.AssertExpectations(t)
|
||||
@@ -171,36 +153,3 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
assert.Equal(t, notifier.ID, capturedNotifier.ID)
|
||||
})
|
||||
}
|
||||
|
||||
type CreateFailedBackupUsecase struct {
|
||||
}
|
||||
|
||||
func (uc *CreateFailedBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
backupProgressListener(10)
|
||||
return nil, errors.New("backup failed")
|
||||
}
|
||||
|
||||
type CreateSuccessBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateSuccessBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
backupProgressListener(10)
|
||||
return &common.BackupMetadata{
|
||||
EncryptionSalt: nil,
|
||||
EncryptionIV: nil,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
}, nil
|
||||
}
|
||||
519
backend/internal/features/backups/backups/backuping/cleaner.go
Normal file
519
backend/internal/features/backups/backups/backuping/cleaner.go
Normal file
@@ -0,0 +1,519 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/storages"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/period"
|
||||
)
|
||||
|
||||
const (
|
||||
cleanerTickerInterval = 1 * time.Minute
|
||||
recentBackupGracePeriod = 60 * time.Minute
|
||||
)
|
||||
|
||||
type BackupCleaner struct {
|
||||
backupRepository *backups_core.BackupRepository
|
||||
storageService *storages.StorageService
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
billingService BillingService
|
||||
fieldEncryptor util_encryption.FieldEncryptor
|
||||
logger *slog.Logger
|
||||
backupRemoveListeners []backups_core.BackupRemoveListener
|
||||
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) Run(ctx context.Context) {
|
||||
if c.hasRun.Swap(true) {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", c))
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
retentionLog := c.logger.With("task_name", "clean_by_retention_policy")
|
||||
exceededLog := c.logger.With("task_name", "clean_exceeded_storage_backups")
|
||||
staleLog := c.logger.With("task_name", "clean_stale_basebackups")
|
||||
|
||||
ticker := time.NewTicker(cleanerTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := c.cleanByRetentionPolicy(retentionLog); err != nil {
|
||||
retentionLog.Error("failed to clean backups by retention policy", "error", err)
|
||||
}
|
||||
|
||||
if err := c.cleanExceededStorageBackups(exceededLog); err != nil {
|
||||
exceededLog.Error("failed to clean exceeded backups", "error", err)
|
||||
}
|
||||
|
||||
if err := c.cleanStaleUploadedBasebackups(staleLog); err != nil {
|
||||
staleLog.Error("failed to clean stale uploaded basebackups", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) DeleteBackup(backup *backups_core.Backup) error {
|
||||
for _, listener := range c.backupRemoveListeners {
|
||||
if err := listener.OnBeforeBackupRemove(backup); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
storage, err := c.storageService.GetStorageByID(backup.StorageID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := storage.DeleteFile(c.fieldEncryptor, backup.FileName); err != nil {
|
||||
// we do not return error here, because sometimes clean up performed
|
||||
// before unavailable storage removal or change - therefore we should
|
||||
// proceed even in case of error. It's possible that some S3 or
|
||||
// storage is not available yet, it should not block us
|
||||
c.logger.Error("Failed to delete backup file", "error", err)
|
||||
}
|
||||
|
||||
metadataFileName := backup.FileName + ".metadata"
|
||||
if err := storage.DeleteFile(c.fieldEncryptor, metadataFileName); err != nil {
|
||||
c.logger.Error("Failed to delete backup metadata file", "error", err)
|
||||
}
|
||||
|
||||
return c.backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) AddBackupRemoveListener(listener backups_core.BackupRemoveListener) {
|
||||
c.backupRemoveListeners = append(c.backupRemoveListeners, listener)
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanStaleUploadedBasebackups(logger *slog.Logger) error {
|
||||
staleBackups, err := c.backupRepository.FindStaleUploadedBasebackups(
|
||||
time.Now().UTC().Add(-10 * time.Minute),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find stale uploaded basebackups: %w", err)
|
||||
}
|
||||
|
||||
for _, backup := range staleBackups {
|
||||
backupLog := logger.With("database_id", backup.DatabaseID, "backup_id", backup.ID)
|
||||
|
||||
staleStorage, storageErr := c.storageService.GetStorageByID(backup.StorageID)
|
||||
if storageErr != nil {
|
||||
backupLog.Error(
|
||||
"failed to get storage for stale basebackup cleanup",
|
||||
"storage_id", backup.StorageID,
|
||||
"error", storageErr,
|
||||
)
|
||||
} else {
|
||||
if err := staleStorage.DeleteFile(c.fieldEncryptor, backup.FileName); err != nil {
|
||||
backupLog.Error(
|
||||
fmt.Sprintf("failed to delete stale basebackup file: %s", backup.FileName),
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
metadataFileName := backup.FileName + ".metadata"
|
||||
if err := staleStorage.DeleteFile(c.fieldEncryptor, metadataFileName); err != nil {
|
||||
backupLog.Error(
|
||||
fmt.Sprintf("failed to delete stale basebackup metadata file: %s", metadataFileName),
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
failMsg := "basebackup finalization timed out after 10 minutes"
|
||||
backup.Status = backups_core.BackupStatusFailed
|
||||
backup.FailMessage = &failMsg
|
||||
|
||||
if err := c.backupRepository.Save(backup); err != nil {
|
||||
backupLog.Error("failed to mark stale uploaded basebackup as failed", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
backupLog.Info("marked stale uploaded basebackup as failed and cleaned storage")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanByRetentionPolicy(logger *slog.Logger) error {
|
||||
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
dbLog := logger.With("database_id", backupConfig.DatabaseID, "policy", backupConfig.RetentionPolicyType)
|
||||
|
||||
var cleanErr error
|
||||
|
||||
switch backupConfig.RetentionPolicyType {
|
||||
case backups_config.RetentionPolicyTypeCount:
|
||||
cleanErr = c.cleanByCount(dbLog, backupConfig)
|
||||
case backups_config.RetentionPolicyTypeGFS:
|
||||
cleanErr = c.cleanByGFS(dbLog, backupConfig)
|
||||
default:
|
||||
cleanErr = c.cleanByTimePeriod(dbLog, backupConfig)
|
||||
}
|
||||
|
||||
if cleanErr != nil {
|
||||
dbLog.Error("failed to clean backups by retention policy", "error", cleanErr)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanExceededStorageBackups(logger *slog.Logger) error {
|
||||
if !config.GetEnv().IsCloud {
|
||||
return nil
|
||||
}
|
||||
|
||||
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
dbLog := logger.With("database_id", backupConfig.DatabaseID)
|
||||
|
||||
subscription, subErr := c.billingService.GetSubscription(dbLog, backupConfig.DatabaseID)
|
||||
if subErr != nil {
|
||||
dbLog.Error("failed to get subscription for exceeded backups check", "error", subErr)
|
||||
continue
|
||||
}
|
||||
|
||||
storageLimitMB := int64(subscription.GetBackupsStorageGB()) * 1024
|
||||
|
||||
if err := c.cleanExceededBackupsForDatabase(dbLog, backupConfig.DatabaseID, storageLimitMB); err != nil {
|
||||
dbLog.Error("failed to clean exceeded backups for database", "error", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanByTimePeriod(logger *slog.Logger, backupConfig *backups_config.BackupConfig) error {
|
||||
if backupConfig.RetentionTimePeriod == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if backupConfig.RetentionTimePeriod == period.PeriodForever {
|
||||
return nil
|
||||
}
|
||||
|
||||
storeDuration := backupConfig.RetentionTimePeriod.ToDuration()
|
||||
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
|
||||
|
||||
oldBackups, err := c.backupRepository.FindBackupsBeforeDate(
|
||||
backupConfig.DatabaseID,
|
||||
dateBeforeBackupsShouldBeDeleted,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to find old backups for database %s: %w",
|
||||
backupConfig.DatabaseID,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
for _, backup := range oldBackups {
|
||||
if isRecentBackup(backup) {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
logger.Error("failed to delete old backup", "backup_id", backup.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Info("deleted old backup", "backup_id", backup.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanByCount(logger *slog.Logger, backupConfig *backups_config.BackupConfig) error {
|
||||
if backupConfig.RetentionCount <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
completedBackups, err := c.backupRepository.FindByDatabaseIdAndStatus(
|
||||
backupConfig.DatabaseID,
|
||||
backups_core.BackupStatusCompleted,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to find completed backups for database %s: %w",
|
||||
backupConfig.DatabaseID,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
// completedBackups are ordered newest first; delete everything beyond position RetentionCount
|
||||
if len(completedBackups) <= backupConfig.RetentionCount {
|
||||
return nil
|
||||
}
|
||||
|
||||
toDelete := completedBackups[backupConfig.RetentionCount:]
|
||||
for _, backup := range toDelete {
|
||||
if isRecentBackup(backup) {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
logger.Error("failed to delete backup by count policy", "backup_id", backup.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Info(
|
||||
fmt.Sprintf("deleted backup by count policy: retention count is %d", backupConfig.RetentionCount),
|
||||
"backup_id", backup.ID,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanByGFS(logger *slog.Logger, backupConfig *backups_config.BackupConfig) error {
|
||||
if backupConfig.RetentionGfsHours <= 0 && backupConfig.RetentionGfsDays <= 0 &&
|
||||
backupConfig.RetentionGfsWeeks <= 0 && backupConfig.RetentionGfsMonths <= 0 &&
|
||||
backupConfig.RetentionGfsYears <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
completedBackups, err := c.backupRepository.FindByDatabaseIdAndStatus(
|
||||
backupConfig.DatabaseID,
|
||||
backups_core.BackupStatusCompleted,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to find completed backups for database %s: %w",
|
||||
backupConfig.DatabaseID,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
keepSet := buildGFSKeepSet(
|
||||
completedBackups,
|
||||
backupConfig.RetentionGfsHours,
|
||||
backupConfig.RetentionGfsDays,
|
||||
backupConfig.RetentionGfsWeeks,
|
||||
backupConfig.RetentionGfsMonths,
|
||||
backupConfig.RetentionGfsYears,
|
||||
)
|
||||
|
||||
for _, backup := range completedBackups {
|
||||
if keepSet[backup.ID] {
|
||||
continue
|
||||
}
|
||||
|
||||
if isRecentBackup(backup) {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
logger.Error("failed to delete backup by GFS policy", "backup_id", backup.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Info("deleted backup by GFS policy", "backup_id", backup.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanExceededBackupsForDatabase(
|
||||
logger *slog.Logger,
|
||||
databaseID uuid.UUID,
|
||||
limitPerDbMB int64,
|
||||
) error {
|
||||
for {
|
||||
backupsTotalSizeMB, err := c.backupRepository.GetTotalSizeByDatabase(databaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if backupsTotalSizeMB <= float64(limitPerDbMB) {
|
||||
break
|
||||
}
|
||||
|
||||
oldestBackups, err := c.backupRepository.FindOldestByDatabaseExcludingInProgress(
|
||||
databaseID,
|
||||
1,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(oldestBackups) == 0 {
|
||||
logger.Warn(fmt.Sprintf(
|
||||
"no backups to delete but still over limit: total size is %.1f MB, limit is %d MB",
|
||||
backupsTotalSizeMB, limitPerDbMB,
|
||||
))
|
||||
break
|
||||
}
|
||||
|
||||
backup := oldestBackups[0]
|
||||
if isRecentBackup(backup) {
|
||||
break
|
||||
}
|
||||
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
logger.Error("failed to delete exceeded backup", "backup_id", backup.ID, "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Info(
|
||||
fmt.Sprintf("deleted exceeded backup: backup size is %.1f MB, total size is %.1f MB, limit is %d MB",
|
||||
backup.BackupSizeMb, backupsTotalSizeMB, limitPerDbMB),
|
||||
"backup_id", backup.ID,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isRecentBackup(backup *backups_core.Backup) bool {
|
||||
return time.Since(backup.CreatedAt) < recentBackupGracePeriod
|
||||
}
|
||||
|
||||
// buildGFSKeepSet determines which backups to retain under the GFS rotation scheme.
|
||||
// Backups must be sorted newest-first. A backup can fill multiple slots simultaneously
|
||||
// (e.g. the newest backup of a year also fills the monthly, weekly, daily, and hourly slot).
|
||||
func buildGFSKeepSet(
|
||||
backups []*backups_core.Backup,
|
||||
hours, days, weeks, months, years int,
|
||||
) map[uuid.UUID]bool {
|
||||
keep := make(map[uuid.UUID]bool)
|
||||
|
||||
if len(backups) == 0 {
|
||||
return keep
|
||||
}
|
||||
|
||||
hoursSeen := make(map[string]bool)
|
||||
daysSeen := make(map[string]bool)
|
||||
weeksSeen := make(map[string]bool)
|
||||
monthsSeen := make(map[string]bool)
|
||||
yearsSeen := make(map[string]bool)
|
||||
|
||||
hoursKept, daysKept, weeksKept, monthsKept, yearsKept := 0, 0, 0, 0, 0
|
||||
|
||||
// Compute per-level time-window cutoffs so higher-frequency slots
|
||||
// cannot absorb backups that belong to lower-frequency levels.
|
||||
ref := backups[0].CreatedAt
|
||||
|
||||
rawHourlyCutoff := ref.Add(-time.Duration(hours) * time.Hour)
|
||||
rawDailyCutoff := ref.Add(-time.Duration(days) * 24 * time.Hour)
|
||||
rawWeeklyCutoff := ref.Add(-time.Duration(weeks) * 7 * 24 * time.Hour)
|
||||
rawMonthlyCutoff := ref.AddDate(0, -months, 0)
|
||||
rawYearlyCutoff := ref.AddDate(-years, 0, 0)
|
||||
|
||||
// Hierarchical capping: each level's window cannot extend further back
|
||||
// than the nearest active lower-frequency level's window.
|
||||
yearlyCutoff := rawYearlyCutoff
|
||||
|
||||
monthlyCutoff := rawMonthlyCutoff
|
||||
if years > 0 {
|
||||
monthlyCutoff = laterOf(monthlyCutoff, yearlyCutoff)
|
||||
}
|
||||
|
||||
weeklyCutoff := rawWeeklyCutoff
|
||||
if months > 0 {
|
||||
weeklyCutoff = laterOf(weeklyCutoff, monthlyCutoff)
|
||||
} else if years > 0 {
|
||||
weeklyCutoff = laterOf(weeklyCutoff, yearlyCutoff)
|
||||
}
|
||||
|
||||
dailyCutoff := rawDailyCutoff
|
||||
switch {
|
||||
case weeks > 0:
|
||||
dailyCutoff = laterOf(dailyCutoff, weeklyCutoff)
|
||||
case months > 0:
|
||||
dailyCutoff = laterOf(dailyCutoff, monthlyCutoff)
|
||||
case years > 0:
|
||||
dailyCutoff = laterOf(dailyCutoff, yearlyCutoff)
|
||||
}
|
||||
|
||||
hourlyCutoff := rawHourlyCutoff
|
||||
switch {
|
||||
case days > 0:
|
||||
hourlyCutoff = laterOf(hourlyCutoff, dailyCutoff)
|
||||
case weeks > 0:
|
||||
hourlyCutoff = laterOf(hourlyCutoff, weeklyCutoff)
|
||||
case months > 0:
|
||||
hourlyCutoff = laterOf(hourlyCutoff, monthlyCutoff)
|
||||
case years > 0:
|
||||
hourlyCutoff = laterOf(hourlyCutoff, yearlyCutoff)
|
||||
}
|
||||
|
||||
for _, backup := range backups {
|
||||
t := backup.CreatedAt
|
||||
|
||||
hourKey := t.Format("2006-01-02-15")
|
||||
dayKey := t.Format("2006-01-02")
|
||||
weekYear, week := t.ISOWeek()
|
||||
weekKey := fmt.Sprintf("%d-%02d", weekYear, week)
|
||||
monthKey := t.Format("2006-01")
|
||||
yearKey := t.Format("2006")
|
||||
|
||||
if hours > 0 && hoursKept < hours && !hoursSeen[hourKey] && t.After(hourlyCutoff) {
|
||||
keep[backup.ID] = true
|
||||
hoursSeen[hourKey] = true
|
||||
hoursKept++
|
||||
}
|
||||
|
||||
if days > 0 && daysKept < days && !daysSeen[dayKey] && t.After(dailyCutoff) {
|
||||
keep[backup.ID] = true
|
||||
daysSeen[dayKey] = true
|
||||
daysKept++
|
||||
}
|
||||
|
||||
if weeks > 0 && weeksKept < weeks && !weeksSeen[weekKey] && t.After(weeklyCutoff) {
|
||||
keep[backup.ID] = true
|
||||
weeksSeen[weekKey] = true
|
||||
weeksKept++
|
||||
}
|
||||
|
||||
if months > 0 && monthsKept < months && !monthsSeen[monthKey] && t.After(monthlyCutoff) {
|
||||
keep[backup.ID] = true
|
||||
monthsSeen[monthKey] = true
|
||||
monthsKept++
|
||||
}
|
||||
|
||||
if years > 0 && yearsKept < years && !yearsSeen[yearKey] && t.After(yearlyCutoff) {
|
||||
keep[backup.ID] = true
|
||||
yearsSeen[yearKey] = true
|
||||
yearsKept++
|
||||
}
|
||||
}
|
||||
|
||||
return keep
|
||||
}
|
||||
|
||||
func laterOf(a, b time.Time) time.Time {
|
||||
if a.After(b) {
|
||||
return a
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
1328
backend/internal/features/backups/backups/backuping/cleaner_test.go
Normal file
1328
backend/internal/features/backups/backups/backuping/cleaner_test.go
Normal file
File diff suppressed because it is too large
Load Diff
96
backend/internal/features/backups/backups/backuping/di.go
Normal file
96
backend/internal/features/backups/backups/backuping/di.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/usecases"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/billing"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var backupRepository = &backups_core.BackupRepository{}
|
||||
|
||||
var taskCancelManager = tasks_cancellation.GetTaskCancelManager()
|
||||
|
||||
var backupCleaner = &BackupCleaner{
|
||||
backupRepository,
|
||||
storages.GetStorageService(),
|
||||
backups_config.GetBackupConfigService(),
|
||||
billing.GetBillingService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
logger.GetLogger(),
|
||||
[]backups_core.BackupRemoveListener{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
var backupNodesRegistry = &BackupNodesRegistry{
|
||||
cache_utils.GetValkeyClient(),
|
||||
logger.GetLogger(),
|
||||
cache_utils.DefaultCacheTimeout,
|
||||
cache_utils.NewPubSubManager(),
|
||||
cache_utils.NewPubSubManager(),
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
func getNodeID() uuid.UUID {
|
||||
return uuid.New()
|
||||
}
|
||||
|
||||
var backuperNode = &BackuperNode{
|
||||
databases.GetDatabaseService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
taskCancelManager,
|
||||
backupNodesRegistry,
|
||||
logger.GetLogger(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
getNodeID(),
|
||||
time.Time{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
var backupsScheduler = &BackupsScheduler{
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
taskCancelManager,
|
||||
backupNodesRegistry,
|
||||
databases.GetDatabaseService(),
|
||||
billing.GetBillingService(),
|
||||
time.Now().UTC(),
|
||||
logger.GetLogger(),
|
||||
make(map[uuid.UUID]BackupToNodeRelation),
|
||||
backuperNode,
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
func GetBackupsScheduler() *BackupsScheduler {
|
||||
return backupsScheduler
|
||||
}
|
||||
|
||||
func GetBackuperNode() *BackuperNode {
|
||||
return backuperNode
|
||||
}
|
||||
|
||||
func GetBackupNodesRegistry() *BackupNodesRegistry {
|
||||
return backupNodesRegistry
|
||||
}
|
||||
|
||||
func GetBackupCleaner() *BackupCleaner {
|
||||
return backupCleaner
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user