mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 00:32:03 +02:00
Compare commits
269 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
415dda8752 | ||
|
|
3faf85796a | ||
|
|
edd2759f5a | ||
|
|
c283856f38 | ||
|
|
6059e1a33b | ||
|
|
2deda2e7ea | ||
|
|
acf1143752 | ||
|
|
889063a8b4 | ||
|
|
a1e20e7b10 | ||
|
|
7e76945550 | ||
|
|
d98acfc4af | ||
|
|
0ffc7c8c96 | ||
|
|
1b011bdcd4 | ||
|
|
7e209ff537 | ||
|
|
f712e3a437 | ||
|
|
bcd7d8e1aa | ||
|
|
880a7488e9 | ||
|
|
ca4d483f2c | ||
|
|
1b511410a6 | ||
|
|
c8edff8046 | ||
|
|
f60e3d956b | ||
|
|
f2cb9022f2 | ||
|
|
4b3f36eea2 | ||
|
|
460063e7a5 | ||
|
|
a0f02b253e | ||
|
|
812f11bc2f | ||
|
|
e796e3ddf0 | ||
|
|
c96d3db337 | ||
|
|
ed6c3a2034 | ||
|
|
05115047c3 | ||
|
|
446b96c6c0 | ||
|
|
36a0448da1 | ||
|
|
8e392cfeab | ||
|
|
6683db1e52 | ||
|
|
703b883936 | ||
|
|
e818bcff82 | ||
|
|
b2f98f1332 | ||
|
|
230cc27ea6 | ||
|
|
cd197ff94b | ||
|
|
91f35a3e17 | ||
|
|
30c2e2d156 | ||
|
|
ef7c5b45e6 | ||
|
|
920c98e229 | ||
|
|
2a19a96aae | ||
|
|
75aa2108d9 | ||
|
|
0a0040839e | ||
|
|
ff4f795ece | ||
|
|
dc05502580 | ||
|
|
1ca38f5583 | ||
|
|
40b3ff61c7 | ||
|
|
e1b245a573 | ||
|
|
fdf29b71f2 | ||
|
|
49da981c21 | ||
|
|
9d611d3559 | ||
|
|
22cab53dab | ||
|
|
d761c4156c | ||
|
|
cbb8b82711 | ||
|
|
8e3d1e5bff | ||
|
|
349e7f0ee8 | ||
|
|
3a274e135b | ||
|
|
61e937bc2a | ||
|
|
f67919fe1a | ||
|
|
91ee5966d8 | ||
|
|
d77d7d69a3 | ||
|
|
fc88b730d5 | ||
|
|
1f1d80245f | ||
|
|
16a29cf458 | ||
|
|
43e04500ac | ||
|
|
cee3022f85 | ||
|
|
f46d92c480 | ||
|
|
10677238d7 | ||
|
|
2553203fcf | ||
|
|
7b05bd8000 | ||
|
|
8d45728f73 | ||
|
|
c70ad82c95 | ||
|
|
e4bc34d319 | ||
|
|
257ae85da7 | ||
|
|
b42c820bb2 | ||
|
|
da5c13fb11 | ||
|
|
35180360e5 | ||
|
|
e4f6cd7a5d | ||
|
|
d7b8e6d56a | ||
|
|
6016f23fb2 | ||
|
|
e7c4ee8f6f | ||
|
|
a75702a01b | ||
|
|
81a21eb907 | ||
|
|
33d6bf0147 | ||
|
|
6eb53bb07b | ||
|
|
6ac04270b9 | ||
|
|
b0510d7c21 | ||
|
|
dc5f271882 | ||
|
|
8f718771c9 | ||
|
|
d8eea05dca | ||
|
|
b2a94274d7 | ||
|
|
77c2712ebb | ||
|
|
a9dc29f82c | ||
|
|
c934a45dca | ||
|
|
d4acdf2826 | ||
|
|
49753c4fc0 | ||
|
|
c6aed6b36d | ||
|
|
3060b4266a | ||
|
|
ebeb597f17 | ||
|
|
4783784325 | ||
|
|
bd41433bdb | ||
|
|
a9073787d2 | ||
|
|
0890bf8f09 | ||
|
|
f8c11e8802 | ||
|
|
e798d82fc1 | ||
|
|
81a01585ee | ||
|
|
a8465c1a10 | ||
|
|
a9e5db70f6 | ||
|
|
7a47be6ca6 | ||
|
|
16be3db0c6 | ||
|
|
744e51d1e1 | ||
|
|
b3af75d430 | ||
|
|
6f7320abeb | ||
|
|
a1655d35a6 | ||
|
|
9b6e801184 | ||
|
|
105777ab6f | ||
|
|
3a1a88d5cf | ||
|
|
699ca16814 | ||
|
|
26f3cf233a | ||
|
|
3d8372e9f6 | ||
|
|
b46f11804d | ||
|
|
4676361688 | ||
|
|
de3679cadf | ||
|
|
8f03a30af2 | ||
|
|
356529c58a | ||
|
|
e7eed056f7 | ||
|
|
6084cdc954 | ||
|
|
c50bcc57b1 | ||
|
|
ea76300ed7 | ||
|
|
9b413e4076 | ||
|
|
f91cb260f2 | ||
|
|
8f37a8082f | ||
|
|
5cf7614772 | ||
|
|
ae27f74c2e | ||
|
|
9457516bb9 | ||
|
|
a36fc5bf8c | ||
|
|
03ada5806d | ||
|
|
a6675390e5 | ||
|
|
af2f978876 | ||
|
|
04e7eba5c5 | ||
|
|
520165541d | ||
|
|
5b556bc161 | ||
|
|
0952a15ec5 | ||
|
|
1afb3aa3ff | ||
|
|
19b92e5f74 | ||
|
|
d4763f26b2 | ||
|
|
0e389ba16b | ||
|
|
594a3294c6 | ||
|
|
4e4a323cf1 | ||
|
|
7d9ecf697b | ||
|
|
755c420157 | ||
|
|
ff73627287 | ||
|
|
9c9ab00ace | ||
|
|
7366e21a1a | ||
|
|
a327d1aa57 | ||
|
|
f152b16ea3 | ||
|
|
85dbe80d3d | ||
|
|
edf4028fd1 | ||
|
|
8d85c45a90 | ||
|
|
d9c176d19a | ||
|
|
7a6f72a456 | ||
|
|
9a1471b88b | ||
|
|
386ea1d708 | ||
|
|
a4b23936ee | ||
|
|
b36aa9d48b | ||
|
|
13cb8e5bd2 | ||
|
|
2db4b6e075 | ||
|
|
f2b0b2bf1f | ||
|
|
7142ce295e | ||
|
|
04621b9b2d | ||
|
|
bd329a68cf | ||
|
|
f957abc9db | ||
|
|
c0fd6be1a9 | ||
|
|
c39bd34d5e | ||
|
|
27bec15a29 | ||
|
|
d98baa0656 | ||
|
|
4344f5ea5e | ||
|
|
7c6afa5b88 | ||
|
|
dbac799e1b | ||
|
|
7ee3817089 | ||
|
|
bae6f7f007 | ||
|
|
55dc087ddd | ||
|
|
c94d0db637 | ||
|
|
a1adef2261 | ||
|
|
4602dc3f88 | ||
|
|
cbbfc5ea8f | ||
|
|
dd1072e230 | ||
|
|
a495e5317a | ||
|
|
7eed647038 | ||
|
|
6973241e25 | ||
|
|
ab181f5b81 | ||
|
|
b60a0cc170 | ||
|
|
f319a497b3 | ||
|
|
bc870b3f8e | ||
|
|
15383c59eb | ||
|
|
d14c223a65 | ||
|
|
2c0a294027 | ||
|
|
5d851d73bd | ||
|
|
699913c251 | ||
|
|
a2e3f30a6d | ||
|
|
80f1174ecd | ||
|
|
a47f8d5e2c | ||
|
|
54b9e67656 | ||
|
|
3782846872 | ||
|
|
245a81897f | ||
|
|
5cbc0773b6 | ||
|
|
997fc01442 | ||
|
|
6d0ae32d0c | ||
|
|
011985d723 | ||
|
|
d677ee61de | ||
|
|
c6b8f6e87a | ||
|
|
2bb5f93d00 | ||
|
|
b91c150300 | ||
|
|
12b119ce40 | ||
|
|
7c6f0ab4ba | ||
|
|
6d2db4b298 | ||
|
|
6397423298 | ||
|
|
3470aae8e3 | ||
|
|
184fbcdb2c | ||
|
|
2d897dd722 | ||
|
|
cba40afd00 | ||
|
|
7aea012aeb | ||
|
|
6d5534deaa | ||
|
|
c04bd54683 | ||
|
|
1c3f16b372 | ||
|
|
ed08da56a6 | ||
|
|
c53e84b48d | ||
|
|
dbfeb9e27f | ||
|
|
02e86ffb3b | ||
|
|
207382116c | ||
|
|
a91ee50e31 | ||
|
|
7e5562b115 | ||
|
|
3ef51c4d68 | ||
|
|
e47e513460 | ||
|
|
226a6c06e6 | ||
|
|
615fd9d574 | ||
|
|
e9fcf20cdf | ||
|
|
7649f4acfd | ||
|
|
7e4c3bcc19 | ||
|
|
f2aecc0427 | ||
|
|
3ce7da319f | ||
|
|
096098f660 | ||
|
|
c3ba4a7c5a | ||
|
|
52c0f53608 | ||
|
|
a5095acad4 | ||
|
|
a6d32b5c09 | ||
|
|
722560e824 | ||
|
|
496ac6120c | ||
|
|
756c6c87af | ||
|
|
a23d05b735 | ||
|
|
33a8d302eb | ||
|
|
25ed1ffd2a | ||
|
|
67582325bb | ||
|
|
5a89558cf6 | ||
|
|
0ec02430b7 | ||
|
|
49115684a7 | ||
|
|
58ae86ff7a | ||
|
|
82939bb079 | ||
|
|
1697bfbae8 | ||
|
|
205cb1ec02 | ||
|
|
b9668875ef | ||
|
|
ca3f0281a3 | ||
|
|
1b8d783d4e | ||
|
|
75b0477874 | ||
|
|
19533514c2 | ||
|
|
b3c3ef136f |
44
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
44
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
---
|
||||
name: Bug Report
|
||||
about: Report a bug or unexpected behavior in Databasus
|
||||
labels: bug
|
||||
---
|
||||
|
||||
## Databasus version (screenshot)
|
||||
|
||||
It is displayed in the bottom left corner of the Databasus UI. Please attach screenshot, not just version text
|
||||
|
||||
<!-- e.g. 1.4.2 -->
|
||||
|
||||
## Operating system and architecture
|
||||
|
||||
<!-- e.g. Ubuntu 22.04 x64, macOS 14 ARM, Windows 11 x64 -->
|
||||
|
||||
## Database type and version (optional, for DB-related bugs)
|
||||
|
||||
<!-- e.g. PostgreSQL 16 in Docker, MySQL 8.0 installed on server, MariaDB 11.4 in AWS Cloud -->
|
||||
|
||||
## Describe the bug (please write manually, do not ask AI to summarize)
|
||||
|
||||
**What happened:**
|
||||
|
||||
**What I expected:**
|
||||
|
||||
## Steps to reproduce
|
||||
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
|
||||
## Have you asked AI how to solve the issue?
|
||||
|
||||
<!-- Using AI to diagnose issues before filing a bug report helps narrow down root causes. -->
|
||||
|
||||
- [ ] Claude Sonnet 4.6 or newer
|
||||
- [ ] ChatGPT 5.2 or newer
|
||||
- [ ] No
|
||||
|
||||
|
||||
## Additional context / logs
|
||||
|
||||
<!-- Screenshots, error messages, relevant log output, etc. -->
|
||||
418
.github/workflows/ci-release.yml
vendored
418
.github/workflows/ci-release.yml
vendored
@@ -9,29 +9,30 @@ on:
|
||||
|
||||
jobs:
|
||||
lint-backend:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: golang:1.26.1
|
||||
volumes:
|
||||
- /runner-cache/go-pkg:/go/pkg/mod
|
||||
- /runner-cache/go-build:/root/.cache/go-build
|
||||
- /runner-cache/golangci-lint:/root/.cache/golangci-lint
|
||||
- /runner-cache/apt-archives:/var/cache/apt/archives
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.24.4"
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
~/.cache/go-build
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('backend/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
- name: Download Go modules
|
||||
run: |
|
||||
cd backend
|
||||
go mod download
|
||||
|
||||
- name: Install golangci-lint
|
||||
run: |
|
||||
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/HEAD/install.sh | sh -s -- -b $(go env GOPATH)/bin v2.7.2
|
||||
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/HEAD/install.sh | sh -s -- -b $(go env GOPATH)/bin v2.11.3
|
||||
echo "$(go env GOPATH)/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Install swag for swagger generation
|
||||
@@ -63,8 +64,6 @@ jobs:
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
cache: "npm"
|
||||
cache-dependency-path: frontend/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
@@ -82,6 +81,44 @@ jobs:
|
||||
cd frontend
|
||||
npm run lint
|
||||
|
||||
- name: Build frontend
|
||||
run: |
|
||||
cd frontend
|
||||
npm run build
|
||||
|
||||
lint-agent:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.26.1"
|
||||
cache-dependency-path: agent/go.sum
|
||||
|
||||
- name: Download Go modules
|
||||
run: |
|
||||
cd agent
|
||||
go mod download
|
||||
|
||||
- name: Install golangci-lint
|
||||
run: |
|
||||
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/HEAD/install.sh | sh -s -- -b $(go env GOPATH)/bin v2.11.3
|
||||
echo "$(go env GOPATH)/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Run golangci-lint
|
||||
run: |
|
||||
cd agent
|
||||
golangci-lint run
|
||||
|
||||
- name: Verify go mod tidy
|
||||
run: |
|
||||
cd agent
|
||||
go mod tidy
|
||||
git diff --exit-code go.mod go.sum || (echo "go mod tidy made changes, please run 'go mod tidy' and commit the changes" && exit 1)
|
||||
|
||||
test-frontend:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint-frontend]
|
||||
@@ -93,8 +130,6 @@ jobs:
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
cache: "npm"
|
||||
cache-dependency-path: frontend/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
@@ -106,45 +141,77 @@ jobs:
|
||||
cd frontend
|
||||
npm run test
|
||||
|
||||
test-backend:
|
||||
test-agent:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint-backend]
|
||||
needs: [lint-agent]
|
||||
steps:
|
||||
- name: Free up disk space
|
||||
run: |
|
||||
echo "Disk space before cleanup:"
|
||||
df -h
|
||||
# Remove unnecessary pre-installed software
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo rm -rf /usr/local/share/boost
|
||||
sudo rm -rf /usr/share/swift
|
||||
# Clean apt cache
|
||||
sudo apt-get clean
|
||||
# Clean docker images (if any pre-installed)
|
||||
docker system prune -af --volumes || true
|
||||
echo "Disk space after cleanup:"
|
||||
df -h
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.24.4"
|
||||
go-version: "1.26.1"
|
||||
cache-dependency-path: agent/go.sum
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
~/.cache/go-build
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('backend/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
- name: Download Go modules
|
||||
run: |
|
||||
cd agent
|
||||
go mod download
|
||||
|
||||
- name: Run Go tests
|
||||
run: |
|
||||
cd agent
|
||||
go test -count=1 -failfast ./internal/...
|
||||
|
||||
e2e-agent:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint-agent]
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Run e2e tests
|
||||
run: |
|
||||
cd agent
|
||||
make e2e
|
||||
|
||||
- name: Cleanup
|
||||
if: always()
|
||||
run: |
|
||||
cd agent/e2e
|
||||
docker compose down -v --rmi local || true
|
||||
rm -rf artifacts || true
|
||||
|
||||
# Self-hosted: performant high-frequency CPU is used to start many containers and run tests fast. Tests
|
||||
# step is bottle-neck, because we need a lot of containers and cannot parallelize tests due to shared resources
|
||||
test-backend:
|
||||
runs-on: self-hosted
|
||||
needs: [lint-backend]
|
||||
container:
|
||||
image: golang:1.26.1
|
||||
options: --privileged -v /var/run/docker.sock:/var/run/docker.sock --add-host=host.docker.internal:host-gateway
|
||||
volumes:
|
||||
- /runner-cache/go-pkg:/go/pkg/mod
|
||||
- /runner-cache/go-build:/root/.cache/go-build
|
||||
- /runner-cache/apt-archives:/var/cache/apt/archives
|
||||
steps:
|
||||
- name: Install Docker CLI
|
||||
run: |
|
||||
apt-get update -qq
|
||||
apt-get install -y -qq docker.io docker-compose netcat-openbsd wget
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Download Go modules
|
||||
run: |
|
||||
cd backend
|
||||
go mod download
|
||||
|
||||
- name: Create .env file for testing
|
||||
run: |
|
||||
@@ -156,14 +223,16 @@ jobs:
|
||||
DEV_DB_PASSWORD=Q1234567
|
||||
#app
|
||||
ENV_MODE=development
|
||||
# db
|
||||
DATABASE_DSN=host=localhost user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
|
||||
DATABASE_URL=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disable
|
||||
# db - using 172.17.0.1 to access host from container
|
||||
DATABASE_DSN=host=172.17.0.1 user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
|
||||
DATABASE_URL=postgres://postgres:Q1234567@172.17.0.1:5437/databasus?sslmode=disable
|
||||
# migrations
|
||||
GOOSE_DRIVER=postgres
|
||||
GOOSE_DBSTRING=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disable
|
||||
GOOSE_DBSTRING=postgres://postgres:Q1234567@172.17.0.1:5437/databasus?sslmode=disable
|
||||
GOOSE_MIGRATION_DIR=./migrations
|
||||
# testing
|
||||
# testing
|
||||
TEST_LOCALHOST=172.17.0.1
|
||||
IS_SKIP_EXTERNAL_RESOURCES_TESTS=true
|
||||
# to get Google Drive env variables: add storage in UI and copy data from added storage here
|
||||
TEST_GOOGLE_DRIVE_CLIENT_ID=${{ secrets.TEST_GOOGLE_DRIVE_CLIENT_ID }}
|
||||
TEST_GOOGLE_DRIVE_CLIENT_SECRET=${{ secrets.TEST_GOOGLE_DRIVE_CLIENT_SECRET }}
|
||||
@@ -221,6 +290,14 @@ jobs:
|
||||
TEST_MONGODB_60_PORT=27060
|
||||
TEST_MONGODB_70_PORT=27070
|
||||
TEST_MONGODB_82_PORT=27082
|
||||
# Valkey (cache) - using 172.17.0.1
|
||||
VALKEY_HOST=172.17.0.1
|
||||
VALKEY_PORT=6379
|
||||
VALKEY_USERNAME=
|
||||
VALKEY_PASSWORD=
|
||||
VALKEY_IS_SSL=false
|
||||
# Host for test databases (container -> host)
|
||||
TEST_DB_HOST=172.17.0.1
|
||||
EOF
|
||||
|
||||
- name: Start test containers
|
||||
@@ -233,25 +310,30 @@ jobs:
|
||||
# Wait for main dev database
|
||||
timeout 60 bash -c 'until docker exec dev-db pg_isready -h localhost -p 5437 -U postgres; do sleep 2; done'
|
||||
|
||||
# Wait for test databases
|
||||
timeout 60 bash -c 'until nc -z localhost 5000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5001; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5002; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5003; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5004; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5005; do sleep 2; done'
|
||||
# Wait for Valkey (cache)
|
||||
echo "Waiting for Valkey..."
|
||||
timeout 60 bash -c 'until docker exec dev-valkey valkey-cli ping 2>/dev/null | grep -q PONG; do sleep 2; done'
|
||||
echo "Valkey is ready!"
|
||||
|
||||
# Wait for test databases (using 172.17.0.1 from container)
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5001; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5002; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5003; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5004; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5005; do sleep 2; done'
|
||||
|
||||
# Wait for MinIO
|
||||
timeout 60 bash -c 'until nc -z localhost 9000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 9000; do sleep 2; done'
|
||||
|
||||
# Wait for Azurite
|
||||
timeout 60 bash -c 'until nc -z localhost 10000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 10000; do sleep 2; done'
|
||||
|
||||
# Wait for FTP
|
||||
timeout 60 bash -c 'until nc -z localhost 7007; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 7007; do sleep 2; done'
|
||||
|
||||
# Wait for SFTP
|
||||
timeout 60 bash -c 'until nc -z localhost 7008; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 7008; do sleep 2; done'
|
||||
|
||||
# Wait for MySQL containers
|
||||
echo "Waiting for MySQL 5.7..."
|
||||
@@ -310,67 +392,66 @@ jobs:
|
||||
mkdir -p databasus-data/backups
|
||||
mkdir -p databasus-data/temp
|
||||
|
||||
- name: Cache PostgreSQL client tools
|
||||
id: cache-postgres
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: /usr/lib/postgresql
|
||||
key: postgres-clients-12-18-v1
|
||||
|
||||
- name: Cache MySQL client tools
|
||||
id: cache-mysql
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: backend/tools/mysql
|
||||
key: mysql-clients-57-80-84-9-v1
|
||||
|
||||
- name: Cache MariaDB client tools
|
||||
id: cache-mariadb
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: backend/tools/mariadb
|
||||
key: mariadb-clients-106-121-v1
|
||||
|
||||
- name: Cache MongoDB Database Tools
|
||||
id: cache-mongodb
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: backend/tools/mongodb
|
||||
key: mongodb-database-tools-100.10.0-v1
|
||||
|
||||
- name: Install MySQL dependencies
|
||||
- name: Install database client dependencies
|
||||
run: |
|
||||
sudo apt-get update -qq
|
||||
sudo apt-get install -y -qq libncurses6
|
||||
sudo ln -sf /usr/lib/x86_64-linux-gnu/libncurses.so.6 /usr/lib/x86_64-linux-gnu/libncurses.so.5
|
||||
sudo ln -sf /usr/lib/x86_64-linux-gnu/libtinfo.so.6 /usr/lib/x86_64-linux-gnu/libtinfo.so.5
|
||||
apt-get update -qq
|
||||
apt-get install -y -qq libncurses6 libpq5
|
||||
ln -sf /usr/lib/x86_64-linux-gnu/libncurses.so.6 /usr/lib/x86_64-linux-gnu/libncurses.so.5 || true
|
||||
ln -sf /usr/lib/x86_64-linux-gnu/libtinfo.so.6 /usr/lib/x86_64-linux-gnu/libtinfo.so.5 || true
|
||||
|
||||
- name: Install PostgreSQL, MySQL, MariaDB and MongoDB client tools
|
||||
if: steps.cache-postgres.outputs.cache-hit != 'true' || steps.cache-mysql.outputs.cache-hit != 'true' || steps.cache-mariadb.outputs.cache-hit != 'true' || steps.cache-mongodb.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
chmod +x backend/tools/download_linux.sh
|
||||
cd backend/tools
|
||||
./download_linux.sh
|
||||
|
||||
- name: Setup PostgreSQL symlinks (when using cache)
|
||||
if: steps.cache-postgres.outputs.cache-hit == 'true'
|
||||
- name: Setup PostgreSQL, MySQL and MariaDB client tools from pre-built assets
|
||||
run: |
|
||||
cd backend/tools
|
||||
mkdir -p postgresql
|
||||
|
||||
# Create directory structure
|
||||
mkdir -p postgresql mysql mariadb mongodb/bin
|
||||
|
||||
# Copy PostgreSQL client tools (12-18) from pre-built assets
|
||||
for version in 12 13 14 15 16 17 18; do
|
||||
version_dir="postgresql/postgresql-$version"
|
||||
mkdir -p "$version_dir/bin"
|
||||
pg_bin_dir="/usr/lib/postgresql/$version/bin"
|
||||
if [ -d "$pg_bin_dir" ]; then
|
||||
ln -sf "$pg_bin_dir/pg_dump" "$version_dir/bin/pg_dump"
|
||||
ln -sf "$pg_bin_dir/pg_dumpall" "$version_dir/bin/pg_dumpall"
|
||||
ln -sf "$pg_bin_dir/psql" "$version_dir/bin/psql"
|
||||
ln -sf "$pg_bin_dir/pg_restore" "$version_dir/bin/pg_restore"
|
||||
ln -sf "$pg_bin_dir/createdb" "$version_dir/bin/createdb"
|
||||
ln -sf "$pg_bin_dir/dropdb" "$version_dir/bin/dropdb"
|
||||
fi
|
||||
mkdir -p postgresql/postgresql-$version
|
||||
cp -r ../../assets/tools/x64/postgresql/postgresql-$version/bin postgresql/postgresql-$version/
|
||||
done
|
||||
|
||||
# Copy MySQL client tools (5.7, 8.0, 8.4, 9) from pre-built assets
|
||||
for version in 5.7 8.0 8.4 9; do
|
||||
mkdir -p mysql/mysql-$version
|
||||
cp -r ../../assets/tools/x64/mysql/mysql-$version/bin mysql/mysql-$version/
|
||||
done
|
||||
|
||||
# Copy MariaDB client tools (10.6, 12.1) from pre-built assets
|
||||
for version in 10.6 12.1; do
|
||||
mkdir -p mariadb/mariadb-$version
|
||||
cp -r ../../assets/tools/x64/mariadb/mariadb-$version/bin mariadb/mariadb-$version/
|
||||
done
|
||||
|
||||
# Make all binaries executable
|
||||
chmod +x postgresql/*/bin/*
|
||||
chmod +x mysql/*/bin/*
|
||||
chmod +x mariadb/*/bin/*
|
||||
|
||||
echo "Pre-built client tools setup complete"
|
||||
|
||||
- name: Install MongoDB Database Tools
|
||||
run: |
|
||||
cd backend/tools
|
||||
|
||||
# MongoDB Database Tools must be downloaded (not in pre-built assets)
|
||||
# They are backward compatible - single version supports all servers (4.0-8.0)
|
||||
MONGODB_TOOLS_URL="https://fastdl.mongodb.org/tools/db/mongodb-database-tools-debian12-x86_64-100.10.0.deb"
|
||||
|
||||
echo "Downloading MongoDB Database Tools..."
|
||||
wget -q "$MONGODB_TOOLS_URL" -O /tmp/mongodb-database-tools.deb
|
||||
|
||||
echo "Installing MongoDB Database Tools..."
|
||||
dpkg -i /tmp/mongodb-database-tools.deb || apt-get install -f -y --no-install-recommends
|
||||
|
||||
# Create symlinks to tools directory
|
||||
ln -sf /usr/bin/mongodump mongodb/bin/mongodump
|
||||
ln -sf /usr/bin/mongorestore mongodb/bin/mongorestore
|
||||
|
||||
rm -f /tmp/mongodb-database-tools.deb
|
||||
echo "MongoDB Database Tools installed successfully"
|
||||
|
||||
- name: Verify MariaDB client tools exist
|
||||
run: |
|
||||
cd backend/tools
|
||||
@@ -403,7 +484,7 @@ jobs:
|
||||
- name: Run database migrations
|
||||
run: |
|
||||
cd backend
|
||||
go install github.com/pressly/goose/v3/cmd/goose@latest
|
||||
go install github.com/pressly/goose/v3/cmd/goose@v3.24.3
|
||||
goose up
|
||||
|
||||
- name: Run Go tests
|
||||
@@ -415,11 +496,29 @@ jobs:
|
||||
if: always()
|
||||
run: |
|
||||
cd backend
|
||||
# Stop and remove containers (keeping images for next run)
|
||||
docker compose -f docker-compose.yml.example down -v
|
||||
|
||||
# Clean up all data directories created by docker-compose
|
||||
echo "Cleaning up data directories..."
|
||||
rm -rf pgdata || true
|
||||
rm -rf valkey-data || true
|
||||
rm -rf mysqldata || true
|
||||
rm -rf mariadbdata || true
|
||||
rm -rf temp/nas || true
|
||||
rm -rf databasus-data || true
|
||||
|
||||
# Also clean root-level databasus-data if exists
|
||||
cd ..
|
||||
rm -rf databasus-data || true
|
||||
|
||||
echo "Cleanup complete"
|
||||
|
||||
determine-version:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [test-backend, test-frontend]
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: node:20
|
||||
needs: [test-backend, test-frontend, test-agent, e2e-agent]
|
||||
if: ${{ github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, '[skip-release]') }}
|
||||
outputs:
|
||||
should_release: ${{ steps.version_bump.outputs.should_release }}
|
||||
@@ -431,10 +530,9 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Install semver
|
||||
run: npm install -g semver
|
||||
@@ -448,6 +546,7 @@ jobs:
|
||||
|
||||
- name: Analyze commits and determine version bump
|
||||
id: version_bump
|
||||
shell: bash
|
||||
run: |
|
||||
CURRENT_VERSION="${{ steps.current_version.outputs.current_version }}"
|
||||
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
|
||||
@@ -467,7 +566,7 @@ jobs:
|
||||
HAS_FIX=false
|
||||
HAS_BREAKING=false
|
||||
|
||||
# Analyze each commit
|
||||
# Analyze each commit - USE PROCESS SUBSTITUTION to avoid subshell variable scope issues
|
||||
while IFS= read -r commit; do
|
||||
if [[ "$commit" =~ ^FEATURE ]]; then
|
||||
HAS_FEATURE=true
|
||||
@@ -485,7 +584,7 @@ jobs:
|
||||
HAS_BREAKING=true
|
||||
echo "Found BREAKING CHANGE: $commit"
|
||||
fi
|
||||
done <<< "$COMMITS"
|
||||
done < <(printf '%s\n' "$COMMITS")
|
||||
|
||||
# Determine version bump
|
||||
if [ "$HAS_BREAKING" = true ]; then
|
||||
@@ -511,10 +610,15 @@ jobs:
|
||||
fi
|
||||
|
||||
build-only:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [test-backend, test-frontend]
|
||||
runs-on: self-hosted
|
||||
needs: [test-backend, test-frontend, test-agent, e2e-agent]
|
||||
if: ${{ github.ref == 'refs/heads/main' && contains(github.event.head_commit.message, '[skip-release]') }}
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/* || true
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/.* || true
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -543,12 +647,17 @@ jobs:
|
||||
databasus/databasus:${{ github.sha }}
|
||||
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
needs: [determine-version]
|
||||
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
|
||||
permissions:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/* || true
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/.* || true
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -578,21 +687,33 @@ jobs:
|
||||
databasus/databasus:${{ github.sha }}
|
||||
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: node:20
|
||||
needs: [determine-version, build-and-push]
|
||||
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
rm -rf "$GITHUB_WORKSPACE"/* || true
|
||||
rm -rf "$GITHUB_WORKSPACE"/.* || true
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Generate changelog
|
||||
id: changelog
|
||||
shell: bash
|
||||
run: |
|
||||
NEW_VERSION="${{ needs.determine-version.outputs.new_version }}"
|
||||
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
|
||||
@@ -612,6 +733,7 @@ jobs:
|
||||
FIXES=""
|
||||
REFACTORS=""
|
||||
|
||||
# USE PROCESS SUBSTITUTION to avoid subshell variable scope issues
|
||||
while IFS= read -r line; do
|
||||
if [ -n "$line" ]; then
|
||||
COMMIT_MSG=$(echo "$line" | cut -d'|' -f1)
|
||||
@@ -645,7 +767,7 @@ jobs:
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
done <<< "$COMMITS"
|
||||
done < <(printf '%s\n' "$COMMITS")
|
||||
|
||||
# Build changelog sections
|
||||
if [ -n "$FEATURES" ]; then
|
||||
@@ -672,17 +794,6 @@ jobs:
|
||||
echo EOF
|
||||
} >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Update CITATION.cff version
|
||||
run: |
|
||||
VERSION="${{ needs.determine-version.outputs.new_version }}"
|
||||
sed -i "s/^version: .*/version: ${VERSION}/" CITATION.cff
|
||||
sed -i "s/^date-released: .*/date-released: \"$(date +%Y-%m-%d)\"/" CITATION.cff
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git add CITATION.cff
|
||||
git commit -m "Update CITATION.cff to v${VERSION}" || true
|
||||
git push || true
|
||||
|
||||
- name: Create GitHub Release
|
||||
uses: actions/create-release@v1
|
||||
env:
|
||||
@@ -695,16 +806,33 @@ jobs:
|
||||
prerelease: false
|
||||
|
||||
publish-helm-chart:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: alpine:3.19
|
||||
volumes:
|
||||
- /runner-cache/apk-cache:/etc/apk/cache
|
||||
needs: [determine-version, build-and-push]
|
||||
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
rm -rf "$GITHUB_WORKSPACE"/* || true
|
||||
rm -rf "$GITHUB_WORKSPACE"/.* || true
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apk add --no-cache git bash curl
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4
|
||||
with:
|
||||
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -1,12 +1,16 @@
|
||||
ansible/
|
||||
postgresus_data/
|
||||
postgresus-data/
|
||||
databasus-data/
|
||||
.env
|
||||
pgdata/
|
||||
docker-compose.yml
|
||||
!agent/e2e/docker-compose.yml
|
||||
node_modules/
|
||||
.idea
|
||||
/articles
|
||||
|
||||
.DS_Store
|
||||
/scripts
|
||||
/scripts
|
||||
.vscode/settings.json
|
||||
.claude
|
||||
@@ -6,24 +6,55 @@ repos:
|
||||
hooks:
|
||||
- id: frontend-format
|
||||
name: Frontend Format (Prettier)
|
||||
entry: powershell -Command "cd frontend; npm run format"
|
||||
entry: bash -c "cd frontend && npm run format"
|
||||
language: system
|
||||
files: ^frontend/.*\.(ts|tsx|js|jsx|json|css|md)$
|
||||
pass_filenames: false
|
||||
|
||||
- id: frontend-lint
|
||||
name: Frontend Lint (ESLint)
|
||||
entry: powershell -Command "cd frontend; npm run lint"
|
||||
entry: bash -c "cd frontend && npm run lint"
|
||||
language: system
|
||||
files: ^frontend/.*\.(ts|tsx|js|jsx)$
|
||||
pass_filenames: false
|
||||
|
||||
- id: frontend-build
|
||||
name: Frontend Build
|
||||
entry: bash -c "cd frontend && npm run build"
|
||||
language: system
|
||||
files: ^frontend/.*\.(ts|tsx|js|jsx|json|css)$
|
||||
pass_filenames: false
|
||||
|
||||
# Backend checks
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: backend-format-and-lint
|
||||
name: Backend Format & Lint (golangci-lint)
|
||||
entry: powershell -Command "cd backend; golangci-lint fmt; golangci-lint run"
|
||||
entry: bash -c "cd backend && golangci-lint fmt ./internal/... ./cmd/... && golangci-lint run ./internal/... ./cmd/..."
|
||||
language: system
|
||||
files: ^backend/.*\.go$
|
||||
pass_filenames: false
|
||||
pass_filenames: false
|
||||
|
||||
- id: backend-go-mod-tidy
|
||||
name: Backend Go Mod Tidy
|
||||
entry: bash -c "cd backend && go mod tidy"
|
||||
language: system
|
||||
files: ^backend/.*\.go$
|
||||
pass_filenames: false
|
||||
|
||||
# Agent checks
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: agent-format-and-lint
|
||||
name: Agent Format & Lint (golangci-lint)
|
||||
entry: bash -c "cd agent && golangci-lint fmt ./internal/... ./cmd/... && golangci-lint run ./internal/... ./cmd/..."
|
||||
language: system
|
||||
files: ^agent/.*\.go$
|
||||
pass_filenames: false
|
||||
|
||||
- id: agent-go-mod-tidy
|
||||
name: Agent Go Mod Tidy
|
||||
entry: bash -c "cd agent && go mod tidy"
|
||||
language: system
|
||||
files: ^agent/.*\.go$
|
||||
pass_filenames: false
|
||||
|
||||
@@ -32,5 +32,5 @@ keywords:
|
||||
- mongodb
|
||||
- mariadb
|
||||
license: Apache-2.0
|
||||
version: 2.18.5
|
||||
date-released: "2025-12-30"
|
||||
version: 2.21.0
|
||||
date-released: "2026-01-05"
|
||||
|
||||
163
Dockerfile
163
Dockerfile
@@ -22,7 +22,7 @@ RUN npm run build
|
||||
|
||||
# ========= BUILD BACKEND =========
|
||||
# Backend build stage
|
||||
FROM --platform=$BUILDPLATFORM golang:1.24.4 AS backend-build
|
||||
FROM --platform=$BUILDPLATFORM golang:1.26.1 AS backend-build
|
||||
|
||||
# Make TARGET args available early so tools built here match the final image arch
|
||||
ARG TARGETOS
|
||||
@@ -66,13 +66,52 @@ RUN CGO_ENABLED=0 \
|
||||
go build -o /app/main ./cmd/main.go
|
||||
|
||||
|
||||
# ========= BUILD AGENT =========
|
||||
# Builds the databasus-agent CLI binary for BOTH x86_64 and ARM64.
|
||||
# Both architectures are always built because:
|
||||
# - Databasus server runs on one arch (e.g. amd64)
|
||||
# - The agent runs on remote PostgreSQL servers that may be on a
|
||||
# different arch (e.g. arm64)
|
||||
# - The backend serves the correct binary based on the agent's
|
||||
# ?arch= query parameter
|
||||
#
|
||||
# We cross-compile from the build platform (no QEMU needed) because the
|
||||
# agent is pure Go with zero C dependencies.
|
||||
# CGO_ENABLED=0 produces fully static binaries — no glibc/musl dependency,
|
||||
# so the agent runs on any Linux distro (Alpine, Debian, Ubuntu, RHEL, etc.).
|
||||
# APP_VERSION is baked into the binary via -ldflags so the agent can
|
||||
# compare its version against the server and auto-update when needed.
|
||||
FROM --platform=$BUILDPLATFORM golang:1.26.1 AS agent-build
|
||||
|
||||
ARG APP_VERSION=dev
|
||||
|
||||
WORKDIR /agent
|
||||
|
||||
COPY agent/go.mod ./
|
||||
RUN go mod download
|
||||
|
||||
COPY agent/ ./
|
||||
|
||||
# Build for x86_64 (amd64) — static binary, no glibc dependency
|
||||
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 \
|
||||
go build -ldflags "-X main.Version=${APP_VERSION}" \
|
||||
-o /agent-binaries/databasus-agent-linux-amd64 ./cmd/main.go
|
||||
|
||||
# Build for ARM64 (arm64) — static binary, no glibc dependency
|
||||
RUN CGO_ENABLED=0 GOOS=linux GOARCH=arm64 \
|
||||
go build -ldflags "-X main.Version=${APP_VERSION}" \
|
||||
-o /agent-binaries/databasus-agent-linux-arm64 ./cmd/main.go
|
||||
|
||||
|
||||
# ========= RUNTIME =========
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
# Add version metadata to runtime image
|
||||
ARG APP_VERSION=dev
|
||||
ARG TARGETARCH
|
||||
LABEL org.opencontainers.image.version=$APP_VERSION
|
||||
ENV APP_VERSION=$APP_VERSION
|
||||
ENV CONTAINER_ARCH=$TARGETARCH
|
||||
|
||||
# Set production mode for Docker containers
|
||||
ENV ENV_MODE=production
|
||||
@@ -123,6 +162,15 @@ RUN wget -qO- https://www.postgresql.org/media/keys/ACCC4CF8.asc | apt-key add -
|
||||
apt-get install -y --no-install-recommends postgresql-17 && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Valkey server from debian repository
|
||||
# Valkey is only accessible internally (localhost) - not exposed outside container
|
||||
RUN wget -O /usr/share/keyrings/greensec.github.io-valkey-debian.key https://greensec.github.io/valkey-debian/public.key && \
|
||||
echo "deb [signed-by=/usr/share/keyrings/greensec.github.io-valkey-debian.key] https://greensec.github.io/valkey-debian/repo $(lsb_release -cs) main" \
|
||||
> /etc/apt/sources.list.d/valkey-debian.list && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends valkey && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# ========= Install rclone =========
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends rclone && \
|
||||
@@ -172,19 +220,23 @@ RUN if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
|
||||
# ========= Install MongoDB Database Tools =========
|
||||
# Note: MongoDB Database Tools are backward compatible - single version supports all server versions (4.0-8.0)
|
||||
# Use dpkg with apt-get -f install to handle dependencies
|
||||
# Note: For ARM64, we use Ubuntu 22.04 package as MongoDB doesn't provide Debian 12 ARM64 packages
|
||||
RUN apt-get update && \
|
||||
if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
wget -q https://fastdl.mongodb.org/tools/db/mongodb-database-tools-debian12-x86_64-100.10.0.deb -O /tmp/mongodb-database-tools.deb; \
|
||||
elif [ "$TARGETARCH" = "arm64" ]; then \
|
||||
wget -q https://fastdl.mongodb.org/tools/db/mongodb-database-tools-debian12-aarch64-100.10.0.deb -O /tmp/mongodb-database-tools.deb; \
|
||||
wget -q https://fastdl.mongodb.org/tools/db/mongodb-database-tools-ubuntu2204-arm64-100.10.0.deb -O /tmp/mongodb-database-tools.deb; \
|
||||
fi && \
|
||||
dpkg -i /tmp/mongodb-database-tools.deb || true && \
|
||||
apt-get install -f -y --no-install-recommends && \
|
||||
rm /tmp/mongodb-database-tools.deb && \
|
||||
dpkg -i /tmp/mongodb-database-tools.deb || apt-get install -f -y --no-install-recommends && \
|
||||
rm -f /tmp/mongodb-database-tools.deb && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
ln -sf /usr/bin/mongodump /usr/local/mongodb-database-tools/bin/mongodump && \
|
||||
ln -sf /usr/bin/mongorestore /usr/local/mongodb-database-tools/bin/mongorestore
|
||||
mkdir -p /usr/local/mongodb-database-tools/bin && \
|
||||
if [ -f /usr/bin/mongodump ]; then \
|
||||
ln -sf /usr/bin/mongodump /usr/local/mongodb-database-tools/bin/mongodump; \
|
||||
fi && \
|
||||
if [ -f /usr/bin/mongorestore ]; then \
|
||||
ln -sf /usr/bin/mongorestore /usr/local/mongodb-database-tools/bin/mongorestore; \
|
||||
fi
|
||||
|
||||
# Create postgres user and set up directories
|
||||
RUN useradd -m -s /bin/bash postgres || true && \
|
||||
@@ -205,6 +257,10 @@ COPY backend/migrations ./migrations
|
||||
# Copy UI files
|
||||
COPY --from=backend-build /app/ui/build ./ui/build
|
||||
|
||||
# Copy agent binaries (both architectures) — served by the backend
|
||||
# at GET /api/v1/system/agent?arch=amd64|arm64
|
||||
COPY --from=agent-build /agent-binaries ./agent-binaries
|
||||
|
||||
# Copy .env file (with fallback to .env.production.example)
|
||||
COPY backend/.env* /app/
|
||||
RUN if [ ! -f /app/.env ]; then \
|
||||
@@ -238,10 +294,69 @@ fi
|
||||
# PostgreSQL 17 binary paths
|
||||
PG_BIN="/usr/lib/postgresql/17/bin"
|
||||
|
||||
# Generate runtime configuration for frontend
|
||||
echo "Generating runtime configuration..."
|
||||
|
||||
# Detect if email is configured (both SMTP_HOST and DATABASUS_URL must be set)
|
||||
if [ -n "\${SMTP_HOST:-}" ] && [ -n "\${DATABASUS_URL:-}" ]; then
|
||||
IS_EMAIL_CONFIGURED="true"
|
||||
else
|
||||
IS_EMAIL_CONFIGURED="false"
|
||||
fi
|
||||
|
||||
cat > /app/ui/build/runtime-config.js <<JSEOF
|
||||
// Runtime configuration injected at container startup
|
||||
// This file is generated dynamically and should not be edited manually
|
||||
window.__RUNTIME_CONFIG__ = {
|
||||
IS_CLOUD: '\${IS_CLOUD:-false}',
|
||||
GITHUB_CLIENT_ID: '\${GITHUB_CLIENT_ID:-}',
|
||||
GOOGLE_CLIENT_ID: '\${GOOGLE_CLIENT_ID:-}',
|
||||
IS_EMAIL_CONFIGURED: '\$IS_EMAIL_CONFIGURED',
|
||||
CLOUDFLARE_TURNSTILE_SITE_KEY: '\${CLOUDFLARE_TURNSTILE_SITE_KEY:-}',
|
||||
CONTAINER_ARCH: '\${CONTAINER_ARCH:-unknown}'
|
||||
};
|
||||
JSEOF
|
||||
|
||||
# Inject analytics script if provided (only if not already injected)
|
||||
if [ -n "\${ANALYTICS_SCRIPT:-}" ]; then
|
||||
if ! grep -q "rybbit.databasus.com" /app/ui/build/index.html 2>/dev/null; then
|
||||
echo "Injecting analytics script..."
|
||||
sed -i "s#</head># \${ANALYTICS_SCRIPT}\\
|
||||
</head>#" /app/ui/build/index.html
|
||||
fi
|
||||
fi
|
||||
|
||||
# Ensure proper ownership of data directory
|
||||
echo "Setting up data directory permissions..."
|
||||
mkdir -p /databasus-data/pgdata
|
||||
mkdir -p /databasus-data/temp
|
||||
mkdir -p /databasus-data/backups
|
||||
chown -R postgres:postgres /databasus-data
|
||||
chmod 700 /databasus-data/temp
|
||||
|
||||
# ========= Start Valkey (internal cache) =========
|
||||
echo "Configuring Valkey cache..."
|
||||
cat > /tmp/valkey.conf << 'VALKEY_CONFIG'
|
||||
port 6379
|
||||
bind 127.0.0.1
|
||||
protected-mode yes
|
||||
save ""
|
||||
maxmemory 256mb
|
||||
maxmemory-policy allkeys-lru
|
||||
VALKEY_CONFIG
|
||||
|
||||
echo "Starting Valkey..."
|
||||
valkey-server /tmp/valkey.conf &
|
||||
VALKEY_PID=\$!
|
||||
|
||||
echo "Waiting for Valkey to be ready..."
|
||||
for i in {1..30}; do
|
||||
if valkey-cli ping >/dev/null 2>&1; then
|
||||
echo "Valkey is ready!"
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
# Initialize PostgreSQL if not already initialized
|
||||
if [ ! -s "/databasus-data/pgdata/PG_VERSION" ]; then
|
||||
@@ -323,6 +438,8 @@ fi
|
||||
# Create database and set password for postgres user
|
||||
echo "Setting up database and user..."
|
||||
gosu postgres \$PG_BIN/psql -p 5437 -h localhost -d postgres << 'SQL'
|
||||
|
||||
# We use stub password, because internal DB is not exposed outside container
|
||||
ALTER USER postgres WITH PASSWORD 'Q1234567';
|
||||
SELECT 'CREATE DATABASE databasus OWNER postgres'
|
||||
WHERE NOT EXISTS (SELECT FROM pg_database WHERE datname = 'databasus')
|
||||
@@ -332,9 +449,37 @@ SQL
|
||||
|
||||
# Start the main application
|
||||
echo "Starting Databasus application..."
|
||||
|
||||
# Check and warn about external database/Valkey usage
|
||||
if [ -n "\${DANGEROUS_EXTERNAL_DATABASE_DSN:-}" ]; then
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "WARNING: Using external database"
|
||||
echo "=========================================="
|
||||
echo "DANGEROUS_EXTERNAL_DATABASE_DSN is set."
|
||||
echo "Application will connect to external PostgreSQL instead of internal instance."
|
||||
echo "Internal PostgreSQL is still running in the background."
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
fi
|
||||
|
||||
if [ -n "\${DANGEROUS_VALKEY_HOST:-}" ]; then
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "WARNING: Using external Valkey"
|
||||
echo "=========================================="
|
||||
echo "DANGEROUS_VALKEY_HOST is set."
|
||||
echo "Application will connect to external Valkey instead of internal instance."
|
||||
echo "Internal Valkey is still running in the background."
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
fi
|
||||
|
||||
exec ./main
|
||||
EOF
|
||||
|
||||
LABEL org.opencontainers.image.source="https://github.com/databasus/databasus"
|
||||
|
||||
RUN chmod +x /app/start.sh
|
||||
|
||||
EXPOSE 4005
|
||||
@@ -343,4 +488,4 @@ EXPOSE 4005
|
||||
VOLUME ["/databasus-data"]
|
||||
|
||||
ENTRYPOINT ["/app/start.sh"]
|
||||
CMD []
|
||||
CMD []
|
||||
|
||||
69
README.md
69
README.md
@@ -2,7 +2,7 @@
|
||||
<img src="assets/logo.svg" alt="Databasus Logo" width="250"/>
|
||||
|
||||
<h3>Backup tool for PostgreSQL, MySQL and MongoDB</h3>
|
||||
<p>Databasus is a free, open source and self-hosted tool to backup databases. Make backups with different storages (S3, Google Drive, FTP, etc.) and notifications about progress (Slack, Discord, Telegram, etc.). Previously known as Postgresus (see migration guide).</p>
|
||||
<p>Databasus is a free, open source and self-hosted tool to backup databases (with focus on PostgreSQL). Make backups with different storages (S3, Google Drive, FTP, etc.) and notifications about progress (Slack, Discord, Telegram, etc.)</p>
|
||||
|
||||
<!-- Badges -->
|
||||
[](https://www.postgresql.org/)
|
||||
@@ -11,7 +11,7 @@
|
||||
[](https://www.mongodb.com/)
|
||||
<br />
|
||||
[](LICENSE)
|
||||
[](https://hub.docker.com/r/rostislavdugin/postgresus)
|
||||
[](https://hub.docker.com/r/databasus/databasus)
|
||||
[](https://github.com/databasus/databasus)
|
||||
[](https://github.com/databasus/databasus)
|
||||
[](https://github.com/databasus/databasus)
|
||||
@@ -31,8 +31,6 @@
|
||||
<img src="assets/dashboard-dark.svg" alt="Databasus Dark Dashboard" width="800" style="margin-bottom: 10px;"/>
|
||||
|
||||
<img src="assets/dashboard.svg" alt="Databasus Dashboard" width="800"/>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
@@ -43,7 +41,7 @@
|
||||
|
||||
- **PostgreSQL**: 12, 13, 14, 15, 16, 17 and 18
|
||||
- **MySQL**: 5.7, 8 and 9
|
||||
- **MariaDB**: 10 and 11
|
||||
- **MariaDB**: 10, 11 and 12
|
||||
- **MongoDB**: 4, 5, 6, 7 and 8
|
||||
|
||||
### 🔄 **Scheduled backups**
|
||||
@@ -52,6 +50,13 @@
|
||||
- **Precise timing**: run backups at specific times (e.g., 4 AM during low traffic)
|
||||
- **Smart compression**: 4-8x space savings with balanced compression (~20% overhead)
|
||||
|
||||
### 🗑️ **Retention policies**
|
||||
|
||||
- **Time period**: Keep backups for a fixed duration (e.g., 7 days, 3 months, 1 year)
|
||||
- **Count**: Keep a fixed number of the most recent backups (e.g., last 30)
|
||||
- **GFS (Grandfather-Father-Son)**: Layered retention — keep hourly, daily, weekly, monthly and yearly backups independently for fine-grained long-term history (enterprises requirement)
|
||||
- **Size limits**: Set per-backup and total storage size caps to control storage usage
|
||||
|
||||
### 🗄️ **Multiple storage destinations** <a href="https://databasus.com/storages">(view supported)</a>
|
||||
|
||||
- **Local storage**: Keep backups on your VPS/server
|
||||
@@ -71,6 +76,8 @@
|
||||
- **Encryption for secrets**: Any sensitive data is encrypted and never exposed, even in logs or error messages
|
||||
- **Read-only user**: Databasus uses a read-only user by default for backups and never stores anything that can modify your data
|
||||
|
||||
It is also important for Databasus that you are able to decrypt and restore backups from storages (local, S3, etc.) without Databasus itself. To do so, read our guide on [how to recover directly from storage](https://databasus.com/how-to-recover-without-databasus). We avoid "vendor lock-in" even to open source tool!
|
||||
|
||||
### 👥 **Suitable for teams** <a href="https://databasus.com/access-management">(docs)</a>
|
||||
|
||||
- **Workspaces**: Group databases, notifiers and storages for different projects or teams
|
||||
@@ -220,8 +227,9 @@ For more options (NodePort, TLS, HTTPRoute for Gateway API), see the [Helm chart
|
||||
3. **Configure schedule**: Choose from hourly, daily, weekly, monthly or cron intervals
|
||||
4. **Set database connection**: Enter your database credentials and connection details
|
||||
5. **Choose storage**: Select where to store your backups (local, S3, Google Drive, etc.)
|
||||
6. **Add notifications** (optional): Configure email, Telegram, Slack, or webhook notifications
|
||||
7. **Save and start**: Databasus will validate settings and begin the backup schedule
|
||||
6. **Configure retention policy**: Choose time period, count or GFS to control how long backups are kept
|
||||
7. **Add notifications** (optional): Configure email, Telegram, Slack, or webhook notifications
|
||||
8. **Save and start**: Databasus will validate settings and begin the backup schedule
|
||||
|
||||
### 🔑 Resetting password <a href="https://databasus.com/password">(docs)</a>
|
||||
|
||||
@@ -233,66 +241,37 @@ docker exec -it databasus ./main --new-password="YourNewSecurePassword123" --ema
|
||||
|
||||
Replace `admin` with the actual email address of the user whose password you want to reset.
|
||||
|
||||
### 💾 Backuping Databasus itself
|
||||
|
||||
After installation, it is also recommended to <a href="https://databasus.com/faq/#backup-databasus">backup your Databasus itself</a> or, at least, to copy secret key used for encryption (30 seconds is needed). So you are able to restore from your encrypted backups if you lose access to the server with Databasus or it is corrupted.
|
||||
|
||||
---
|
||||
|
||||
## 📝 License
|
||||
|
||||
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details
|
||||
|
||||
---
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
Contributions are welcome! Read the <a href="https://databasus.com/contribute">contributing guide</a> for more details, priorities and rules. If you want to contribute but don't know where to start, message me on Telegram [@rostislav_dugin](https://t.me/rostislav_dugin)
|
||||
|
||||
Also you can join our large community of developers, DBAs and DevOps engineers on Telegram [@databasus_community](https://t.me/databasus_community).
|
||||
|
||||
--
|
||||
|
||||
## 📖 Migration guide
|
||||
|
||||
Databasus is the new name for Postgresus. You can stay with latest version of Postgresus if you wish. If you want to migrate - follow installation steps for Databasus itself.
|
||||
|
||||
Just renaming an image is not enough as Postgresus and Databasus use different data folders and internal database naming.
|
||||
|
||||
You can put a new Databasus image with updated volume near the old Postgresus and run it (stop Postgresus before):
|
||||
|
||||
```
|
||||
services:
|
||||
databasus:
|
||||
container_name: databasus
|
||||
image: databasus/databasus:latest
|
||||
ports:
|
||||
- "4005:4005"
|
||||
volumes:
|
||||
- ./databasus-data:/databasus-data
|
||||
restart: unless-stopped
|
||||
```
|
||||
|
||||
Then manually move databases from Postgresus to Databasus.
|
||||
|
||||
### Why was Postgresus renamed to Databasus?
|
||||
|
||||
Databasus has been developed since 2023. It was internal tool to backup production and home projects databases. In start of 2025 it was released as open source project on GitHub. By the end of 2025 it became popular and the time for renaming has come in December 2025.
|
||||
|
||||
It was an important step for the project to grow. Actually, there are a couple of reasons:
|
||||
|
||||
1. Postgresus is no longer a little tool that just adds UI for pg_dump for little projects. It became a tool both for individual users, DevOps, DBAs, teams, companies and even large enterprises. Tens of thousands of users use Postgresus every day. Postgresus grew into a reliable backup management tool. Initial positioning is no longer suitable: the project is not just a UI wrapper, it's a solid backup management system now (despite it's still easy to use).
|
||||
|
||||
2. New databases are supported: although the primary focus is PostgreSQL (with 100% support in the most efficient way) and always will be, Databasus added support for MySQL, MariaDB and MongoDB. Later more databases will be supported.
|
||||
|
||||
3. Trademark issue: "postgres" is a trademark of PostgreSQL Inc. and cannot be used in the project name. So for safety and legal reasons, we had to rename the project.
|
||||
|
||||
## AI disclaimer
|
||||
|
||||
There have been questions about AI usage in project development in issues and discussions. As the project focuses on security, reliability and production usage, it's important to explain how AI is used in the development process.
|
||||
|
||||
First of all, we are proud to say that Databasus has been accepted into both [Claude for Open Source](https://claude.com/contact-sales/claude-for-oss) by Anthropic and [Codex for Open Source](https://developers.openai.com/codex/community/codex-for-oss/) by OpenAI in March 2026. For us it is one more signal that the project was recognized as important open-source software and was as critical infrastructure worth supporting independently by two of the world's leading AI companies. Read more at [databasus.com/faq](https://databasus.com/faq#oss-programs).
|
||||
|
||||
Despite of this, we have the following rules how AI is used in the development process:
|
||||
|
||||
AI is used as a helper for:
|
||||
|
||||
- verification of code quality and searching for vulnerabilities
|
||||
- cleaning up and improving documentation, comments and code
|
||||
- assistance during development
|
||||
- double-checking PRs and commits after human review
|
||||
- additional security analysis of PRs via Codex Security
|
||||
|
||||
AI is not used for:
|
||||
|
||||
|
||||
1
agent/.env.example
Normal file
1
agent/.env.example
Normal file
@@ -0,0 +1 @@
|
||||
ENV_MODE=development
|
||||
24
agent/.gitignore
vendored
Normal file
24
agent/.gitignore
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
main
|
||||
.env
|
||||
docker-compose.yml
|
||||
!e2e/docker-compose.yml
|
||||
pgdata
|
||||
pgdata_test/
|
||||
mysqldata/
|
||||
mariadbdata/
|
||||
main.exe
|
||||
swagger/
|
||||
swagger/*
|
||||
swagger/docs.go
|
||||
swagger/swagger.json
|
||||
swagger/swagger.yaml
|
||||
postgresus-backend.exe
|
||||
databasus-backend.exe
|
||||
ui/build/*
|
||||
pgdata-for-restore/
|
||||
temp/
|
||||
cmd.exe
|
||||
temp/
|
||||
valkey-data/
|
||||
victoria-logs-data/
|
||||
databasus.json
|
||||
41
agent/.golangci.yml
Normal file
41
agent/.golangci.yml
Normal file
@@ -0,0 +1,41 @@
|
||||
version: "2"
|
||||
|
||||
run:
|
||||
timeout: 5m
|
||||
tests: false
|
||||
concurrency: 4
|
||||
|
||||
linters:
|
||||
default: standard
|
||||
enable:
|
||||
- funcorder
|
||||
- bodyclose
|
||||
- errorlint
|
||||
- gocritic
|
||||
- unconvert
|
||||
- misspell
|
||||
- errname
|
||||
- noctx
|
||||
- modernize
|
||||
|
||||
settings:
|
||||
errcheck:
|
||||
check-type-assertions: true
|
||||
|
||||
formatters:
|
||||
enable:
|
||||
- gofumpt
|
||||
- golines
|
||||
- gci
|
||||
|
||||
settings:
|
||||
golines:
|
||||
max-len: 120
|
||||
gofumpt:
|
||||
module-path: databasus-agent
|
||||
extra-rules: true
|
||||
gci:
|
||||
sections:
|
||||
- standard
|
||||
- default
|
||||
- localmodule
|
||||
26
agent/Makefile
Normal file
26
agent/Makefile
Normal file
@@ -0,0 +1,26 @@
|
||||
.PHONY: run build test lint e2e e2e-clean
|
||||
|
||||
# Usage: make run ARGS="start --pg-host localhost"
|
||||
run:
|
||||
go run cmd/main.go $(ARGS)
|
||||
|
||||
build:
|
||||
CGO_ENABLED=0 go build -ldflags "-X main.Version=$(VERSION)" -o databasus-agent ./cmd/main.go
|
||||
|
||||
test:
|
||||
go test -count=1 -failfast ./internal/...
|
||||
|
||||
lint:
|
||||
golangci-lint fmt ./cmd/... ./internal/... ./e2e/... && golangci-lint run ./cmd/... ./internal/... ./e2e/...
|
||||
|
||||
e2e:
|
||||
cd e2e && docker compose build
|
||||
cd e2e && docker compose run --rm e2e-agent-builder
|
||||
cd e2e && docker compose up -d e2e-postgres e2e-mock-server
|
||||
cd e2e && docker compose run --rm e2e-agent-runner
|
||||
cd e2e && docker compose run --rm e2e-agent-docker
|
||||
cd e2e && docker compose down -v
|
||||
|
||||
e2e-clean:
|
||||
cd e2e && docker compose down -v --rmi local
|
||||
rm -rf e2e/artifacts
|
||||
170
agent/cmd/main.go
Normal file
170
agent/cmd/main.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"databasus-agent/internal/config"
|
||||
"databasus-agent/internal/features/start"
|
||||
"databasus-agent/internal/features/upgrade"
|
||||
"databasus-agent/internal/logger"
|
||||
)
|
||||
|
||||
var Version = "dev"
|
||||
|
||||
func main() {
|
||||
if len(os.Args) < 2 {
|
||||
printUsage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
switch os.Args[1] {
|
||||
case "start":
|
||||
runStart(os.Args[2:])
|
||||
case "stop":
|
||||
runStop()
|
||||
case "status":
|
||||
runStatus()
|
||||
case "restore":
|
||||
runRestore(os.Args[2:])
|
||||
case "version":
|
||||
fmt.Println(Version)
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "unknown command: %s\n", os.Args[1])
|
||||
printUsage()
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func runStart(args []string) {
|
||||
fs := flag.NewFlagSet("start", flag.ExitOnError)
|
||||
|
||||
isDebug := fs.Bool("debug", false, "Enable debug logging")
|
||||
isSkipUpdate := fs.Bool("skip-update", false, "Skip auto-update check")
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.LoadFromJSONAndArgs(fs, args)
|
||||
|
||||
if err := cfg.SaveToJSON(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to save config: %v\n", err)
|
||||
}
|
||||
|
||||
logger.Init(*isDebug)
|
||||
log := logger.GetLogger()
|
||||
|
||||
isDev := checkIsDevelopment()
|
||||
runUpdateCheck(cfg.DatabasusHost, *isSkipUpdate, isDev, log)
|
||||
|
||||
if err := start.Run(cfg, log); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func runStop() {
|
||||
logger.Init(false)
|
||||
logger.GetLogger().Info("stop: stub — not yet implemented")
|
||||
}
|
||||
|
||||
func runStatus() {
|
||||
logger.Init(false)
|
||||
logger.GetLogger().Info("status: stub — not yet implemented")
|
||||
}
|
||||
|
||||
func runRestore(args []string) {
|
||||
fs := flag.NewFlagSet("restore", flag.ExitOnError)
|
||||
|
||||
targetDir := fs.String("target-dir", "", "Target pgdata directory")
|
||||
backupID := fs.String("backup-id", "", "Full backup UUID (optional)")
|
||||
targetTime := fs.String("target-time", "", "PITR target time in RFC3339 (optional)")
|
||||
isYes := fs.Bool("yes", false, "Skip confirmation prompt")
|
||||
isDebug := fs.Bool("debug", false, "Enable debug logging")
|
||||
isSkipUpdate := fs.Bool("skip-update", false, "Skip auto-update check")
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.LoadFromJSONAndArgs(fs, args)
|
||||
|
||||
if err := cfg.SaveToJSON(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to save config: %v\n", err)
|
||||
}
|
||||
|
||||
logger.Init(*isDebug)
|
||||
log := logger.GetLogger()
|
||||
|
||||
isDev := checkIsDevelopment()
|
||||
runUpdateCheck(cfg.DatabasusHost, *isSkipUpdate, isDev, log)
|
||||
|
||||
log.Info("restore: stub — not yet implemented",
|
||||
"targetDir", *targetDir,
|
||||
"backupId", *backupID,
|
||||
"targetTime", *targetTime,
|
||||
"yes", *isYes,
|
||||
)
|
||||
}
|
||||
|
||||
func printUsage() {
|
||||
fmt.Fprintln(os.Stderr, "Usage: databasus-agent <command> [flags]")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, "Commands:")
|
||||
fmt.Fprintln(os.Stderr, " start Start the agent (WAL archiving + basebackups)")
|
||||
fmt.Fprintln(os.Stderr, " stop Stop a running agent")
|
||||
fmt.Fprintln(os.Stderr, " status Show agent status")
|
||||
fmt.Fprintln(os.Stderr, " restore Restore a database from backup")
|
||||
fmt.Fprintln(os.Stderr, " version Print agent version")
|
||||
}
|
||||
|
||||
func runUpdateCheck(host string, isSkipUpdate, isDev bool, log *slog.Logger) {
|
||||
if isSkipUpdate {
|
||||
return
|
||||
}
|
||||
|
||||
if host == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if err := upgrade.CheckAndUpdate(host, Version, isDev, log); err != nil {
|
||||
log.Error("Auto-update failed", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func checkIsDevelopment() bool {
|
||||
dir, err := os.Getwd()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for range 3 {
|
||||
if data, err := os.ReadFile(filepath.Join(dir, ".env")); err == nil {
|
||||
return parseEnvMode(data)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
dir = filepath.Dir(dir)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func parseEnvMode(data []byte) bool {
|
||||
for line := range strings.SplitSeq(string(data), "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.SplitN(line, "=", 2)
|
||||
if len(parts) == 2 && strings.TrimSpace(parts[0]) == "ENV_MODE" {
|
||||
return strings.TrimSpace(parts[1]) == "development"
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
1
agent/e2e/.gitignore
vendored
Normal file
1
agent/e2e/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
artifacts/
|
||||
13
agent/e2e/Dockerfile.agent-builder
Normal file
13
agent/e2e/Dockerfile.agent-builder
Normal file
@@ -0,0 +1,13 @@
|
||||
# Builds agent binaries with different versions so
|
||||
# we can test upgrade behavior (v1 -> v2)
|
||||
FROM golang:1.26.1-alpine AS build
|
||||
WORKDIR /src
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
RUN CGO_ENABLED=0 go build -ldflags "-X main.Version=v1.0.0" -o /out/agent-v1 ./cmd/main.go
|
||||
RUN CGO_ENABLED=0 go build -ldflags "-X main.Version=v2.0.0" -o /out/agent-v2 ./cmd/main.go
|
||||
|
||||
FROM alpine:3.21
|
||||
COPY --from=build /out/ /out/
|
||||
CMD ["cp", "-v", "/out/agent-v1", "/out/agent-v2", "/artifacts/"]
|
||||
8
agent/e2e/Dockerfile.agent-docker
Normal file
8
agent/e2e/Dockerfile.agent-docker
Normal file
@@ -0,0 +1,8 @@
|
||||
# Runs pg_basebackup-via-docker-exec test (test 5) which tests
|
||||
# that the agent can connect to Postgres inside Docker container
|
||||
FROM docker:27-cli
|
||||
|
||||
RUN apk add --no-cache bash curl
|
||||
|
||||
WORKDIR /tmp
|
||||
ENTRYPOINT []
|
||||
14
agent/e2e/Dockerfile.agent-runner
Normal file
14
agent/e2e/Dockerfile.agent-runner
Normal file
@@ -0,0 +1,14 @@
|
||||
# Runs upgrade and host-mode pg_basebackup tests (tests 1-4). Needs
|
||||
# Postgres client tools to be installed inside the system
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
ca-certificates curl gnupg2 postgresql-common && \
|
||||
/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
postgresql-client-17 && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /tmp
|
||||
ENTRYPOINT []
|
||||
10
agent/e2e/Dockerfile.mock-server
Normal file
10
agent/e2e/Dockerfile.mock-server
Normal file
@@ -0,0 +1,10 @@
|
||||
# Mock databasus API server for version checks and binary downloads. Just
|
||||
# serves static responses and files from the `artifacts` directory.
|
||||
FROM golang:1.26.1-alpine AS build
|
||||
WORKDIR /app
|
||||
COPY mock-server/main.go .
|
||||
RUN CGO_ENABLED=0 go build -o mock-server main.go
|
||||
|
||||
FROM alpine:3.21
|
||||
COPY --from=build /app/mock-server /usr/local/bin/mock-server
|
||||
ENTRYPOINT ["mock-server"]
|
||||
64
agent/e2e/docker-compose.yml
Normal file
64
agent/e2e/docker-compose.yml
Normal file
@@ -0,0 +1,64 @@
|
||||
services:
|
||||
e2e-agent-builder:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: e2e/Dockerfile.agent-builder
|
||||
volumes:
|
||||
- ./artifacts:/artifacts
|
||||
container_name: e2e-agent-builder
|
||||
|
||||
e2e-postgres:
|
||||
image: postgres:17
|
||||
environment:
|
||||
POSTGRES_DB: testdb
|
||||
POSTGRES_USER: testuser
|
||||
POSTGRES_PASSWORD: testpassword
|
||||
container_name: e2e-agent-postgres
|
||||
command: postgres -c wal_level=replica -c max_wal_senders=3
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U testuser -d testdb"]
|
||||
interval: 2s
|
||||
timeout: 5s
|
||||
retries: 30
|
||||
|
||||
e2e-mock-server:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.mock-server
|
||||
volumes:
|
||||
- ./artifacts:/artifacts:ro
|
||||
container_name: e2e-mock-server
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "-q", "--spider", "http://localhost:4050/health"]
|
||||
interval: 2s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
|
||||
e2e-agent-runner:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.agent-runner
|
||||
volumes:
|
||||
- ./artifacts:/opt/agent/artifacts:ro
|
||||
- ./scripts:/opt/agent/scripts:ro
|
||||
depends_on:
|
||||
e2e-postgres:
|
||||
condition: service_healthy
|
||||
e2e-mock-server:
|
||||
condition: service_healthy
|
||||
container_name: e2e-agent-runner
|
||||
command: ["bash", "/opt/agent/scripts/run-all.sh", "host"]
|
||||
|
||||
e2e-agent-docker:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.agent-docker
|
||||
volumes:
|
||||
- ./artifacts:/opt/agent/artifacts:ro
|
||||
- ./scripts:/opt/agent/scripts:ro
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
depends_on:
|
||||
e2e-postgres:
|
||||
condition: service_healthy
|
||||
container_name: e2e-agent-docker
|
||||
command: ["bash", "/opt/agent/scripts/run-all.sh", "docker"]
|
||||
84
agent/e2e/mock-server/main.go
Normal file
84
agent/e2e/mock-server/main.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type server struct {
|
||||
mu sync.RWMutex
|
||||
version string
|
||||
binaryPath string
|
||||
}
|
||||
|
||||
func main() {
|
||||
version := "v2.0.0"
|
||||
binaryPath := "/artifacts/agent-v2"
|
||||
port := "4050"
|
||||
|
||||
s := &server{version: version, binaryPath: binaryPath}
|
||||
|
||||
http.HandleFunc("/api/v1/system/version", s.handleVersion)
|
||||
http.HandleFunc("/api/v1/system/agent", s.handleAgentDownload)
|
||||
http.HandleFunc("/mock/set-version", s.handleSetVersion)
|
||||
http.HandleFunc("/health", s.handleHealth)
|
||||
|
||||
addr := ":" + port
|
||||
log.Printf("Mock server starting on %s (version=%s, binary=%s)", addr, version, binaryPath)
|
||||
|
||||
if err := http.ListenAndServe(addr, nil); err != nil {
|
||||
log.Fatalf("Server failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *server) handleVersion(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.RLock()
|
||||
v := s.version
|
||||
s.mu.RUnlock()
|
||||
|
||||
log.Printf("GET /api/v1/system/version -> %s", v)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"version": v})
|
||||
}
|
||||
|
||||
func (s *server) handleAgentDownload(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.RLock()
|
||||
path := s.binaryPath
|
||||
s.mu.RUnlock()
|
||||
|
||||
log.Printf("GET /api/v1/system/agent (arch=%s) -> serving %s", r.URL.Query().Get("arch"), path)
|
||||
|
||||
http.ServeFile(w, r, path)
|
||||
}
|
||||
|
||||
func (s *server) handleSetVersion(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "POST only", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var body struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.version = body.Version
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Printf("POST /mock/set-version -> %s", body.Version)
|
||||
|
||||
_, _ = fmt.Fprintf(w, "version set to %s", body.Version)
|
||||
}
|
||||
|
||||
func (s *server) handleHealth(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}
|
||||
48
agent/e2e/scripts/run-all.sh
Normal file
48
agent/e2e/scripts/run-all.sh
Normal file
@@ -0,0 +1,48 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
MODE="${1:-host}"
|
||||
SCRIPT_DIR="$(dirname "$0")"
|
||||
PASSED=0
|
||||
FAILED=0
|
||||
|
||||
run_test() {
|
||||
local name="$1"
|
||||
local script="$2"
|
||||
|
||||
echo ""
|
||||
echo "========================================"
|
||||
echo " $name"
|
||||
echo "========================================"
|
||||
|
||||
if bash "$script"; then
|
||||
echo " PASSED: $name"
|
||||
PASSED=$((PASSED + 1))
|
||||
else
|
||||
echo " FAILED: $name"
|
||||
FAILED=$((FAILED + 1))
|
||||
fi
|
||||
}
|
||||
|
||||
if [ "$MODE" = "host" ]; then
|
||||
run_test "Test 1: Upgrade success (v1 -> v2)" "$SCRIPT_DIR/test-upgrade-success.sh"
|
||||
run_test "Test 2: Upgrade skip (version matches)" "$SCRIPT_DIR/test-upgrade-skip.sh"
|
||||
run_test "Test 3: pg_basebackup in PATH" "$SCRIPT_DIR/test-pg-host-path.sh"
|
||||
run_test "Test 4: pg_basebackup via bindir" "$SCRIPT_DIR/test-pg-host-bindir.sh"
|
||||
|
||||
elif [ "$MODE" = "docker" ]; then
|
||||
run_test "Test 5: pg_basebackup via docker exec" "$SCRIPT_DIR/test-pg-docker-exec.sh"
|
||||
|
||||
else
|
||||
echo "Unknown mode: $MODE (expected 'host' or 'docker')"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "========================================"
|
||||
echo " Results: $PASSED passed, $FAILED failed"
|
||||
echo "========================================"
|
||||
|
||||
if [ "$FAILED" -gt 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
51
agent/e2e/scripts/test-pg-docker-exec.sh
Normal file
51
agent/e2e/scripts/test-pg-docker-exec.sh
Normal file
@@ -0,0 +1,51 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
ARTIFACTS="/opt/agent/artifacts"
|
||||
AGENT="/tmp/test-agent"
|
||||
PG_CONTAINER="e2e-agent-postgres"
|
||||
|
||||
# Copy agent binary
|
||||
cp "$ARTIFACTS/agent-v1" "$AGENT"
|
||||
chmod +x "$AGENT"
|
||||
|
||||
# Verify docker CLI works and PG container is accessible
|
||||
if ! docker exec "$PG_CONTAINER" pg_basebackup --version > /dev/null 2>&1; then
|
||||
echo "FAIL: Cannot reach pg_basebackup inside container $PG_CONTAINER (test setup issue)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Run start with --skip-update and pg-type=docker
|
||||
echo "Running agent start (pg_basebackup via docker exec)..."
|
||||
OUTPUT=$("$AGENT" start \
|
||||
--skip-update \
|
||||
--databasus-host http://e2e-mock-server:4050 \
|
||||
--db-id test-db-id \
|
||||
--token test-token \
|
||||
--pg-host e2e-postgres \
|
||||
--pg-port 5432 \
|
||||
--pg-user testuser \
|
||||
--pg-password testpassword \
|
||||
--wal-dir /tmp/wal \
|
||||
--pg-type docker \
|
||||
--pg-docker-container-name "$PG_CONTAINER" 2>&1)
|
||||
|
||||
EXIT_CODE=$?
|
||||
echo "$OUTPUT"
|
||||
|
||||
if [ "$EXIT_CODE" -ne 0 ]; then
|
||||
echo "FAIL: Agent exited with code $EXIT_CODE"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! echo "$OUTPUT" | grep -q "pg_basebackup verified (docker)"; then
|
||||
echo "FAIL: Expected output to contain 'pg_basebackup verified (docker)'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! echo "$OUTPUT" | grep -q "PostgreSQL connection verified"; then
|
||||
echo "FAIL: Expected output to contain 'PostgreSQL connection verified'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "pg_basebackup found via docker exec and DB connection verified"
|
||||
57
agent/e2e/scripts/test-pg-host-bindir.sh
Normal file
57
agent/e2e/scripts/test-pg-host-bindir.sh
Normal file
@@ -0,0 +1,57 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
ARTIFACTS="/opt/agent/artifacts"
|
||||
AGENT="/tmp/test-agent"
|
||||
CUSTOM_BIN_DIR="/opt/pg/bin"
|
||||
|
||||
# Copy agent binary
|
||||
cp "$ARTIFACTS/agent-v1" "$AGENT"
|
||||
chmod +x "$AGENT"
|
||||
|
||||
# Move pg_basebackup out of PATH into custom directory
|
||||
mkdir -p "$CUSTOM_BIN_DIR"
|
||||
cp "$(which pg_basebackup)" "$CUSTOM_BIN_DIR/pg_basebackup"
|
||||
|
||||
# Hide the system one by prepending an empty dir to PATH
|
||||
export PATH="/opt/empty-path:$PATH"
|
||||
mkdir -p /opt/empty-path
|
||||
|
||||
# Verify pg_basebackup is NOT directly callable from default location
|
||||
# (we copied it, but the original is still there in debian — so we test
|
||||
# that the agent uses the custom dir, not PATH, by checking the output)
|
||||
|
||||
# Run start with --skip-update and custom bin dir
|
||||
echo "Running agent start (pg_basebackup via --pg-host-bin-dir)..."
|
||||
OUTPUT=$("$AGENT" start \
|
||||
--skip-update \
|
||||
--databasus-host http://e2e-mock-server:4050 \
|
||||
--db-id test-db-id \
|
||||
--token test-token \
|
||||
--pg-host e2e-postgres \
|
||||
--pg-port 5432 \
|
||||
--pg-user testuser \
|
||||
--pg-password testpassword \
|
||||
--wal-dir /tmp/wal \
|
||||
--pg-type host \
|
||||
--pg-host-bin-dir "$CUSTOM_BIN_DIR" 2>&1)
|
||||
|
||||
EXIT_CODE=$?
|
||||
echo "$OUTPUT"
|
||||
|
||||
if [ "$EXIT_CODE" -ne 0 ]; then
|
||||
echo "FAIL: Agent exited with code $EXIT_CODE"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! echo "$OUTPUT" | grep -q "pg_basebackup verified"; then
|
||||
echo "FAIL: Expected output to contain 'pg_basebackup verified'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! echo "$OUTPUT" | grep -q "PostgreSQL connection verified"; then
|
||||
echo "FAIL: Expected output to contain 'PostgreSQL connection verified'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "pg_basebackup found via custom bin dir and DB connection verified"
|
||||
49
agent/e2e/scripts/test-pg-host-path.sh
Normal file
49
agent/e2e/scripts/test-pg-host-path.sh
Normal file
@@ -0,0 +1,49 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
ARTIFACTS="/opt/agent/artifacts"
|
||||
AGENT="/tmp/test-agent"
|
||||
|
||||
# Copy agent binary
|
||||
cp "$ARTIFACTS/agent-v1" "$AGENT"
|
||||
chmod +x "$AGENT"
|
||||
|
||||
# Verify pg_basebackup is in PATH
|
||||
if ! which pg_basebackup > /dev/null 2>&1; then
|
||||
echo "FAIL: pg_basebackup not found in PATH (test setup issue)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Run start with --skip-update and pg-type=host
|
||||
echo "Running agent start (pg_basebackup in PATH)..."
|
||||
OUTPUT=$("$AGENT" start \
|
||||
--skip-update \
|
||||
--databasus-host http://e2e-mock-server:4050 \
|
||||
--db-id test-db-id \
|
||||
--token test-token \
|
||||
--pg-host e2e-postgres \
|
||||
--pg-port 5432 \
|
||||
--pg-user testuser \
|
||||
--pg-password testpassword \
|
||||
--wal-dir /tmp/wal \
|
||||
--pg-type host 2>&1)
|
||||
|
||||
EXIT_CODE=$?
|
||||
echo "$OUTPUT"
|
||||
|
||||
if [ "$EXIT_CODE" -ne 0 ]; then
|
||||
echo "FAIL: Agent exited with code $EXIT_CODE"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! echo "$OUTPUT" | grep -q "pg_basebackup verified"; then
|
||||
echo "FAIL: Expected output to contain 'pg_basebackup verified'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! echo "$OUTPUT" | grep -q "PostgreSQL connection verified"; then
|
||||
echo "FAIL: Expected output to contain 'PostgreSQL connection verified'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "pg_basebackup found in PATH and DB connection verified"
|
||||
51
agent/e2e/scripts/test-upgrade-skip.sh
Normal file
51
agent/e2e/scripts/test-upgrade-skip.sh
Normal file
@@ -0,0 +1,51 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
ARTIFACTS="/opt/agent/artifacts"
|
||||
AGENT="/tmp/test-agent"
|
||||
|
||||
# Set mock server to return v1.0.0 (same as agent)
|
||||
curl -sf -X POST http://e2e-mock-server:4050/mock/set-version \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"version":"v1.0.0"}'
|
||||
|
||||
# Copy v1 binary to writable location
|
||||
cp "$ARTIFACTS/agent-v1" "$AGENT"
|
||||
chmod +x "$AGENT"
|
||||
|
||||
# Verify initial version
|
||||
VERSION=$("$AGENT" version)
|
||||
if [ "$VERSION" != "v1.0.0" ]; then
|
||||
echo "FAIL: Expected initial version v1.0.0, got $VERSION"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Run start — agent should see version matches and skip upgrade
|
||||
echo "Running agent start (expecting upgrade skip)..."
|
||||
OUTPUT=$("$AGENT" start \
|
||||
--databasus-host http://e2e-mock-server:4050 \
|
||||
--db-id test-db-id \
|
||||
--token test-token \
|
||||
--pg-host e2e-postgres \
|
||||
--pg-port 5432 \
|
||||
--pg-user testuser \
|
||||
--pg-password testpassword \
|
||||
--wal-dir /tmp/wal \
|
||||
--pg-type host 2>&1) || true
|
||||
|
||||
echo "$OUTPUT"
|
||||
|
||||
# Verify output contains "up to date"
|
||||
if ! echo "$OUTPUT" | grep -qi "up to date"; then
|
||||
echo "FAIL: Expected output to contain 'up to date'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Verify binary is still v1
|
||||
VERSION=$("$AGENT" version)
|
||||
if [ "$VERSION" != "v1.0.0" ]; then
|
||||
echo "FAIL: Expected version v1.0.0 (unchanged), got $VERSION"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Upgrade correctly skipped, version still $VERSION"
|
||||
52
agent/e2e/scripts/test-upgrade-success.sh
Normal file
52
agent/e2e/scripts/test-upgrade-success.sh
Normal file
@@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
ARTIFACTS="/opt/agent/artifacts"
|
||||
AGENT="/tmp/test-agent"
|
||||
|
||||
# Ensure mock server returns v2.0.0
|
||||
curl -sf -X POST http://e2e-mock-server:4050/mock/set-version \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"version":"v2.0.0"}'
|
||||
|
||||
# Copy v1 binary to writable location
|
||||
cp "$ARTIFACTS/agent-v1" "$AGENT"
|
||||
chmod +x "$AGENT"
|
||||
|
||||
# Verify initial version
|
||||
VERSION=$("$AGENT" version)
|
||||
if [ "$VERSION" != "v1.0.0" ]; then
|
||||
echo "FAIL: Expected initial version v1.0.0, got $VERSION"
|
||||
exit 1
|
||||
fi
|
||||
echo "Initial version: $VERSION"
|
||||
|
||||
# Run start — agent will:
|
||||
# 1. Fetch version from mock (v2.0.0 != v1.0.0)
|
||||
# 2. Download v2 binary from mock
|
||||
# 3. Replace itself on disk
|
||||
# 4. Re-exec with same args
|
||||
# 5. Re-exec'd v2 fetches version (v2.0.0 == v2.0.0) → skips update
|
||||
# 6. Proceeds to start → verifies pg_basebackup + DB → exits 0 (stub)
|
||||
echo "Running agent start (expecting upgrade v1 -> v2)..."
|
||||
OUTPUT=$("$AGENT" start \
|
||||
--databasus-host http://e2e-mock-server:4050 \
|
||||
--db-id test-db-id \
|
||||
--token test-token \
|
||||
--pg-host e2e-postgres \
|
||||
--pg-port 5432 \
|
||||
--pg-user testuser \
|
||||
--pg-password testpassword \
|
||||
--wal-dir /tmp/wal \
|
||||
--pg-type host 2>&1) || true
|
||||
|
||||
echo "$OUTPUT"
|
||||
|
||||
# Verify binary on disk is now v2
|
||||
VERSION=$("$AGENT" version)
|
||||
if [ "$VERSION" != "v2.0.0" ]; then
|
||||
echo "FAIL: Expected upgraded version v2.0.0, got $VERSION"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Binary upgraded successfully to $VERSION"
|
||||
19
agent/go.mod
Normal file
19
agent/go.mod
Normal file
@@ -0,0 +1,19 @@
|
||||
module databasus-agent
|
||||
|
||||
go 1.26.1
|
||||
|
||||
require (
|
||||
github.com/jackc/pgx/v5 v5.8.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.14.1 // indirect
|
||||
golang.org/x/text v0.29.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
35
agent/go.sum
Normal file
35
agent/go.sum
Normal file
@@ -0,0 +1,35 @@
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo=
|
||||
github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
|
||||
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
267
agent/internal/config/config.go
Normal file
267
agent/internal/config/config.go
Normal file
@@ -0,0 +1,267 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"databasus-agent/internal/logger"
|
||||
)
|
||||
|
||||
var log = logger.GetLogger()
|
||||
|
||||
const configFileName = "databasus.json"
|
||||
|
||||
type Config struct {
|
||||
DatabasusHost string `json:"databasusHost"`
|
||||
DbID string `json:"dbId"`
|
||||
Token string `json:"token"`
|
||||
PgHost string `json:"pgHost"`
|
||||
PgPort int `json:"pgPort"`
|
||||
PgUser string `json:"pgUser"`
|
||||
PgPassword string `json:"pgPassword"`
|
||||
PgType string `json:"pgType"`
|
||||
PgHostBinDir string `json:"pgHostBinDir"`
|
||||
PgDockerContainerName string `json:"pgDockerContainerName"`
|
||||
WalDir string `json:"walDir"`
|
||||
IsDeleteWalAfterUpload *bool `json:"deleteWalAfterUpload"`
|
||||
|
||||
flags parsedFlags
|
||||
}
|
||||
|
||||
// LoadFromJSONAndArgs reads databasus.json into the struct
|
||||
// and overrides JSON values with any explicitly provided CLI flags.
|
||||
func (c *Config) LoadFromJSONAndArgs(fs *flag.FlagSet, args []string) {
|
||||
c.loadFromJSON()
|
||||
c.applyDefaults()
|
||||
c.initSources()
|
||||
|
||||
c.flags.databasusHost = fs.String(
|
||||
"databasus-host",
|
||||
"",
|
||||
"Databasus server URL (e.g. http://your-server:4005)",
|
||||
)
|
||||
c.flags.dbID = fs.String("db-id", "", "Database ID")
|
||||
c.flags.token = fs.String("token", "", "Agent token")
|
||||
c.flags.pgHost = fs.String("pg-host", "", "PostgreSQL host")
|
||||
c.flags.pgPort = fs.Int("pg-port", 0, "PostgreSQL port")
|
||||
c.flags.pgUser = fs.String("pg-user", "", "PostgreSQL user")
|
||||
c.flags.pgPassword = fs.String("pg-password", "", "PostgreSQL password")
|
||||
c.flags.pgType = fs.String("pg-type", "", "PostgreSQL type: host or docker")
|
||||
c.flags.pgHostBinDir = fs.String("pg-host-bin-dir", "", "Path to PG bin directory (host mode)")
|
||||
c.flags.pgDockerContainerName = fs.String("pg-docker-container-name", "", "Docker container name (docker mode)")
|
||||
c.flags.walDir = fs.String("wal-dir", "", "Path to WAL queue directory")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
c.applyFlags()
|
||||
log.Info("========= Loading config ============")
|
||||
c.logConfigSources()
|
||||
log.Info("========= Config has been loaded ====")
|
||||
}
|
||||
|
||||
// SaveToJSON writes the current struct to databasus.json.
|
||||
func (c *Config) SaveToJSON() error {
|
||||
data, err := json.MarshalIndent(c, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(configFileName, data, 0o644)
|
||||
}
|
||||
|
||||
func (c *Config) loadFromJSON() {
|
||||
data, err := os.ReadFile(configFileName)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
log.Info("No databasus.json found, will create on save")
|
||||
return
|
||||
}
|
||||
|
||||
log.Warn("Failed to read databasus.json", "error", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, c); err != nil {
|
||||
log.Warn("Failed to parse databasus.json", "error", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("Configuration loaded from " + configFileName)
|
||||
}
|
||||
|
||||
func (c *Config) applyDefaults() {
|
||||
if c.PgPort == 0 {
|
||||
c.PgPort = 5432
|
||||
}
|
||||
|
||||
if c.PgType == "" {
|
||||
c.PgType = "host"
|
||||
}
|
||||
|
||||
if c.IsDeleteWalAfterUpload == nil {
|
||||
v := true
|
||||
c.IsDeleteWalAfterUpload = &v
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) initSources() {
|
||||
c.flags.sources = map[string]string{
|
||||
"databasus-host": "not configured",
|
||||
"db-id": "not configured",
|
||||
"token": "not configured",
|
||||
"pg-host": "not configured",
|
||||
"pg-port": "not configured",
|
||||
"pg-user": "not configured",
|
||||
"pg-password": "not configured",
|
||||
"pg-type": "not configured",
|
||||
"pg-host-bin-dir": "not configured",
|
||||
"pg-docker-container-name": "not configured",
|
||||
"wal-dir": "not configured",
|
||||
"delete-wal-after-upload": "not configured",
|
||||
}
|
||||
|
||||
if c.DatabasusHost != "" {
|
||||
c.flags.sources["databasus-host"] = configFileName
|
||||
}
|
||||
|
||||
if c.DbID != "" {
|
||||
c.flags.sources["db-id"] = configFileName
|
||||
}
|
||||
|
||||
if c.Token != "" {
|
||||
c.flags.sources["token"] = configFileName
|
||||
}
|
||||
|
||||
if c.PgHost != "" {
|
||||
c.flags.sources["pg-host"] = configFileName
|
||||
}
|
||||
|
||||
// PgPort always has a value after applyDefaults
|
||||
c.flags.sources["pg-port"] = configFileName
|
||||
|
||||
if c.PgUser != "" {
|
||||
c.flags.sources["pg-user"] = configFileName
|
||||
}
|
||||
|
||||
if c.PgPassword != "" {
|
||||
c.flags.sources["pg-password"] = configFileName
|
||||
}
|
||||
|
||||
// PgType always has a value after applyDefaults
|
||||
c.flags.sources["pg-type"] = configFileName
|
||||
|
||||
if c.PgHostBinDir != "" {
|
||||
c.flags.sources["pg-host-bin-dir"] = configFileName
|
||||
}
|
||||
|
||||
if c.PgDockerContainerName != "" {
|
||||
c.flags.sources["pg-docker-container-name"] = configFileName
|
||||
}
|
||||
|
||||
if c.WalDir != "" {
|
||||
c.flags.sources["wal-dir"] = configFileName
|
||||
}
|
||||
|
||||
// IsDeleteWalAfterUpload always has a value after applyDefaults
|
||||
c.flags.sources["delete-wal-after-upload"] = configFileName
|
||||
}
|
||||
|
||||
func (c *Config) applyFlags() {
|
||||
if c.flags.databasusHost != nil && *c.flags.databasusHost != "" {
|
||||
c.DatabasusHost = *c.flags.databasusHost
|
||||
c.flags.sources["databasus-host"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.dbID != nil && *c.flags.dbID != "" {
|
||||
c.DbID = *c.flags.dbID
|
||||
c.flags.sources["db-id"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.token != nil && *c.flags.token != "" {
|
||||
c.Token = *c.flags.token
|
||||
c.flags.sources["token"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgHost != nil && *c.flags.pgHost != "" {
|
||||
c.PgHost = *c.flags.pgHost
|
||||
c.flags.sources["pg-host"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgPort != nil && *c.flags.pgPort != 0 {
|
||||
c.PgPort = *c.flags.pgPort
|
||||
c.flags.sources["pg-port"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgUser != nil && *c.flags.pgUser != "" {
|
||||
c.PgUser = *c.flags.pgUser
|
||||
c.flags.sources["pg-user"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgPassword != nil && *c.flags.pgPassword != "" {
|
||||
c.PgPassword = *c.flags.pgPassword
|
||||
c.flags.sources["pg-password"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgType != nil && *c.flags.pgType != "" {
|
||||
c.PgType = *c.flags.pgType
|
||||
c.flags.sources["pg-type"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgHostBinDir != nil && *c.flags.pgHostBinDir != "" {
|
||||
c.PgHostBinDir = *c.flags.pgHostBinDir
|
||||
c.flags.sources["pg-host-bin-dir"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.pgDockerContainerName != nil && *c.flags.pgDockerContainerName != "" {
|
||||
c.PgDockerContainerName = *c.flags.pgDockerContainerName
|
||||
c.flags.sources["pg-docker-container-name"] = "command line args"
|
||||
}
|
||||
|
||||
if c.flags.walDir != nil && *c.flags.walDir != "" {
|
||||
c.WalDir = *c.flags.walDir
|
||||
c.flags.sources["wal-dir"] = "command line args"
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) logConfigSources() {
|
||||
log.Info("databasus-host", "value", c.DatabasusHost, "source", c.flags.sources["databasus-host"])
|
||||
log.Info("db-id", "value", c.DbID, "source", c.flags.sources["db-id"])
|
||||
log.Info("token", "value", maskSensitive(c.Token), "source", c.flags.sources["token"])
|
||||
log.Info("pg-host", "value", c.PgHost, "source", c.flags.sources["pg-host"])
|
||||
log.Info("pg-port", "value", c.PgPort, "source", c.flags.sources["pg-port"])
|
||||
log.Info("pg-user", "value", c.PgUser, "source", c.flags.sources["pg-user"])
|
||||
log.Info("pg-password", "value", maskSensitive(c.PgPassword), "source", c.flags.sources["pg-password"])
|
||||
log.Info("pg-type", "value", c.PgType, "source", c.flags.sources["pg-type"])
|
||||
log.Info("pg-host-bin-dir", "value", c.PgHostBinDir, "source", c.flags.sources["pg-host-bin-dir"])
|
||||
log.Info(
|
||||
"pg-docker-container-name",
|
||||
"value",
|
||||
c.PgDockerContainerName,
|
||||
"source",
|
||||
c.flags.sources["pg-docker-container-name"],
|
||||
)
|
||||
log.Info("wal-dir", "value", c.WalDir, "source", c.flags.sources["wal-dir"])
|
||||
log.Info(
|
||||
"delete-wal-after-upload",
|
||||
"value",
|
||||
fmt.Sprintf("%v", *c.IsDeleteWalAfterUpload),
|
||||
"source",
|
||||
c.flags.sources["delete-wal-after-upload"],
|
||||
)
|
||||
}
|
||||
|
||||
func maskSensitive(value string) string {
|
||||
if value == "" {
|
||||
return "(not set)"
|
||||
}
|
||||
|
||||
visibleLen := max(len(value)/4, 1)
|
||||
|
||||
return value[:visibleLen] + "***"
|
||||
}
|
||||
301
agent/internal/config/config_test.go
Normal file
301
agent/internal/config/config_test.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_LoadFromJSONAndArgs_ValuesLoadedFromJSON(t *testing.T) {
|
||||
dir := setupTempDir(t)
|
||||
writeConfigJSON(t, dir, Config{
|
||||
DatabasusHost: "http://json-host:4005",
|
||||
DbID: "json-db-id",
|
||||
Token: "json-token",
|
||||
})
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{})
|
||||
|
||||
assert.Equal(t, "http://json-host:4005", cfg.DatabasusHost)
|
||||
assert.Equal(t, "json-db-id", cfg.DbID)
|
||||
assert.Equal(t, "json-token", cfg.Token)
|
||||
}
|
||||
|
||||
func Test_LoadFromJSONAndArgs_ValuesLoadedFromArgs_WhenNoJSON(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{
|
||||
"--databasus-host", "http://arg-host:4005",
|
||||
"--db-id", "arg-db-id",
|
||||
"--token", "arg-token",
|
||||
})
|
||||
|
||||
assert.Equal(t, "http://arg-host:4005", cfg.DatabasusHost)
|
||||
assert.Equal(t, "arg-db-id", cfg.DbID)
|
||||
assert.Equal(t, "arg-token", cfg.Token)
|
||||
}
|
||||
|
||||
func Test_LoadFromJSONAndArgs_ArgsOverrideJSON(t *testing.T) {
|
||||
dir := setupTempDir(t)
|
||||
writeConfigJSON(t, dir, Config{
|
||||
DatabasusHost: "http://json-host:4005",
|
||||
DbID: "json-db-id",
|
||||
Token: "json-token",
|
||||
})
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{
|
||||
"--databasus-host", "http://arg-host:9999",
|
||||
"--db-id", "arg-db-id-override",
|
||||
"--token", "arg-token-override",
|
||||
})
|
||||
|
||||
assert.Equal(t, "http://arg-host:9999", cfg.DatabasusHost)
|
||||
assert.Equal(t, "arg-db-id-override", cfg.DbID)
|
||||
assert.Equal(t, "arg-token-override", cfg.Token)
|
||||
}
|
||||
|
||||
func Test_LoadFromJSONAndArgs_PartialArgsOverrideJSON(t *testing.T) {
|
||||
dir := setupTempDir(t)
|
||||
writeConfigJSON(t, dir, Config{
|
||||
DatabasusHost: "http://json-host:4005",
|
||||
DbID: "json-db-id",
|
||||
Token: "json-token",
|
||||
})
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{
|
||||
"--databasus-host", "http://arg-host-only:4005",
|
||||
})
|
||||
|
||||
assert.Equal(t, "http://arg-host-only:4005", cfg.DatabasusHost)
|
||||
assert.Equal(t, "json-db-id", cfg.DbID)
|
||||
assert.Equal(t, "json-token", cfg.Token)
|
||||
}
|
||||
|
||||
func Test_SaveToJSON_ConfigSavedCorrectly(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
|
||||
deleteWal := true
|
||||
cfg := &Config{
|
||||
DatabasusHost: "http://save-host:4005",
|
||||
DbID: "save-db-id",
|
||||
Token: "save-token",
|
||||
IsDeleteWalAfterUpload: &deleteWal,
|
||||
}
|
||||
|
||||
err := cfg.SaveToJSON()
|
||||
require.NoError(t, err)
|
||||
|
||||
saved := readConfigJSON(t)
|
||||
|
||||
assert.Equal(t, "http://save-host:4005", saved.DatabasusHost)
|
||||
assert.Equal(t, "save-db-id", saved.DbID)
|
||||
assert.Equal(t, "save-token", saved.Token)
|
||||
}
|
||||
|
||||
func Test_SaveToJSON_AfterArgsOverrideJSON_SavedFileContainsMergedValues(t *testing.T) {
|
||||
dir := setupTempDir(t)
|
||||
writeConfigJSON(t, dir, Config{
|
||||
DatabasusHost: "http://json-host:4005",
|
||||
DbID: "json-db-id",
|
||||
Token: "json-token",
|
||||
})
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{
|
||||
"--databasus-host", "http://override-host:9999",
|
||||
})
|
||||
|
||||
err := cfg.SaveToJSON()
|
||||
require.NoError(t, err)
|
||||
|
||||
saved := readConfigJSON(t)
|
||||
|
||||
assert.Equal(t, "http://override-host:9999", saved.DatabasusHost)
|
||||
assert.Equal(t, "json-db-id", saved.DbID)
|
||||
assert.Equal(t, "json-token", saved.Token)
|
||||
}
|
||||
|
||||
func Test_LoadFromJSONAndArgs_PgFieldsLoadedFromJSON(t *testing.T) {
|
||||
dir := setupTempDir(t)
|
||||
deleteWal := false
|
||||
writeConfigJSON(t, dir, Config{
|
||||
DatabasusHost: "http://json-host:4005",
|
||||
DbID: "json-db-id",
|
||||
Token: "json-token",
|
||||
PgHost: "pg-json-host",
|
||||
PgPort: 5433,
|
||||
PgUser: "pg-json-user",
|
||||
PgPassword: "pg-json-pass",
|
||||
PgType: "docker",
|
||||
PgHostBinDir: "/usr/bin",
|
||||
PgDockerContainerName: "pg-container",
|
||||
WalDir: "/opt/wal",
|
||||
IsDeleteWalAfterUpload: &deleteWal,
|
||||
})
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{})
|
||||
|
||||
assert.Equal(t, "pg-json-host", cfg.PgHost)
|
||||
assert.Equal(t, 5433, cfg.PgPort)
|
||||
assert.Equal(t, "pg-json-user", cfg.PgUser)
|
||||
assert.Equal(t, "pg-json-pass", cfg.PgPassword)
|
||||
assert.Equal(t, "docker", cfg.PgType)
|
||||
assert.Equal(t, "/usr/bin", cfg.PgHostBinDir)
|
||||
assert.Equal(t, "pg-container", cfg.PgDockerContainerName)
|
||||
assert.Equal(t, "/opt/wal", cfg.WalDir)
|
||||
assert.Equal(t, false, *cfg.IsDeleteWalAfterUpload)
|
||||
}
|
||||
|
||||
func Test_LoadFromJSONAndArgs_PgFieldsLoadedFromArgs(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{
|
||||
"--pg-host", "arg-pg-host",
|
||||
"--pg-port", "5433",
|
||||
"--pg-user", "arg-pg-user",
|
||||
"--pg-password", "arg-pg-pass",
|
||||
"--pg-type", "docker",
|
||||
"--pg-host-bin-dir", "/custom/bin",
|
||||
"--pg-docker-container-name", "my-pg",
|
||||
"--wal-dir", "/var/wal",
|
||||
})
|
||||
|
||||
assert.Equal(t, "arg-pg-host", cfg.PgHost)
|
||||
assert.Equal(t, 5433, cfg.PgPort)
|
||||
assert.Equal(t, "arg-pg-user", cfg.PgUser)
|
||||
assert.Equal(t, "arg-pg-pass", cfg.PgPassword)
|
||||
assert.Equal(t, "docker", cfg.PgType)
|
||||
assert.Equal(t, "/custom/bin", cfg.PgHostBinDir)
|
||||
assert.Equal(t, "my-pg", cfg.PgDockerContainerName)
|
||||
assert.Equal(t, "/var/wal", cfg.WalDir)
|
||||
}
|
||||
|
||||
func Test_LoadFromJSONAndArgs_PgArgsOverrideJSON(t *testing.T) {
|
||||
dir := setupTempDir(t)
|
||||
writeConfigJSON(t, dir, Config{
|
||||
PgHost: "json-host",
|
||||
PgPort: 5432,
|
||||
PgUser: "json-user",
|
||||
PgType: "host",
|
||||
WalDir: "/json/wal",
|
||||
})
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{
|
||||
"--pg-host", "arg-host",
|
||||
"--pg-port", "5433",
|
||||
"--pg-user", "arg-user",
|
||||
"--pg-type", "docker",
|
||||
"--pg-docker-container-name", "my-container",
|
||||
"--wal-dir", "/arg/wal",
|
||||
})
|
||||
|
||||
assert.Equal(t, "arg-host", cfg.PgHost)
|
||||
assert.Equal(t, 5433, cfg.PgPort)
|
||||
assert.Equal(t, "arg-user", cfg.PgUser)
|
||||
assert.Equal(t, "docker", cfg.PgType)
|
||||
assert.Equal(t, "my-container", cfg.PgDockerContainerName)
|
||||
assert.Equal(t, "/arg/wal", cfg.WalDir)
|
||||
}
|
||||
|
||||
func Test_LoadFromJSONAndArgs_DefaultsApplied_WhenNoJSONAndNoArgs(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
|
||||
cfg := &Config{}
|
||||
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
cfg.LoadFromJSONAndArgs(fs, []string{})
|
||||
|
||||
assert.Equal(t, 5432, cfg.PgPort)
|
||||
assert.Equal(t, "host", cfg.PgType)
|
||||
require.NotNil(t, cfg.IsDeleteWalAfterUpload)
|
||||
assert.Equal(t, true, *cfg.IsDeleteWalAfterUpload)
|
||||
}
|
||||
|
||||
func Test_SaveToJSON_PgFieldsSavedCorrectly(t *testing.T) {
|
||||
setupTempDir(t)
|
||||
|
||||
deleteWal := false
|
||||
cfg := &Config{
|
||||
DatabasusHost: "http://host:4005",
|
||||
DbID: "db-id",
|
||||
Token: "token",
|
||||
PgHost: "pg-host",
|
||||
PgPort: 5433,
|
||||
PgUser: "pg-user",
|
||||
PgPassword: "pg-pass",
|
||||
PgType: "docker",
|
||||
PgHostBinDir: "/usr/bin",
|
||||
PgDockerContainerName: "pg-container",
|
||||
WalDir: "/opt/wal",
|
||||
IsDeleteWalAfterUpload: &deleteWal,
|
||||
}
|
||||
|
||||
err := cfg.SaveToJSON()
|
||||
require.NoError(t, err)
|
||||
|
||||
saved := readConfigJSON(t)
|
||||
|
||||
assert.Equal(t, "pg-host", saved.PgHost)
|
||||
assert.Equal(t, 5433, saved.PgPort)
|
||||
assert.Equal(t, "pg-user", saved.PgUser)
|
||||
assert.Equal(t, "pg-pass", saved.PgPassword)
|
||||
assert.Equal(t, "docker", saved.PgType)
|
||||
assert.Equal(t, "/usr/bin", saved.PgHostBinDir)
|
||||
assert.Equal(t, "pg-container", saved.PgDockerContainerName)
|
||||
assert.Equal(t, "/opt/wal", saved.WalDir)
|
||||
require.NotNil(t, saved.IsDeleteWalAfterUpload)
|
||||
assert.Equal(t, false, *saved.IsDeleteWalAfterUpload)
|
||||
}
|
||||
|
||||
func setupTempDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
origDir, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
|
||||
dir := t.TempDir()
|
||||
require.NoError(t, os.Chdir(dir))
|
||||
|
||||
t.Cleanup(func() { os.Chdir(origDir) })
|
||||
|
||||
return dir
|
||||
}
|
||||
|
||||
func writeConfigJSON(t *testing.T, dir string, cfg Config) {
|
||||
t.Helper()
|
||||
|
||||
data, err := json.MarshalIndent(cfg, "", " ")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, os.WriteFile(dir+"/"+configFileName, data, 0o644))
|
||||
}
|
||||
|
||||
func readConfigJSON(t *testing.T) Config {
|
||||
t.Helper()
|
||||
|
||||
data, err := os.ReadFile(configFileName)
|
||||
require.NoError(t, err)
|
||||
|
||||
var cfg Config
|
||||
require.NoError(t, json.Unmarshal(data, &cfg))
|
||||
|
||||
return cfg
|
||||
}
|
||||
17
agent/internal/config/dto.go
Normal file
17
agent/internal/config/dto.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package config
|
||||
|
||||
type parsedFlags struct {
|
||||
databasusHost *string
|
||||
dbID *string
|
||||
token *string
|
||||
pgHost *string
|
||||
pgPort *int
|
||||
pgUser *string
|
||||
pgPassword *string
|
||||
pgType *string
|
||||
pgHostBinDir *string
|
||||
pgDockerContainerName *string
|
||||
walDir *string
|
||||
|
||||
sources map[string]string
|
||||
}
|
||||
179
agent/internal/features/start/start.go
Normal file
179
agent/internal/features/start/start.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package start
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"databasus-agent/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
pgBasebackupVerifyTimeout = 10 * time.Second
|
||||
dbVerifyTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
func Run(cfg *config.Config, log *slog.Logger) error {
|
||||
if err := validateConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := verifyPgBasebackup(cfg, log); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := verifyDatabase(cfg, log); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("start: stub — not yet implemented",
|
||||
"dbId", cfg.DbID,
|
||||
"hasToken", cfg.Token != "",
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateConfig(cfg *config.Config) error {
|
||||
if cfg.DatabasusHost == "" {
|
||||
return errors.New("argument databasus-host is required")
|
||||
}
|
||||
|
||||
if cfg.DbID == "" {
|
||||
return errors.New("argument db-id is required")
|
||||
}
|
||||
|
||||
if cfg.Token == "" {
|
||||
return errors.New("argument token is required")
|
||||
}
|
||||
|
||||
if cfg.PgHost == "" {
|
||||
return errors.New("argument pg-host is required")
|
||||
}
|
||||
|
||||
if cfg.PgPort <= 0 {
|
||||
return errors.New("argument pg-port must be a positive number")
|
||||
}
|
||||
|
||||
if cfg.PgUser == "" {
|
||||
return errors.New("argument pg-user is required")
|
||||
}
|
||||
|
||||
if cfg.PgType != "host" && cfg.PgType != "docker" {
|
||||
return fmt.Errorf("argument pg-type must be 'host' or 'docker', got '%s'", cfg.PgType)
|
||||
}
|
||||
|
||||
if cfg.WalDir == "" {
|
||||
return errors.New("argument wal-dir is required")
|
||||
}
|
||||
|
||||
if cfg.PgType == "docker" && cfg.PgDockerContainerName == "" {
|
||||
return errors.New("argument pg-docker-container-name is required when pg-type is 'docker'")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyPgBasebackup(cfg *config.Config, log *slog.Logger) error {
|
||||
switch cfg.PgType {
|
||||
case "host":
|
||||
return verifyPgBasebackupHost(cfg, log)
|
||||
case "docker":
|
||||
return verifyPgBasebackupDocker(cfg, log)
|
||||
default:
|
||||
return fmt.Errorf("unexpected pg-type: %s", cfg.PgType)
|
||||
}
|
||||
}
|
||||
|
||||
func verifyPgBasebackupHost(cfg *config.Config, log *slog.Logger) error {
|
||||
binary := "pg_basebackup"
|
||||
if cfg.PgHostBinDir != "" {
|
||||
binary = filepath.Join(cfg.PgHostBinDir, "pg_basebackup")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgBasebackupVerifyTimeout)
|
||||
defer cancel()
|
||||
|
||||
output, err := exec.CommandContext(ctx, binary, "--version").CombinedOutput()
|
||||
if err != nil {
|
||||
if cfg.PgHostBinDir != "" {
|
||||
return fmt.Errorf(
|
||||
"pg_basebackup not found at '%s': %w. Verify pg-host-bin-dir is correct",
|
||||
binary, err,
|
||||
)
|
||||
}
|
||||
|
||||
return fmt.Errorf(
|
||||
"pg_basebackup not found in PATH: %w. Install PostgreSQL client tools or set pg-host-bin-dir",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
log.Info("pg_basebackup verified", "version", strings.TrimSpace(string(output)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyPgBasebackupDocker(cfg *config.Config, log *slog.Logger) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgBasebackupVerifyTimeout)
|
||||
defer cancel()
|
||||
|
||||
output, err := exec.CommandContext(ctx,
|
||||
"docker", "exec", cfg.PgDockerContainerName,
|
||||
"pg_basebackup", "--version",
|
||||
).CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"pg_basebackup not available in container '%s': %w. "+
|
||||
"Check that the container is running and pg_basebackup is installed inside it",
|
||||
cfg.PgDockerContainerName, err,
|
||||
)
|
||||
}
|
||||
|
||||
log.Info("pg_basebackup verified (docker)",
|
||||
"container", cfg.PgDockerContainerName,
|
||||
"version", strings.TrimSpace(string(output)),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyDatabase(cfg *config.Config, log *slog.Logger) error {
|
||||
connStr := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=postgres sslmode=disable",
|
||||
cfg.PgHost, cfg.PgPort, cfg.PgUser, cfg.PgPassword,
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dbVerifyTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pgx.Connect(ctx, connStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to connect to PostgreSQL at %s:%d as user '%s': %w",
|
||||
cfg.PgHost, cfg.PgPort, cfg.PgUser, err,
|
||||
)
|
||||
}
|
||||
defer func() { _ = conn.Close(ctx) }()
|
||||
|
||||
if err := conn.Ping(ctx); err != nil {
|
||||
return fmt.Errorf("PostgreSQL ping failed at %s:%d: %w",
|
||||
cfg.PgHost, cfg.PgPort, err,
|
||||
)
|
||||
}
|
||||
|
||||
log.Info("PostgreSQL connection verified",
|
||||
"host", cfg.PgHost,
|
||||
"port", cfg.PgPort,
|
||||
"user", cfg.PgUser,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
5
agent/internal/features/upgrade/dto.go
Normal file
5
agent/internal/features/upgrade/dto.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package upgrade
|
||||
|
||||
type versionResponse struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
154
agent/internal/features/upgrade/upgrader.go
Normal file
154
agent/internal/features/upgrade/upgrader.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package upgrade
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
func CheckAndUpdate(databasusHost, currentVersion string, isDev bool, log *slog.Logger) error {
|
||||
if isDev {
|
||||
log.Info("Skipping update check (development mode)")
|
||||
return nil
|
||||
}
|
||||
|
||||
serverVersion, err := fetchServerVersion(databasusHost, log)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"unable to check version, please verify Databasus server is available at %s: %w",
|
||||
databasusHost,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
if serverVersion == currentVersion {
|
||||
log.Info("Agent version is up to date", "version", currentVersion)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Info("Updating agent...", "current", currentVersion, "target", serverVersion)
|
||||
|
||||
selfPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to determine executable path: %w", err)
|
||||
}
|
||||
|
||||
tempPath := selfPath + ".update"
|
||||
|
||||
defer func() {
|
||||
_ = os.Remove(tempPath)
|
||||
}()
|
||||
|
||||
if err := downloadBinary(databasusHost, tempPath); err != nil {
|
||||
return fmt.Errorf("failed to download update: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(tempPath, 0o755); err != nil {
|
||||
return fmt.Errorf("failed to set permissions on update: %w", err)
|
||||
}
|
||||
|
||||
if err := verifyBinary(tempPath, serverVersion); err != nil {
|
||||
return fmt.Errorf("update verification failed: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tempPath, selfPath); err != nil {
|
||||
return fmt.Errorf("failed to replace binary (try --skip-update if this persists): %w", err)
|
||||
}
|
||||
|
||||
log.Info("Update complete, re-executing...")
|
||||
|
||||
return syscall.Exec(selfPath, os.Args, os.Environ())
|
||||
}
|
||||
|
||||
func fetchServerVersion(host string, log *slog.Logger) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, host+"/api/v1/system/version", nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
log.Warn("Could not reach server for update check, continuing", "error", err)
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Warn(
|
||||
"Server returned non-OK status for version check, continuing",
|
||||
"status",
|
||||
resp.StatusCode,
|
||||
)
|
||||
return "", fmt.Errorf("status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var ver versionResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&ver); err != nil {
|
||||
log.Warn("Failed to parse server version response, continuing", "error", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
return ver.Version, nil
|
||||
}
|
||||
|
||||
func downloadBinary(host, destPath string) error {
|
||||
url := fmt.Sprintf("%s/api/v1/system/agent?arch=%s", host, runtime.GOARCH)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("server returned %d for agent download", resp.StatusCode)
|
||||
}
|
||||
|
||||
f, err := os.Create(destPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
_, err = io.Copy(f, resp.Body)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func verifyBinary(binaryPath, expectedVersion string) error {
|
||||
cmd := exec.CommandContext(context.Background(), binaryPath, "version")
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("binary failed to execute: %w", err)
|
||||
}
|
||||
|
||||
got := strings.TrimSpace(string(output))
|
||||
if got != expectedVersion {
|
||||
return fmt.Errorf("version mismatch: expected %q, got %q", expectedVersion, got)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
45
agent/internal/logger/logger.go
Normal file
45
agent/internal/logger/logger.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
loggerInstance *slog.Logger
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
func Init(isDebug bool) {
|
||||
level := slog.LevelInfo
|
||||
if isDebug {
|
||||
level = slog.LevelDebug
|
||||
}
|
||||
|
||||
once.Do(func() {
|
||||
loggerInstance = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
|
||||
Level: level,
|
||||
ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr {
|
||||
if a.Key == slog.TimeKey {
|
||||
a.Value = slog.StringValue(time.Now().Format("2006/01/02 15:04:05"))
|
||||
}
|
||||
if a.Key == slog.LevelKey {
|
||||
return slog.Attr{}
|
||||
}
|
||||
|
||||
return a
|
||||
},
|
||||
}))
|
||||
})
|
||||
}
|
||||
|
||||
// GetLogger returns a singleton slog.Logger that logs to the console
|
||||
func GetLogger() *slog.Logger {
|
||||
if loggerInstance == nil {
|
||||
Init(false)
|
||||
}
|
||||
|
||||
return loggerInstance
|
||||
}
|
||||
17
assets/tools/README.md
Normal file
17
assets/tools/README.md
Normal file
@@ -0,0 +1,17 @@
|
||||
We keep binaries here to speed up CI \ CD tasks and building.
|
||||
|
||||
Docker image needs:
|
||||
- PostgreSQL client tools (versions 12-18)
|
||||
- MySQL client tools (versions 5.7, 8.0, 8.4, 9)
|
||||
- MariaDB client tools (versions 10.6, 12.1)
|
||||
- MongoDB Database Tools (latest)
|
||||
|
||||
For the most of tools, we need a couple of binaries for each version. However, if we download them on each run, it will download a couple of GBs each time.
|
||||
|
||||
So, for speed up we keep only required executables (like pg_dump, mysqldump, mariadb-dump, mongodump, etc.).
|
||||
|
||||
It takes:
|
||||
- ~ 100MB for ARM
|
||||
- ~ 100MB for x64
|
||||
|
||||
Instead of GBs. See Dockefile for usage details.
|
||||
@@ -1,152 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
Always place private methods to the bottom of file
|
||||
|
||||
**This rule applies to ALL Go files including tests, services, controllers, repositories, etc.**
|
||||
|
||||
In Go, exported (public) functions/methods start with uppercase letters, while unexported (private) ones start with lowercase letters.
|
||||
|
||||
## Structure Order:
|
||||
|
||||
1. Type definitions and constants
|
||||
2. Public methods/functions (uppercase)
|
||||
3. Private methods/functions (lowercase)
|
||||
|
||||
## Examples:
|
||||
|
||||
### Service with methods:
|
||||
|
||||
```go
|
||||
type UserService struct {
|
||||
repository *UserRepository
|
||||
}
|
||||
|
||||
// Public methods first
|
||||
func (s *UserService) CreateUser(user *User) error {
|
||||
if err := s.validateUser(user); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.repository.Save(user)
|
||||
}
|
||||
|
||||
func (s *UserService) GetUser(id uuid.UUID) (*User, error) {
|
||||
return s.repository.FindByID(id)
|
||||
}
|
||||
|
||||
// Private methods at the bottom
|
||||
func (s *UserService) validateUser(user *User) error {
|
||||
if user.Name == "" {
|
||||
return errors.New("name is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
### Package-level functions:
|
||||
|
||||
```go
|
||||
package utils
|
||||
|
||||
// Public functions first
|
||||
func ProcessData(data []byte) (Result, error) {
|
||||
cleaned := sanitizeInput(data)
|
||||
return parseData(cleaned)
|
||||
}
|
||||
|
||||
func ValidateInput(input string) bool {
|
||||
return isValidFormat(input) && checkLength(input)
|
||||
}
|
||||
|
||||
// Private functions at the bottom
|
||||
func sanitizeInput(data []byte) []byte {
|
||||
// implementation
|
||||
}
|
||||
|
||||
func parseData(data []byte) (Result, error) {
|
||||
// implementation
|
||||
}
|
||||
|
||||
func isValidFormat(input string) bool {
|
||||
// implementation
|
||||
}
|
||||
|
||||
func checkLength(input string) bool {
|
||||
// implementation
|
||||
}
|
||||
```
|
||||
|
||||
### Test files:
|
||||
|
||||
```go
|
||||
package user_test
|
||||
|
||||
// Public test functions first
|
||||
func Test_CreateUser_ValidInput_UserCreated(t *testing.T) {
|
||||
user := createTestUser()
|
||||
result, err := service.CreateUser(user)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
|
||||
func Test_GetUser_ExistingUser_ReturnsUser(t *testing.T) {
|
||||
user := createTestUser()
|
||||
// test implementation
|
||||
}
|
||||
|
||||
// Private helper functions at the bottom
|
||||
func createTestUser() *User {
|
||||
return &User{
|
||||
Name: "Test User",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
}
|
||||
|
||||
func setupTestDatabase() *Database {
|
||||
// setup implementation
|
||||
}
|
||||
```
|
||||
|
||||
### Controller example:
|
||||
|
||||
```go
|
||||
type ProjectController struct {
|
||||
service *ProjectService
|
||||
}
|
||||
|
||||
// Public HTTP handlers first
|
||||
func (c *ProjectController) CreateProject(ctx *gin.Context) {
|
||||
var request CreateProjectRequest
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
c.handleError(ctx, err)
|
||||
return
|
||||
}
|
||||
// handler logic
|
||||
}
|
||||
|
||||
func (c *ProjectController) GetProject(ctx *gin.Context) {
|
||||
projectID := c.extractProjectID(ctx)
|
||||
// handler logic
|
||||
}
|
||||
|
||||
// Private helper methods at the bottom
|
||||
func (c *ProjectController) handleError(ctx *gin.Context, err error) {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
}
|
||||
|
||||
func (c *ProjectController) extractProjectID(ctx *gin.Context) uuid.UUID {
|
||||
return uuid.MustParse(ctx.Param("projectId"))
|
||||
}
|
||||
```
|
||||
|
||||
## Key Points:
|
||||
|
||||
- **Exported/Public** = starts with uppercase letter (CreateUser, GetProject)
|
||||
- **Unexported/Private** = starts with lowercase letter (validateUser, handleError)
|
||||
- This improves code readability by showing the public API first
|
||||
- Private helpers are implementation details, so they go at the bottom
|
||||
- Apply this rule consistently across ALL Go files in the project
|
||||
@@ -1,45 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
## Comment Guidelines
|
||||
|
||||
1. **No obvious comments** - Don't state what the code already clearly shows
|
||||
2. **Functions and variables should have meaningful names** - Code should be self-documenting
|
||||
3. **Comments for unclear code only** - Only add comments when code logic isn't immediately clear
|
||||
|
||||
## Key Principles:
|
||||
|
||||
- **Code should tell a story** - Use descriptive variable and function names
|
||||
- **Comments explain WHY, not WHAT** - The code shows what happens, comments explain business logic or complex decisions
|
||||
- **Prefer refactoring over commenting** - If code needs explaining, consider making it clearer instead
|
||||
- **API documentation is required** - Swagger comments for all HTTP endpoints are mandatory
|
||||
- **Complex algorithms deserve comments** - Mathematical formulas, business rules, or non-obvious optimizations
|
||||
|
||||
Example of useless comment:
|
||||
|
||||
1.
|
||||
|
||||
```sql
|
||||
// Create projects table
|
||||
CREATE TABLE projects (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
name TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
```
|
||||
|
||||
2.
|
||||
|
||||
```go
|
||||
// Create test project
|
||||
project := CreateTestProject(projectName, user, router)
|
||||
```
|
||||
|
||||
3.
|
||||
|
||||
```go
|
||||
// CreateValidLogItems creates valid log items for testing
|
||||
func CreateValidLogItems(count int, uniqueID string) []logs_receiving.LogItemRequestDTO {
|
||||
```
|
||||
@@ -1,133 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
1. When we write controller:
|
||||
|
||||
- we combine all routes to single controller
|
||||
- names them as .WhatWeDo (not "handlers") concept
|
||||
|
||||
2. We use gin and \*gin.Context for all routes.
|
||||
Example:
|
||||
|
||||
func (c *TasksController) GetAvailableTasks(ctx *gin.Context) ...
|
||||
|
||||
3. We document all routes with Swagger in the following format:
|
||||
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
)
|
||||
|
||||
type AuditLogController struct {
|
||||
auditLogService \*AuditLogService
|
||||
}
|
||||
|
||||
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// All audit log endpoints require authentication (handled in main.go)
|
||||
auditRoutes := router.Group("/audit-logs")
|
||||
|
||||
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
|
||||
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
|
||||
|
||||
}
|
||||
|
||||
// GetGlobalAuditLogs
|
||||
// @Summary Get global audit logs (ADMIN only)
|
||||
// @Description Retrieve all audit logs across the system
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/global [get]
|
||||
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(\*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "only administrators can view global audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
|
||||
}
|
||||
|
||||
// GetUserAuditLogs
|
||||
// @Summary Get user audit logs
|
||||
// @Description Retrieve audit logs for a specific user
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param userId path string true "User ID"
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/users/{userId} [get]
|
||||
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(\*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
userIDStr := ctx.Param("userId")
|
||||
targetUserID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view user audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
|
||||
}
|
||||
@@ -1,671 +0,0 @@
|
||||
---
|
||||
alwaysApply: false
|
||||
---
|
||||
|
||||
This is example of CRUD:
|
||||
|
||||
------ backend/internal/features/audit_logs/controller.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogController struct {
|
||||
auditLogService *AuditLogService
|
||||
}
|
||||
|
||||
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// All audit log endpoints require authentication (handled in main.go)
|
||||
auditRoutes := router.Group("/audit-logs")
|
||||
|
||||
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
|
||||
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
|
||||
}
|
||||
|
||||
// GetGlobalAuditLogs
|
||||
// @Summary Get global audit logs (ADMIN only)
|
||||
// @Description Retrieve all audit logs across the system
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/global [get]
|
||||
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "only administrators can view global audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// GetUserAuditLogs
|
||||
// @Summary Get user audit logs
|
||||
// @Description Retrieve audit logs for a specific user
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param userId path string true "User ID"
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/users/{userId} [get]
|
||||
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
userIDStr := ctx.Param("userId")
|
||||
targetUserID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view user audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/controller_test.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
"databasus-backend/internal/storage"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_GetGlobalAuditLogs_AdminSucceedsAndMemberGetsForbidden(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
memberUser := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
projectID := uuid.New()
|
||||
|
||||
// Create test logs
|
||||
createAuditLog(service, "Test log with user", &adminUser.UserID, nil)
|
||||
createAuditLog(service, "Test log with project", nil, &projectID)
|
||||
createAuditLog(service, "Test log standalone", nil, nil)
|
||||
|
||||
// Test ADMIN can access global logs
|
||||
var response GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
"/api/v1/audit-logs/global?limit=10", "Bearer "+adminUser.Token, http.StatusOK, &response)
|
||||
|
||||
assert.GreaterOrEqual(t, len(response.AuditLogs), 3)
|
||||
assert.GreaterOrEqual(t, response.Total, int64(3))
|
||||
|
||||
messages := extractMessages(response.AuditLogs)
|
||||
assert.Contains(t, messages, "Test log with user")
|
||||
assert.Contains(t, messages, "Test log with project")
|
||||
assert.Contains(t, messages, "Test log standalone")
|
||||
|
||||
// Test MEMBER cannot access global logs
|
||||
resp := test_utils.MakeGetRequest(t, router, "/api/v1/audit-logs/global",
|
||||
"Bearer "+memberUser.Token, http.StatusForbidden)
|
||||
assert.Contains(t, string(resp.Body), "only administrators can view global audit logs")
|
||||
}
|
||||
|
||||
func Test_GetUserAuditLogs_PermissionsEnforcedCorrectly(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
projectID := uuid.New()
|
||||
|
||||
// Create test logs for different users
|
||||
createAuditLog(service, "Test log user1 first", &user1.UserID, nil)
|
||||
createAuditLog(service, "Test log user1 second", &user1.UserID, &projectID)
|
||||
createAuditLog(service, "Test log user2 first", &user2.UserID, nil)
|
||||
createAuditLog(service, "Test log user2 second", &user2.UserID, &projectID)
|
||||
createAuditLog(service, "Test project log", nil, &projectID)
|
||||
|
||||
// Test ADMIN can view any user's logs
|
||||
var user1Response GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s?limit=10", user1.UserID.String()),
|
||||
"Bearer "+adminUser.Token, http.StatusOK, &user1Response)
|
||||
|
||||
assert.Equal(t, 2, len(user1Response.AuditLogs))
|
||||
messages := extractMessages(user1Response.AuditLogs)
|
||||
assert.Contains(t, messages, "Test log user1 first")
|
||||
assert.Contains(t, messages, "Test log user1 second")
|
||||
|
||||
// Test user can view own logs
|
||||
var ownLogsResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s", user2.UserID.String()),
|
||||
"Bearer "+user2.Token, http.StatusOK, &ownLogsResponse)
|
||||
assert.Equal(t, 2, len(ownLogsResponse.AuditLogs))
|
||||
|
||||
// Test user cannot view other user's logs
|
||||
resp := test_utils.MakeGetRequest(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s", user1.UserID.String()),
|
||||
"Bearer "+user2.Token, http.StatusForbidden)
|
||||
|
||||
assert.Contains(t, string(resp.Body), "insufficient permissions")
|
||||
}
|
||||
|
||||
func Test_FilterAuditLogsByTime_ReturnsOnlyLogsBeforeDate(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
db := storage.GetDb()
|
||||
baseTime := time.Now().UTC()
|
||||
|
||||
// Create logs with different timestamps
|
||||
createTimedLog(db, &adminUser.UserID, "Test old log", baseTime.Add(-2*time.Hour))
|
||||
createTimedLog(db, &adminUser.UserID, "Test recent log", baseTime.Add(-30*time.Minute))
|
||||
createAuditLog(service, "Test current log", &adminUser.UserID, nil)
|
||||
|
||||
// Test filtering - get logs before 1 hour ago
|
||||
beforeTime := baseTime.Add(-1 * time.Hour)
|
||||
var filteredResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/global?beforeDate=%s", beforeTime.Format(time.RFC3339)),
|
||||
"Bearer "+adminUser.Token, http.StatusOK, &filteredResponse)
|
||||
|
||||
// Verify only old log is returned
|
||||
messages := extractMessages(filteredResponse.AuditLogs)
|
||||
assert.Contains(t, messages, "Test old log")
|
||||
assert.NotContains(t, messages, "Test recent log")
|
||||
assert.NotContains(t, messages, "Test current log")
|
||||
|
||||
// Test without filter - should get all logs
|
||||
var allResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router, "/api/v1/audit-logs/global",
|
||||
"Bearer "+adminUser.Token, http.StatusOK, &allResponse)
|
||||
assert.GreaterOrEqual(t, len(allResponse.AuditLogs), 3)
|
||||
}
|
||||
|
||||
func createRouter() *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
SetupDependencies()
|
||||
|
||||
v1 := router.Group("/api/v1")
|
||||
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
|
||||
GetAuditLogController().RegisterRoutes(protected.(*gin.RouterGroup))
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/di.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var auditLogRepository = &AuditLogRepository{}
|
||||
var auditLogService = &AuditLogService{
|
||||
auditLogRepository: auditLogRepository,
|
||||
logger: logger.GetLogger(),
|
||||
}
|
||||
var auditLogController = &AuditLogController{
|
||||
auditLogService: auditLogService,
|
||||
}
|
||||
|
||||
func GetAuditLogService() *AuditLogService {
|
||||
return auditLogService
|
||||
}
|
||||
|
||||
func GetAuditLogController() *AuditLogController {
|
||||
return auditLogController
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
users_services.GetUserService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/dto.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import "time"
|
||||
|
||||
type GetAuditLogsRequest struct {
|
||||
Limit int `form:"limit" json:"limit"`
|
||||
Offset int `form:"offset" json:"offset"`
|
||||
BeforeDate *time.Time `form:"beforeDate" json:"beforeDate"`
|
||||
}
|
||||
|
||||
type GetAuditLogsResponse struct {
|
||||
AuditLogs []*AuditLog `json:"auditLogs"`
|
||||
Total int64 `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/models.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLog struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id"`
|
||||
UserID *uuid.UUID `json:"userId" gorm:"column:user_id"`
|
||||
ProjectID *uuid.UUID `json:"projectId" gorm:"column:project_id"`
|
||||
Message string `json:"message" gorm:"column:message"`
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
|
||||
}
|
||||
|
||||
func (AuditLog) TableName() string {
|
||||
return "audit_logs"
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/repository.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/storage"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogRepository struct{}
|
||||
|
||||
func (r *AuditLogRepository) Create(auditLog *AuditLog) error {
|
||||
if auditLog.ID == uuid.Nil {
|
||||
auditLog.ID = uuid.New()
|
||||
}
|
||||
|
||||
return storage.GetDb().Create(auditLog).Error
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetGlobal(limit, offset int, beforeDate *time.Time) ([]*AuditLog, error) {
|
||||
var auditLogs []*AuditLog
|
||||
|
||||
query := storage.GetDb().Order("created_at DESC")
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetByUser(
|
||||
userID uuid.UUID,
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLog, error) {
|
||||
var auditLogs []*AuditLog
|
||||
|
||||
query := storage.GetDb().
|
||||
Where("user_id = ?", userID).
|
||||
Order("created_at DESC")
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetByProject(
|
||||
projectID uuid.UUID,
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLog, error) {
|
||||
var auditLogs []*AuditLog
|
||||
|
||||
query := storage.GetDb().
|
||||
Where("project_id = ?", projectID).
|
||||
Order("created_at DESC")
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) CountGlobal(beforeDate *time.Time) (int64, error) {
|
||||
var count int64
|
||||
query := storage.GetDb().Model(&AuditLog{})
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/service.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogService struct {
|
||||
auditLogRepository *AuditLogRepository
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *AuditLogService) WriteAuditLog(
|
||||
message string,
|
||||
userID *uuid.UUID,
|
||||
projectID *uuid.UUID,
|
||||
) {
|
||||
auditLog := &AuditLog{
|
||||
UserID: userID,
|
||||
ProjectID: projectID,
|
||||
Message: message,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
err := s.auditLogRepository.Create(auditLog)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to create audit log", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuditLogService) CreateAuditLog(auditLog *AuditLog) error {
|
||||
return s.auditLogRepository.Create(auditLog)
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetGlobalAuditLogs(
|
||||
user *user_models.User,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
if user.Role != user_enums.UserRoleAdmin {
|
||||
return nil, errors.New("only administrators can view global audit logs")
|
||||
}
|
||||
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetGlobal(limit, offset, request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
total, err := s.auditLogRepository.CountGlobal(request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: total,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetUserAuditLogs(
|
||||
targetUserID uuid.UUID,
|
||||
user *user_models.User,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
// Users can view their own logs, ADMIN can view any user's logs
|
||||
if user.Role != user_enums.UserRoleAdmin && user.ID != targetUserID {
|
||||
return nil, errors.New("insufficient permissions to view user audit logs")
|
||||
}
|
||||
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetByUser(targetUserID, limit, offset, request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: int64(len(auditLogs)),
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetProjectAuditLogs(
|
||||
projectID uuid.UUID,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetByProject(projectID, limit, offset, request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: int64(len(auditLogs)),
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/service_test.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func Test_AuditLogs_ProjectSpecificLogs(t *testing.T) {
|
||||
service := GetAuditLogService()
|
||||
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
project1ID, project2ID := uuid.New(), uuid.New()
|
||||
|
||||
// Create test logs for projects
|
||||
createAuditLog(service, "Test project1 log first", &user1.UserID, &project1ID)
|
||||
createAuditLog(service, "Test project1 log second", &user2.UserID, &project1ID)
|
||||
createAuditLog(service, "Test project2 log first", &user1.UserID, &project2ID)
|
||||
createAuditLog(service, "Test project2 log second", &user2.UserID, &project2ID)
|
||||
createAuditLog(service, "Test no project log", &user1.UserID, nil)
|
||||
|
||||
request := &GetAuditLogsRequest{Limit: 10, Offset: 0}
|
||||
|
||||
// Test project 1 logs
|
||||
project1Response, err := service.GetProjectAuditLogs(project1ID, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(project1Response.AuditLogs))
|
||||
|
||||
messages := extractMessages(project1Response.AuditLogs)
|
||||
assert.Contains(t, messages, "Test project1 log first")
|
||||
assert.Contains(t, messages, "Test project1 log second")
|
||||
for _, log := range project1Response.AuditLogs {
|
||||
assert.Equal(t, &project1ID, log.ProjectID)
|
||||
}
|
||||
|
||||
// Test project 2 logs
|
||||
project2Response, err := service.GetProjectAuditLogs(project2ID, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(project2Response.AuditLogs))
|
||||
|
||||
messages2 := extractMessages(project2Response.AuditLogs)
|
||||
assert.Contains(t, messages2, "Test project2 log first")
|
||||
assert.Contains(t, messages2, "Test project2 log second")
|
||||
|
||||
// Test pagination
|
||||
limitedResponse, err := service.GetProjectAuditLogs(project1ID,
|
||||
&GetAuditLogsRequest{Limit: 1, Offset: 0})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(limitedResponse.AuditLogs))
|
||||
assert.Equal(t, 1, limitedResponse.Limit)
|
||||
|
||||
// Test beforeDate filter
|
||||
beforeTime := time.Now().UTC().Add(-1 * time.Minute)
|
||||
filteredResponse, err := service.GetProjectAuditLogs(project1ID,
|
||||
&GetAuditLogsRequest{Limit: 10, BeforeDate: &beforeTime})
|
||||
assert.NoError(t, err)
|
||||
for _, log := range filteredResponse.AuditLogs {
|
||||
assert.True(t, log.CreatedAt.Before(beforeTime))
|
||||
}
|
||||
}
|
||||
|
||||
func createAuditLog(service *AuditLogService, message string, userID, projectID *uuid.UUID) {
|
||||
service.WriteAuditLog(message, userID, projectID)
|
||||
}
|
||||
|
||||
func extractMessages(logs []*AuditLog) []string {
|
||||
messages := make([]string, len(logs))
|
||||
for i, log := range logs {
|
||||
messages[i] = log.Message
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
func createTimedLog(db *gorm.DB, userID *uuid.UUID, message string, createdAt time.Time) {
|
||||
log := &AuditLog{
|
||||
ID: uuid.New(),
|
||||
UserID: userID,
|
||||
Message: message,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
db.Create(log)
|
||||
}
|
||||
|
||||
```
|
||||
@@ -1,74 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
For DI files use implicit fields declaration styles (espesially
|
||||
for controllers, services, repositories, use cases, etc., not simple
|
||||
data structures).
|
||||
|
||||
So, instead of:
|
||||
|
||||
var orderController = &OrderController{
|
||||
orderService: orderService,
|
||||
botUserService: bot_users.GetBotUserService(),
|
||||
botService: bots.GetBotService(),
|
||||
userService: users.GetUserService(),
|
||||
}
|
||||
|
||||
Use:
|
||||
|
||||
var orderController = &OrderController{
|
||||
orderService,
|
||||
bot_users.GetBotUserService(),
|
||||
bots.GetBotService(),
|
||||
users.GetUserService(),
|
||||
}
|
||||
|
||||
This is needed to avoid forgetting to update DI style
|
||||
when we add new dependency.
|
||||
|
||||
---
|
||||
|
||||
Please force such usage if file look like this (see some
|
||||
services\controllers\repos definitions and getters):
|
||||
|
||||
var orderBackgroundService = &OrderBackgroundService{
|
||||
orderService: orderService,
|
||||
orderPaymentRepository: orderPaymentRepository,
|
||||
botService: bots.GetBotService(),
|
||||
paymentSettingsService: payment_settings.GetPaymentSettingsService(),
|
||||
|
||||
orderSubscriptionListeners: []OrderSubscriptionListener{},
|
||||
}
|
||||
|
||||
var orderController = &OrderController{
|
||||
orderService: orderService,
|
||||
botUserService: bot_users.GetBotUserService(),
|
||||
botService: bots.GetBotService(),
|
||||
userService: users.GetUserService(),
|
||||
}
|
||||
|
||||
func GetUniquePaymentRepository() *repositories.UniquePaymentRepository {
|
||||
return uniquePaymentRepository
|
||||
}
|
||||
|
||||
func GetOrderPaymentRepository() *repositories.OrderPaymentRepository {
|
||||
return orderPaymentRepository
|
||||
}
|
||||
|
||||
func GetOrderService() *OrderService {
|
||||
return orderService
|
||||
}
|
||||
|
||||
func GetOrderController() *OrderController {
|
||||
return orderController
|
||||
}
|
||||
|
||||
func GetOrderBackgroundService() *OrderBackgroundService {
|
||||
return orderBackgroundService
|
||||
}
|
||||
|
||||
func GetOrderRepository() *repositories.OrderRepository {
|
||||
return orderRepository
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
When writting migrations:
|
||||
|
||||
- write them for PostgreSQL
|
||||
- for PRIMARY UUID keys use gen_random_uuid()
|
||||
- for time use TIMESTAMPTZ (timestamp with zone)
|
||||
- split table, constraint and indexes declaration (table first, them other one by one)
|
||||
- format SQL in pretty way (add spaces, align columns types), constraints split by lines. The example:
|
||||
|
||||
CREATE TABLE marketplace_info (
|
||||
bot_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
title TEXT NOT NULL,
|
||||
description TEXT NOT NULL,
|
||||
short_description TEXT NOT NULL,
|
||||
tutorial_url TEXT,
|
||||
info_order BIGINT NOT NULL DEFAULT 0,
|
||||
is_published BOOLEAN NOT NULL DEFAULT FALSE
|
||||
);
|
||||
|
||||
ALTER TABLE marketplace_info_images
|
||||
ADD CONSTRAINT fk_marketplace_info_images_bot_id
|
||||
FOREIGN KEY (bot_id)
|
||||
REFERENCES marketplace_info (bot_id);
|
||||
@@ -1,12 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
When applying changes, do not forget to refactor old code.
|
||||
|
||||
You can shortify, make more readable, improve code quality, etc.
|
||||
Common logic can be extracted to functions, constants, files, etc.
|
||||
|
||||
After each large change with more than ~50-100 lines of code - always run `make lint` (from backend root folder) and, if you change frontend, run `npm run format` (from frontend root folder).
|
||||
@@ -1,147 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
After writing tests, always launch them and verify that they pass.
|
||||
|
||||
## Test Naming Format
|
||||
|
||||
Use these naming patterns:
|
||||
|
||||
- `Test_WhatWeDo_WhatWeExpect`
|
||||
- `Test_WhatWeDo_WhichConditions_WhatWeExpect`
|
||||
|
||||
## Examples from Real Codebase:
|
||||
|
||||
- `Test_CreateApiKey_WhenUserIsProjectOwner_ApiKeyCreated`
|
||||
- `Test_UpdateProject_WhenUserIsProjectAdmin_ProjectUpdated`
|
||||
- `Test_DeleteApiKey_WhenUserIsProjectMember_ReturnsForbidden`
|
||||
- `Test_GetProjectAuditLogs_WithDifferentUserRoles_EnforcesPermissionsCorrectly`
|
||||
- `Test_ProjectLifecycleE2E_CompletesSuccessfully`
|
||||
|
||||
## Testing Philosophy
|
||||
|
||||
**Prefer Controllers Over Unit Tests:**
|
||||
|
||||
- Test through HTTP endpoints via controllers whenever possible
|
||||
- Avoid testing repositories, services in isolation - test via API instead
|
||||
- Only use unit tests for complex model logic when no API exists
|
||||
- Name test files `controller_test.go` or `service_test.go`, not `integration_test.go`
|
||||
|
||||
**Extract Common Logic to Testing Utilities:**
|
||||
|
||||
- Create `testing.go` or `testing/testing.go` files for shared test utilities
|
||||
- Extract router creation, user setup, models creation helpers (in API, not just structs creation)
|
||||
- Reuse common patterns across different test files
|
||||
|
||||
**Refactor Existing Tests:**
|
||||
|
||||
- When working with existing tests, always look for opportunities to refactor and improve
|
||||
- Extract repetitive setup code to common utilities
|
||||
- Simplify complex tests by breaking them into smaller, focused tests
|
||||
- Replace inline test data creation with reusable helper functions
|
||||
- Consolidate similar test patterns across different test files
|
||||
- Make tests more readable and maintainable for other developers
|
||||
|
||||
## Testing Utilities Structure
|
||||
|
||||
**Create `testing.go` or `testing/testing.go` files with common utilities:**
|
||||
|
||||
```go
|
||||
package projects_testing
|
||||
|
||||
// CreateTestRouter creates unified router for all controllers
|
||||
func CreateTestRouter(controllers ...ControllerInterface) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
v1 := router.Group("/api/v1")
|
||||
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
|
||||
|
||||
for _, controller := range controllers {
|
||||
if routerGroup, ok := protected.(*gin.RouterGroup); ok {
|
||||
controller.RegisterRoutes(routerGroup)
|
||||
}
|
||||
}
|
||||
return router
|
||||
}
|
||||
|
||||
// CreateTestProjectViaAPI creates project through HTTP API
|
||||
func CreateTestProjectViaAPI(name string, owner *users_dto.SignInResponseDTO, router *gin.Engine) (*projects_models.Project, string) {
|
||||
request := projects_dto.CreateProjectRequestDTO{Name: name}
|
||||
w := MakeAPIRequest(router, "POST", "/api/v1/projects", "Bearer "+owner.Token, request)
|
||||
// Handle response...
|
||||
return project, owner.Token
|
||||
}
|
||||
|
||||
// AddMemberToProject adds member via API call
|
||||
func AddMemberToProject(project *projects_models.Project, member *users_dto.SignInResponseDTO, role users_enums.ProjectRole, ownerToken string, router *gin.Engine) {
|
||||
// Implementation...
|
||||
}
|
||||
```
|
||||
|
||||
## Controller Test Examples
|
||||
|
||||
**Permission-based testing:**
|
||||
|
||||
```go
|
||||
func Test_CreateApiKey_WhenUserIsProjectOwner_ApiKeyCreated(t *testing.T) {
|
||||
router := CreateApiKeyTestRouter(GetProjectController(), GetMembershipController())
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
project, _ := projects_testing.CreateTestProjectViaAPI("Test Project", owner, router)
|
||||
|
||||
request := CreateApiKeyRequestDTO{Name: "Test API Key"}
|
||||
var response ApiKey
|
||||
test_utils.MakePostRequestAndUnmarshal(t, router, "/api/v1/projects/api-keys/"+project.ID.String(), "Bearer "+owner.Token, request, http.StatusOK, &response)
|
||||
|
||||
assert.Equal(t, "Test API Key", response.Name)
|
||||
assert.NotEmpty(t, response.Token)
|
||||
}
|
||||
```
|
||||
|
||||
**Cross-project security testing:**
|
||||
|
||||
```go
|
||||
func Test_UpdateApiKey_WithApiKeyFromDifferentProject_ReturnsBadRequest(t *testing.T) {
|
||||
router := CreateApiKeyTestRouter(GetProjectController(), GetMembershipController())
|
||||
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
project1, _ := projects_testing.CreateTestProjectViaAPI("Project 1", owner1, router)
|
||||
project2, _ := projects_testing.CreateTestProjectViaAPI("Project 2", owner2, router)
|
||||
|
||||
apiKey := CreateTestApiKey("Cross Project Key", project1.ID, owner1.Token, router)
|
||||
|
||||
// Try to update via different project endpoint
|
||||
request := UpdateApiKeyRequestDTO{Name: &"Hacked Key"}
|
||||
resp := test_utils.MakePutRequest(t, router, "/api/v1/projects/api-keys/"+project2.ID.String()+"/"+apiKey.ID.String(), "Bearer "+owner2.Token, request, http.StatusBadRequest)
|
||||
|
||||
assert.Contains(t, string(resp.Body), "API key does not belong to this project")
|
||||
}
|
||||
```
|
||||
|
||||
**E2E lifecycle testing:**
|
||||
|
||||
```go
|
||||
func Test_ProjectLifecycleE2E_CompletesSuccessfully(t *testing.T) {
|
||||
router := projects_testing.CreateTestRouter(GetProjectController(), GetMembershipController())
|
||||
|
||||
// 1. Create project
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
project := projects_testing.CreateTestProject("E2E Project", owner, router)
|
||||
|
||||
// 2. Add member
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
projects_testing.AddMemberToProject(project, member, users_enums.ProjectRoleMember, owner.Token, router)
|
||||
|
||||
// 3. Promote to admin
|
||||
projects_testing.ChangeMemberRole(project, member.UserID, users_enums.ProjectRoleAdmin, owner.Token, router)
|
||||
|
||||
// 4. Transfer ownership
|
||||
projects_testing.TransferProjectOwnership(project, member.UserID, owner.Token, router)
|
||||
|
||||
// 5. Verify new owner can manage project
|
||||
finalProject := projects_testing.GetProject(project.ID, member.Token, router)
|
||||
assert.Equal(t, project.ID, finalProject.ID)
|
||||
}
|
||||
```
|
||||
@@ -1,6 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
Always use time.Now().UTC() instead of time.Now()
|
||||
@@ -2,8 +2,18 @@
|
||||
DEV_DB_NAME=databasus
|
||||
DEV_DB_USERNAME=postgres
|
||||
DEV_DB_PASSWORD=Q1234567
|
||||
#app
|
||||
# app
|
||||
ENV_MODE=development
|
||||
# logging
|
||||
SHOW_DB_INSTALLATION_VERIFICATION_LOGS=true
|
||||
VICTORIA_LOGS_URL=http://localhost:9428
|
||||
VICTORIA_LOGS_PASSWORD=devpassword
|
||||
# tests
|
||||
TEST_LOCALHOST=localhost
|
||||
IS_SKIP_EXTERNAL_RESOURCES_TESTS=false
|
||||
# cloudflare turnstile
|
||||
CLOUDFLARE_TURNSTILE_SITE_KEY=
|
||||
CLOUDFLARE_TURNSTILE_SECRET_KEY=
|
||||
# db
|
||||
DATABASE_DSN=host=dev-db user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
|
||||
DATABASE_URL=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable
|
||||
@@ -11,6 +21,12 @@ DATABASE_URL=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable
|
||||
GOOSE_DRIVER=postgres
|
||||
GOOSE_DBSTRING=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable
|
||||
GOOSE_MIGRATION_DIR=./migrations
|
||||
# valkey
|
||||
VALKEY_HOST=127.0.0.1
|
||||
VALKEY_PORT=6379
|
||||
VALKEY_USERNAME=
|
||||
VALKEY_PASSWORD=
|
||||
VALKEY_IS_SSL=false
|
||||
# testing
|
||||
# to get Google Drive env variables: add storage in UI and copy data from added storage here
|
||||
TEST_GOOGLE_DRIVE_CLIENT_ID=
|
||||
|
||||
@@ -10,4 +10,10 @@ DATABASE_URL=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disab
|
||||
# migrations
|
||||
GOOSE_DRIVER=postgres
|
||||
GOOSE_DBSTRING=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disable
|
||||
GOOSE_MIGRATION_DIR=./migrations
|
||||
GOOSE_MIGRATION_DIR=./migrations
|
||||
# valkey
|
||||
VALKEY_HOST=127.0.0.1
|
||||
VALKEY_PORT=6379
|
||||
VALKEY_USERNAME=
|
||||
VALKEY_PASSWORD=
|
||||
VALKEY_IS_SSL=false
|
||||
5
backend/.gitignore
vendored
5
backend/.gitignore
vendored
@@ -16,4 +16,7 @@ databasus-backend.exe
|
||||
ui/build/*
|
||||
pgdata-for-restore/
|
||||
temp/
|
||||
cmd.exe
|
||||
cmd.exe
|
||||
temp/
|
||||
valkey-data/
|
||||
victoria-logs-data/
|
||||
@@ -7,6 +7,16 @@ run:
|
||||
|
||||
linters:
|
||||
default: standard
|
||||
enable:
|
||||
- funcorder
|
||||
- bodyclose
|
||||
- errorlint
|
||||
- gocritic
|
||||
- unconvert
|
||||
- misspell
|
||||
- errname
|
||||
- noctx
|
||||
- modernize
|
||||
|
||||
settings:
|
||||
errcheck:
|
||||
@@ -14,6 +24,18 @@ linters:
|
||||
|
||||
formatters:
|
||||
enable:
|
||||
- gofmt
|
||||
- gofumpt
|
||||
- golines
|
||||
- goimports
|
||||
- gci
|
||||
|
||||
settings:
|
||||
golines:
|
||||
max-len: 120
|
||||
gofumpt:
|
||||
module-path: databasus-backend
|
||||
extra-rules: true
|
||||
gci:
|
||||
sections:
|
||||
- standard
|
||||
- default
|
||||
- localmodule
|
||||
|
||||
@@ -2,10 +2,10 @@ run:
|
||||
go run cmd/main.go
|
||||
|
||||
test:
|
||||
go test -p=1 -count=1 -failfast -timeout 10m .\internal\...
|
||||
go test -p=1 -count=1 -failfast -timeout 15m ./internal/...
|
||||
|
||||
lint:
|
||||
golangci-lint fmt && golangci-lint run
|
||||
golangci-lint fmt ./cmd/... ./internal/... && golangci-lint run ./cmd/... ./internal/...
|
||||
|
||||
migration-create:
|
||||
goose create $(name) sql
|
||||
|
||||
@@ -12,9 +12,18 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-contrib/gzip"
|
||||
"github.com/gin-gonic/gin"
|
||||
swaggerFiles "github.com/swaggo/files"
|
||||
ginSwagger "github.com/swaggo/gin-swagger"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_controllers "databasus-backend/internal/features/backups/backups/controllers"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
backups_services "databasus-backend/internal/features/backups/backups/services"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/disk"
|
||||
@@ -23,22 +32,21 @@ import (
|
||||
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/restores"
|
||||
"databasus-backend/internal/features/restores/restoring"
|
||||
"databasus-backend/internal/features/storages"
|
||||
system_agent "databasus-backend/internal/features/system/agent"
|
||||
system_healthcheck "databasus-backend/internal/features/system/healthcheck"
|
||||
system_version "databasus-backend/internal/features/system/version"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
users_controllers "databasus-backend/internal/features/users/controllers"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
env_utils "databasus-backend/internal/util/env"
|
||||
files_utils "databasus-backend/internal/util/files"
|
||||
"databasus-backend/internal/util/logger"
|
||||
_ "databasus-backend/swagger" // swagger docs
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-contrib/gzip"
|
||||
"github.com/gin-gonic/gin"
|
||||
swaggerFiles "github.com/swaggo/files"
|
||||
ginSwagger "github.com/swaggo/gin-swagger"
|
||||
)
|
||||
|
||||
// @title Databasus Backend API
|
||||
@@ -52,14 +60,29 @@ import (
|
||||
func main() {
|
||||
log := logger.GetLogger()
|
||||
|
||||
runMigrations(log)
|
||||
cache_utils.TestCacheConnection()
|
||||
|
||||
if config.GetEnv().IsPrimaryNode {
|
||||
log.Info("Clearing cache...")
|
||||
|
||||
err := cache_utils.ClearAllCache()
|
||||
if err != nil {
|
||||
log.Error("Failed to clear cache", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
if config.GetEnv().IsPrimaryNode {
|
||||
runMigrations(log)
|
||||
} else {
|
||||
log.Info("Skipping migrations (IS_PRIMARY_NODE is false)")
|
||||
}
|
||||
|
||||
// create directories that used for backups and restore
|
||||
err := files_utils.EnsureDirectories([]string{
|
||||
config.GetEnv().TempFolder,
|
||||
config.GetEnv().DataFolder,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error("Failed to ensure directories", "error", err)
|
||||
os.Exit(1)
|
||||
@@ -96,7 +119,9 @@ func main() {
|
||||
enableCors(ginApp)
|
||||
setUpRoutes(ginApp)
|
||||
setUpDependencies()
|
||||
|
||||
runBackgroundTasks(log)
|
||||
|
||||
mountFrontend(ginApp)
|
||||
|
||||
startServerWithGracefulShutdown(log, ginApp)
|
||||
@@ -124,7 +149,7 @@ func handlePasswordReset(log *slog.Logger) {
|
||||
resetPassword(*email, *newPassword, log)
|
||||
}
|
||||
|
||||
func resetPassword(email string, newPassword string, log *slog.Logger) {
|
||||
func resetPassword(email, newPassword string, log *slog.Logger) {
|
||||
log.Info("Resetting password...")
|
||||
|
||||
userService := users_services.GetUserService()
|
||||
@@ -162,6 +187,9 @@ func startServerWithGracefulShutdown(log *slog.Logger, app *gin.Engine) {
|
||||
<-quit
|
||||
log.Info("Shutdown signal received")
|
||||
|
||||
// Gracefully shutdown VictoriaLogs writer
|
||||
logger.ShutdownVictoriaLogs(5 * time.Second)
|
||||
|
||||
// The context is used to inform the server it has 10 seconds to finish
|
||||
// the request it is currently handling
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -183,6 +211,11 @@ func setUpRoutes(r *gin.Engine) {
|
||||
userController := users_controllers.GetUserController()
|
||||
userController.RegisterRoutes(v1)
|
||||
system_healthcheck.GetHealthcheckController().RegisterRoutes(v1)
|
||||
system_version.GetVersionController().RegisterRoutes(v1)
|
||||
system_agent.GetAgentController().RegisterRoutes(v1)
|
||||
backups_controllers.GetBackupController().RegisterPublicRoutes(v1)
|
||||
backups_controllers.GetPostgresWalBackupController().RegisterRoutes(v1)
|
||||
databases.GetDatabaseController().RegisterPublicRoutes(v1)
|
||||
|
||||
// Setup auth middleware
|
||||
userService := users_services.GetUserService()
|
||||
@@ -199,7 +232,7 @@ func setUpRoutes(r *gin.Engine) {
|
||||
notifiers.GetNotifierController().RegisterRoutes(protected)
|
||||
storages.GetStorageController().RegisterRoutes(protected)
|
||||
databases.GetDatabaseController().RegisterRoutes(protected)
|
||||
backups.GetBackupController().RegisterRoutes(protected)
|
||||
backups_controllers.GetBackupController().RegisterRoutes(protected)
|
||||
restores.GetRestoreController().RegisterRoutes(protected)
|
||||
healthcheck_config.GetHealthcheckConfigController().RegisterRoutes(protected)
|
||||
healthcheck_attempt.GetHealthcheckAttemptController().RegisterRoutes(protected)
|
||||
@@ -211,37 +244,87 @@ func setUpRoutes(r *gin.Engine) {
|
||||
|
||||
func setUpDependencies() {
|
||||
databases.SetupDependencies()
|
||||
backups.SetupDependencies()
|
||||
backups_services.SetupDependencies()
|
||||
restores.SetupDependencies()
|
||||
healthcheck_config.SetupDependencies()
|
||||
audit_logs.SetupDependencies()
|
||||
notifiers.SetupDependencies()
|
||||
storages.SetupDependencies()
|
||||
backups_config.SetupDependencies()
|
||||
task_cancellation.SetupDependencies()
|
||||
}
|
||||
|
||||
func runBackgroundTasks(log *slog.Logger) {
|
||||
log.Info("Preparing to run background tasks...")
|
||||
|
||||
// Create context that will be cancelled on shutdown
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Set up signal handling for graceful shutdown
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-quit
|
||||
log.Info("Shutdown signal received, cancelling all background tasks")
|
||||
cancel()
|
||||
}()
|
||||
|
||||
err := files_utils.CleanFolder(config.GetEnv().TempFolder)
|
||||
if err != nil {
|
||||
log.Error("Failed to clean temp folder", "error", err)
|
||||
}
|
||||
|
||||
go runWithPanicLogging(log, "backup background service", func() {
|
||||
backups.GetBackupBackgroundService().Run()
|
||||
})
|
||||
if config.GetEnv().IsPrimaryNode {
|
||||
log.Info("Starting primary node background tasks...")
|
||||
|
||||
go runWithPanicLogging(log, "restore background service", func() {
|
||||
restores.GetRestoreBackgroundService().Run()
|
||||
})
|
||||
go runWithPanicLogging(log, "backup background service", func() {
|
||||
backuping.GetBackupsScheduler().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "healthcheck attempt background service", func() {
|
||||
healthcheck_attempt.GetHealthcheckAttemptBackgroundService().Run()
|
||||
})
|
||||
go runWithPanicLogging(log, "backup cleaner background service", func() {
|
||||
backuping.GetBackupCleaner().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "audit log cleanup background service", func() {
|
||||
audit_logs.GetAuditLogBackgroundService().Run()
|
||||
})
|
||||
go runWithPanicLogging(log, "restore background service", func() {
|
||||
restoring.GetRestoresScheduler().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "healthcheck attempt background service", func() {
|
||||
healthcheck_attempt.GetHealthcheckAttemptBackgroundService().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "audit log cleanup background service", func() {
|
||||
audit_logs.GetAuditLogBackgroundService().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "download token cleanup background service", func() {
|
||||
backups_download.GetDownloadTokenBackgroundService().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "backup nodes registry background service", func() {
|
||||
backuping.GetBackupNodesRegistry().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "restore nodes registry background service", func() {
|
||||
restoring.GetRestoreNodesRegistry().Run(ctx)
|
||||
})
|
||||
} else {
|
||||
log.Info("Skipping primary node tasks as not primary node")
|
||||
}
|
||||
|
||||
if config.GetEnv().IsProcessingNode {
|
||||
log.Info("Starting backup node background tasks...")
|
||||
|
||||
go runWithPanicLogging(log, "backup node", func() {
|
||||
backuping.GetBackuperNode().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "restore node", func() {
|
||||
restoring.GetRestorerNode().Run(ctx)
|
||||
})
|
||||
} else {
|
||||
log.Info("Skipping backup/restore node tasks as not backup node")
|
||||
}
|
||||
}
|
||||
|
||||
func runWithPanicLogging(log *slog.Logger, serviceName string, fn func()) {
|
||||
@@ -270,7 +353,9 @@ func generateSwaggerDocs(log *slog.Logger) {
|
||||
return
|
||||
}
|
||||
|
||||
cmd := exec.Command("swag", "init", "-d", currentDir, "-g", "cmd/main.go", "-o", "swagger")
|
||||
cmd := exec.CommandContext(
|
||||
context.Background(), "swag", "init", "-d", currentDir, "-g", "cmd/main.go", "-o", "swagger",
|
||||
)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
@@ -284,16 +369,13 @@ func generateSwaggerDocs(log *slog.Logger) {
|
||||
func runMigrations(log *slog.Logger) {
|
||||
log.Info("Running database migrations...")
|
||||
|
||||
cmd := exec.Command("goose", "up")
|
||||
cmd := exec.CommandContext(context.Background(), "goose", "-dir", "./migrations", "up")
|
||||
cmd.Env = append(
|
||||
os.Environ(),
|
||||
"GOOSE_DRIVER=postgres",
|
||||
"GOOSE_DBSTRING="+config.GetEnv().DatabaseDsn,
|
||||
)
|
||||
|
||||
// Set the working directory to where migrations are located
|
||||
cmd.Dir = "./migrations"
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
log.Error("Failed to run migrations", "error", err, "output", string(output))
|
||||
|
||||
@@ -19,6 +19,35 @@ services:
|
||||
command: -p 5437
|
||||
shm_size: 10gb
|
||||
|
||||
# Valkey for caching
|
||||
dev-valkey:
|
||||
image: valkey/valkey:9.0.1-alpine
|
||||
ports:
|
||||
- "${VALKEY_PORT:-6379}:6379"
|
||||
volumes:
|
||||
- ./valkey-data:/data
|
||||
container_name: dev-valkey
|
||||
healthcheck:
|
||||
test: ["CMD", "valkey-cli", "ping"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
start_period: 20s
|
||||
|
||||
# VictoriaLogs for external logging
|
||||
victoria-logs:
|
||||
image: victoriametrics/victoria-logs:latest
|
||||
container_name: victoria-logs
|
||||
ports:
|
||||
- "9428:9428"
|
||||
command:
|
||||
- -storageDataPath=/victoria-logs-data
|
||||
- -retentionPeriod=7d
|
||||
- -httpAuth.password=devpassword
|
||||
volumes:
|
||||
- ./victoria-logs-data:/victoria-logs-data
|
||||
restart: unless-stopped
|
||||
|
||||
# Test MinIO container
|
||||
test-minio:
|
||||
image: minio/minio:latest
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
module databasus-backend
|
||||
|
||||
go 1.24.4
|
||||
go 1.26.1
|
||||
|
||||
require (
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0
|
||||
@@ -25,9 +25,9 @@ require (
|
||||
github.com/swaggo/files v1.0.1
|
||||
github.com/swaggo/gin-swagger v1.6.0
|
||||
github.com/swaggo/swag v1.16.4
|
||||
github.com/valkey-io/valkey-go v1.0.70
|
||||
go.mongodb.org/mongo-driver v1.17.6
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/time v0.14.0
|
||||
gorm.io/driver/postgres v1.5.11
|
||||
gorm.io/gorm v1.26.1
|
||||
)
|
||||
@@ -185,6 +185,7 @@ require (
|
||||
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
golang.org/x/term v0.38.0 // indirect
|
||||
golang.org/x/time v0.14.0 // indirect
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
|
||||
gopkg.in/validator.v2 v2.0.1 // indirect
|
||||
moul.io/http2curl/v2 v2.3.0 // indirect
|
||||
@@ -269,7 +270,7 @@ require (
|
||||
go.opentelemetry.io/otel/metric v1.38.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.38.0 // indirect
|
||||
golang.org/x/arch v0.17.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/net v0.48.0 // indirect
|
||||
golang.org/x/oauth2 v0.33.0
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/sys v0.39.0 // indirect
|
||||
|
||||
@@ -539,8 +539,8 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
||||
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
|
||||
github.com/onsi/ginkgo/v2 v2.17.3 h1:oJcvKpIb7/8uLpDDtnQuf18xVnwKp8DTD7DQ6gTd/MU=
|
||||
github.com/onsi/ginkgo/v2 v2.17.3/go.mod h1:nP2DPOQoNsQmsVyv5rDA8JkXQoCs6goXIvr/PRJ1eCc=
|
||||
github.com/onsi/gomega v1.37.0 h1:CdEG8g0S133B4OswTDC/5XPSzE1OeP29QOioj2PID2Y=
|
||||
github.com/onsi/gomega v1.37.0/go.mod h1:8D9+Txp43QWKhM24yyOBEdpkzN8FvJyAwecBgsU4KU0=
|
||||
github.com/onsi/gomega v1.38.3 h1:eTX+W6dobAYfFeGC2PV6RwXRu/MyT+cQguijutvkpSM=
|
||||
github.com/onsi/gomega v1.38.3/go.mod h1:ZCU1pkQcXDO5Sl9/VVEGlDyp+zm0m1cmeG5TOzLgdh4=
|
||||
github.com/oracle/oci-go-sdk/v65 v65.104.0 h1:l9awEvzWvxmYhy/97A0hZ87pa7BncYXmcO/S8+rvgK0=
|
||||
github.com/oracle/oci-go-sdk/v65 v65.104.0/go.mod h1:oB8jFGVc/7/zJ+DbleE8MzGHjhs2ioCz5stRTdZdIcY=
|
||||
github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg=
|
||||
@@ -660,6 +660,8 @@ github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY=
|
||||
github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
|
||||
github.com/unknwon/goconfig v1.0.0 h1:rS7O+CmUdli1T+oDm7fYj1MwqNWtEJfNj+FqcUHML8U=
|
||||
github.com/unknwon/goconfig v1.0.0/go.mod h1:qu2ZQ/wcC/if2u32263HTVC39PeOQRSmidQk3DuDFQ8=
|
||||
github.com/valkey-io/valkey-go v1.0.70 h1:mjYNT8qiazxDAJ0QNQ8twWT/YFOkOoRd40ERV2mB49Y=
|
||||
github.com/valkey-io/valkey-go v1.0.70/go.mod h1:VGhZ6fs68Qrn2+OhH+6waZH27bjpgQOiLyUQyXuYK5k=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM=
|
||||
@@ -720,6 +722,8 @@ go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
|
||||
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
|
||||
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
|
||||
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/arch v0.17.0 h1:4O3dfLzd+lQewptAHqjewQZQDyEdejz3VwgeYwkZneU=
|
||||
golang.org/x/arch v0.17.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
@@ -818,8 +822,8 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
||||
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
env_utils "databasus-backend/internal/util/env"
|
||||
"databasus-backend/internal/util/logger"
|
||||
"databasus-backend/internal/util/tools"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -11,6 +8,10 @@ import (
|
||||
|
||||
"github.com/ilyakaznacheev/cleanenv"
|
||||
"github.com/joho/godotenv"
|
||||
|
||||
env_utils "databasus-backend/internal/util/env"
|
||||
"databasus-backend/internal/util/logger"
|
||||
"databasus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
var log = logger.GetLogger()
|
||||
@@ -22,13 +23,32 @@ const (
|
||||
|
||||
type EnvVariables struct {
|
||||
IsTesting bool
|
||||
DatabaseDsn string `env:"DATABASE_DSN" required:"true"`
|
||||
EnvMode env_utils.EnvMode `env:"ENV_MODE" required:"true"`
|
||||
PostgresesInstallDir string `env:"POSTGRES_INSTALL_DIR"`
|
||||
MysqlInstallDir string `env:"MYSQL_INSTALL_DIR"`
|
||||
MariadbInstallDir string `env:"MARIADB_INSTALL_DIR"`
|
||||
MongodbInstallDir string `env:"MONGODB_INSTALL_DIR"`
|
||||
|
||||
// Internal database
|
||||
DatabaseDsn string `env:"DATABASE_DSN" required:"true"`
|
||||
// Internal Valkey
|
||||
ValkeyHost string `env:"VALKEY_HOST" required:"true"`
|
||||
ValkeyPort string `env:"VALKEY_PORT" required:"true"`
|
||||
ValkeyUsername string `env:"VALKEY_USERNAME" required:"true"`
|
||||
ValkeyPassword string `env:"VALKEY_PASSWORD" required:"true"`
|
||||
ValkeyIsSsl bool `env:"VALKEY_IS_SSL" required:"true"`
|
||||
|
||||
IsCloud bool `env:"IS_CLOUD"`
|
||||
TestLocalhost string `env:"TEST_LOCALHOST"`
|
||||
|
||||
ShowDbInstallationVerificationLogs bool `env:"SHOW_DB_INSTALLATION_VERIFICATION_LOGS"`
|
||||
IsSkipExternalResourcesTests bool `env:"IS_SKIP_EXTERNAL_RESOURCES_TESTS"`
|
||||
|
||||
IsManyNodesMode bool `env:"IS_MANY_NODES_MODE"`
|
||||
IsPrimaryNode bool `env:"IS_PRIMARY_NODE"`
|
||||
IsProcessingNode bool `env:"IS_PROCESSING_NODE"`
|
||||
NodeNetworkThroughputMBs int `env:"NODE_NETWORK_THROUGHPUT_MBPS"`
|
||||
|
||||
DataFolder string
|
||||
TempFolder string
|
||||
SecretKeyPath string
|
||||
@@ -85,6 +105,10 @@ type EnvVariables struct {
|
||||
GoogleClientID string `env:"GOOGLE_CLIENT_ID"`
|
||||
GoogleClientSecret string `env:"GOOGLE_CLIENT_SECRET"`
|
||||
|
||||
// Cloudflare Turnstile
|
||||
CloudflareTurnstileSecretKey string `env:"CLOUDFLARE_TURNSTILE_SECRET_KEY"`
|
||||
CloudflareTurnstileSiteKey string `env:"CLOUDFLARE_TURNSTILE_SITE_KEY"`
|
||||
|
||||
// testing Telegram
|
||||
TestTelegramBotToken string `env:"TEST_TELEGRAM_BOT_TOKEN"`
|
||||
TestTelegramChatID string `env:"TEST_TELEGRAM_CHAT_ID"`
|
||||
@@ -95,6 +119,16 @@ type EnvVariables struct {
|
||||
TestSupabaseUsername string `env:"TEST_SUPABASE_USERNAME"`
|
||||
TestSupabasePassword string `env:"TEST_SUPABASE_PASSWORD"`
|
||||
TestSupabaseDatabase string `env:"TEST_SUPABASE_DATABASE"`
|
||||
|
||||
// SMTP configuration (optional)
|
||||
SMTPHost string `env:"SMTP_HOST"`
|
||||
SMTPPort int `env:"SMTP_PORT"`
|
||||
SMTPUser string `env:"SMTP_USER"`
|
||||
SMTPPassword string `env:"SMTP_PASSWORD"`
|
||||
SMTPFrom string `env:"SMTP_FROM"`
|
||||
|
||||
// Application URL (optional) - used for email links
|
||||
DatabasusURL string `env:"DATABASUS_URL"`
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -155,6 +189,21 @@ func loadEnvVariables() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Set default value for ShowDbInstallationVerificationLogs if not defined
|
||||
if os.Getenv("SHOW_DB_INSTALLATION_VERIFICATION_LOGS") == "" {
|
||||
env.ShowDbInstallationVerificationLogs = true
|
||||
}
|
||||
|
||||
// Set default value for IsSkipExternalTests if not defined
|
||||
if os.Getenv("IS_SKIP_EXTERNAL_RESOURCES_TESTS") == "" {
|
||||
env.IsSkipExternalResourcesTests = false
|
||||
}
|
||||
|
||||
// Set default value for IsCloud if not defined
|
||||
if os.Getenv("IS_CLOUD") == "" {
|
||||
env.IsCloud = false
|
||||
}
|
||||
|
||||
for _, arg := range os.Args {
|
||||
if strings.Contains(arg, "test") {
|
||||
env.IsTesting = true
|
||||
@@ -162,6 +211,14 @@ func loadEnvVariables() {
|
||||
}
|
||||
}
|
||||
|
||||
// Check for external database override
|
||||
if externalDsn := os.Getenv("DANGEROUS_EXTERNAL_DATABASE_DSN"); externalDsn != "" {
|
||||
log.Warn(
|
||||
"Using DANGEROUS_EXTERNAL_DATABASE_DSN - connecting to external database instead of internal PostgreSQL",
|
||||
)
|
||||
env.DatabaseDsn = externalDsn
|
||||
}
|
||||
|
||||
if env.DatabaseDsn == "" {
|
||||
log.Error("DATABASE_DSN is empty")
|
||||
os.Exit(1)
|
||||
@@ -178,16 +235,80 @@ func loadEnvVariables() {
|
||||
log.Info("ENV_MODE loaded", "mode", env.EnvMode)
|
||||
|
||||
env.PostgresesInstallDir = filepath.Join(backendRoot, "tools", "postgresql")
|
||||
tools.VerifyPostgresesInstallation(log, env.EnvMode, env.PostgresesInstallDir)
|
||||
tools.VerifyPostgresesInstallation(
|
||||
log,
|
||||
env.EnvMode,
|
||||
env.PostgresesInstallDir,
|
||||
env.ShowDbInstallationVerificationLogs,
|
||||
)
|
||||
|
||||
env.MysqlInstallDir = filepath.Join(backendRoot, "tools", "mysql")
|
||||
tools.VerifyMysqlInstallation(log, env.EnvMode, env.MysqlInstallDir)
|
||||
tools.VerifyMysqlInstallation(
|
||||
log,
|
||||
env.EnvMode,
|
||||
env.MysqlInstallDir,
|
||||
env.ShowDbInstallationVerificationLogs,
|
||||
)
|
||||
|
||||
env.MariadbInstallDir = filepath.Join(backendRoot, "tools", "mariadb")
|
||||
tools.VerifyMariadbInstallation(log, env.EnvMode, env.MariadbInstallDir)
|
||||
tools.VerifyMariadbInstallation(
|
||||
log,
|
||||
env.EnvMode,
|
||||
env.MariadbInstallDir,
|
||||
env.ShowDbInstallationVerificationLogs,
|
||||
)
|
||||
|
||||
env.MongodbInstallDir = filepath.Join(backendRoot, "tools", "mongodb")
|
||||
tools.VerifyMongodbInstallation(log, env.EnvMode, env.MongodbInstallDir)
|
||||
tools.VerifyMongodbInstallation(
|
||||
log,
|
||||
env.EnvMode,
|
||||
env.MongodbInstallDir,
|
||||
env.ShowDbInstallationVerificationLogs,
|
||||
)
|
||||
|
||||
if env.NodeNetworkThroughputMBs == 0 {
|
||||
env.NodeNetworkThroughputMBs = 125 // 1 Gbit/s
|
||||
}
|
||||
|
||||
if !env.IsManyNodesMode {
|
||||
env.IsPrimaryNode = true
|
||||
env.IsProcessingNode = true
|
||||
}
|
||||
|
||||
if env.TestLocalhost == "" {
|
||||
env.TestLocalhost = "localhost"
|
||||
}
|
||||
|
||||
// Valkey
|
||||
if env.ValkeyHost == "" {
|
||||
log.Error("VALKEY_HOST is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
if env.ValkeyPort == "" {
|
||||
log.Error("VALKEY_PORT is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Check for external Valkey override
|
||||
if externalValkeyHost := os.Getenv("DANGEROUS_VALKEY_HOST"); externalValkeyHost != "" {
|
||||
log.Warn(
|
||||
"Using DANGEROUS_VALKEY_* variables - connecting to external Valkey instead of internal instance",
|
||||
)
|
||||
env.ValkeyHost = externalValkeyHost
|
||||
|
||||
if externalValkeyPort := os.Getenv("DANGEROUS_VALKEY_PORT"); externalValkeyPort != "" {
|
||||
env.ValkeyPort = externalValkeyPort
|
||||
}
|
||||
if externalValkeyUsername := os.Getenv("DANGEROUS_VALKEY_USERNAME"); externalValkeyUsername != "" {
|
||||
env.ValkeyUsername = externalValkeyUsername
|
||||
}
|
||||
if externalValkeyPassword := os.Getenv("DANGEROUS_VALKEY_PASSWORD"); externalValkeyPassword != "" {
|
||||
env.ValkeyPassword = externalValkeyPassword
|
||||
}
|
||||
if externalValkeyIsSsl := os.Getenv("DANGEROUS_VALKEY_IS_SSL"); externalValkeyIsSsl != "" {
|
||||
env.ValkeyIsSsl = externalValkeyIsSsl == "true"
|
||||
}
|
||||
}
|
||||
|
||||
// Store the data and temp folders one level below the root
|
||||
// (projectRoot/databasus-data -> /databasus-data)
|
||||
|
||||
@@ -1,33 +1,51 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type AuditLogBackgroundService struct {
|
||||
auditLogService *AuditLogService
|
||||
logger *slog.Logger
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *AuditLogBackgroundService) Run() {
|
||||
s.logger.Info("Starting audit log cleanup background service")
|
||||
func (s *AuditLogBackgroundService) Run(ctx context.Context) {
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
return
|
||||
}
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
for {
|
||||
if config.IsShouldShutdown() {
|
||||
s.logger.Info("Starting audit log cleanup background service")
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.cleanOldAuditLogs(); err != nil {
|
||||
s.logger.Error("Failed to clean old audit logs", "error", err)
|
||||
}
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
time.Sleep(1 * time.Hour)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.cleanOldAuditLogs(); err != nil {
|
||||
s.logger.Error("Failed to clean old audit logs", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/storage"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gorm.io/gorm"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
"databasus-backend/internal/storage"
|
||||
)
|
||||
|
||||
func Test_CleanOldAuditLogs_DeletesLogsOlderThanOneYear(t *testing.T) {
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
)
|
||||
|
||||
type AuditLogController struct {
|
||||
@@ -50,7 +51,7 @@ func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
|
||||
|
||||
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "only administrators can view global audit logs" {
|
||||
if errors.Is(err, ErrOnlyAdminsCanViewGlobalLogs) {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -99,7 +100,7 @@ func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
|
||||
|
||||
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view user audit logs" {
|
||||
if errors.Is(err, ErrInsufficientPermissionsToViewLogs) {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -6,15 +6,15 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_GetGlobalAuditLogs_WithDifferentUserRoles_EnforcesPermissionsCorrectly(t *testing.T) {
|
||||
|
||||
@@ -1,21 +1,30 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var auditLogRepository = &AuditLogRepository{}
|
||||
var auditLogService = &AuditLogService{
|
||||
auditLogRepository,
|
||||
logger.GetLogger(),
|
||||
}
|
||||
var (
|
||||
auditLogRepository = &AuditLogRepository{}
|
||||
auditLogService = &AuditLogService{
|
||||
auditLogRepository,
|
||||
logger.GetLogger(),
|
||||
}
|
||||
)
|
||||
|
||||
var auditLogController = &AuditLogController{
|
||||
auditLogService,
|
||||
}
|
||||
|
||||
var auditLogBackgroundService = &AuditLogBackgroundService{
|
||||
auditLogService,
|
||||
logger.GetLogger(),
|
||||
auditLogService: auditLogService,
|
||||
logger: logger.GetLogger(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
|
||||
func GetAuditLogService() *AuditLogService {
|
||||
@@ -30,8 +39,23 @@ func GetAuditLogBackgroundService() *AuditLogBackgroundService {
|
||||
return auditLogBackgroundService
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
users_services.GetUserService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
users_services.GetUserService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
12
backend/internal/features/audit_logs/errors.go
Normal file
12
backend/internal/features/audit_logs/errors.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package audit_logs
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrOnlyAdminsCanViewGlobalLogs = errors.New(
|
||||
"only administrators can view global audit logs",
|
||||
)
|
||||
ErrInsufficientPermissionsToViewLogs = errors.New(
|
||||
"insufficient permissions to view user audit logs",
|
||||
)
|
||||
)
|
||||
@@ -1,10 +1,11 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/storage"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/storage"
|
||||
)
|
||||
|
||||
type AuditLogRepository struct{}
|
||||
@@ -21,7 +22,7 @@ func (r *AuditLogRepository) GetGlobal(
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLogDTO, error) {
|
||||
var auditLogs = make([]*AuditLogDTO, 0)
|
||||
auditLogs := make([]*AuditLogDTO, 0)
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
@@ -37,7 +38,7 @@ func (r *AuditLogRepository) GetGlobal(
|
||||
LEFT JOIN users u ON al.user_id = u.id
|
||||
LEFT JOIN workspaces w ON al.workspace_id = w.id`
|
||||
|
||||
args := []interface{}{}
|
||||
args := []any{}
|
||||
|
||||
if beforeDate != nil {
|
||||
sql += " WHERE al.created_at < ?"
|
||||
@@ -57,7 +58,7 @@ func (r *AuditLogRepository) GetByUser(
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLogDTO, error) {
|
||||
var auditLogs = make([]*AuditLogDTO, 0)
|
||||
auditLogs := make([]*AuditLogDTO, 0)
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
@@ -74,7 +75,7 @@ func (r *AuditLogRepository) GetByUser(
|
||||
LEFT JOIN workspaces w ON al.workspace_id = w.id
|
||||
WHERE al.user_id = ?`
|
||||
|
||||
args := []interface{}{userID}
|
||||
args := []any{userID}
|
||||
|
||||
if beforeDate != nil {
|
||||
sql += " AND al.created_at < ?"
|
||||
@@ -94,7 +95,7 @@ func (r *AuditLogRepository) GetByWorkspace(
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLogDTO, error) {
|
||||
var auditLogs = make([]*AuditLogDTO, 0)
|
||||
auditLogs := make([]*AuditLogDTO, 0)
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
@@ -111,7 +112,7 @@ func (r *AuditLogRepository) GetByWorkspace(
|
||||
LEFT JOIN workspaces w ON al.workspace_id = w.id
|
||||
WHERE al.workspace_id = ?`
|
||||
|
||||
args := []interface{}{workspaceID}
|
||||
args := []any{workspaceID}
|
||||
|
||||
if beforeDate != nil {
|
||||
sql += " AND al.created_at < ?"
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogService struct {
|
||||
@@ -44,7 +43,7 @@ func (s *AuditLogService) GetGlobalAuditLogs(
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
if user.Role != user_enums.UserRoleAdmin {
|
||||
return nil, errors.New("only administrators can view global audit logs")
|
||||
return nil, ErrOnlyAdminsCanViewGlobalLogs
|
||||
}
|
||||
|
||||
limit := request.Limit
|
||||
@@ -79,7 +78,7 @@ func (s *AuditLogService) GetUserAuditLogs(
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
// Users can view their own logs, ADMIN can view any user's logs
|
||||
if user.Role != user_enums.UserRoleAdmin && user.ID != targetUserID {
|
||||
return nil, errors.New("insufficient permissions to view user audit logs")
|
||||
return nil, ErrInsufficientPermissionsToViewLogs
|
||||
}
|
||||
|
||||
limit := request.Limit
|
||||
|
||||
@@ -4,11 +4,11 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
)
|
||||
|
||||
func Test_AuditLogs_WorkspaceSpecificLogs(t *testing.T) {
|
||||
|
||||
@@ -1,254 +0,0 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/storages"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/period"
|
||||
"log/slog"
|
||||
"time"
|
||||
)
|
||||
|
||||
type BackupBackgroundService struct {
|
||||
backupService *BackupService
|
||||
backupRepository *BackupRepository
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
storageService *storages.StorageService
|
||||
|
||||
lastBackupTime time.Time
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) Run() {
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
|
||||
if err := s.failBackupsInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail backups in progress", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
if config.IsShouldShutdown() {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.cleanOldBackups(); err != nil {
|
||||
s.logger.Error("Failed to clean old backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.runPendingBackups(); err != nil {
|
||||
s.logger.Error("Failed to run pending backups", "error", err)
|
||||
}
|
||||
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
time.Sleep(1 * time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) IsBackupsWorkerRunning() bool {
|
||||
// if last backup time is more than 5 minutes ago, return false
|
||||
return s.lastBackupTime.After(time.Now().UTC().Add(-5 * time.Minute))
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) failBackupsInProgress() error {
|
||||
backupsInProgress, err := s.backupRepository.FindByStatus(BackupStatusInProgress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backup := range backupsInProgress {
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(backup.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
failMessage := "Backup failed due to application restart"
|
||||
backup.FailMessage = &failMessage
|
||||
backup.Status = BackupStatusFailed
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
s.backupService.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupFailed,
|
||||
&failMessage,
|
||||
)
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) cleanOldBackups() error {
|
||||
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
backupStorePeriod := backupConfig.StorePeriod
|
||||
|
||||
if backupStorePeriod == period.PeriodForever {
|
||||
continue
|
||||
}
|
||||
|
||||
storeDuration := backupStorePeriod.ToDuration()
|
||||
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
|
||||
|
||||
oldBackups, err := s.backupRepository.FindBackupsBeforeDate(
|
||||
backupConfig.DatabaseID,
|
||||
dateBeforeBackupsShouldBeDeleted,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to find old backups for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, backup := range oldBackups {
|
||||
storage, err := s.storageService.GetStorageByID(backup.StorageID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to get storage by ID",
|
||||
"storageId",
|
||||
backup.StorageID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
err = storage.DeleteFile(encryptor, backup.ID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
|
||||
}
|
||||
|
||||
if err := s.backupRepository.DeleteByID(backup.ID); err != nil {
|
||||
s.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Deleted old backup",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) runPendingBackups() error {
|
||||
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
if backupConfig.BackupInterval == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
lastBackup, err := s.backupRepository.FindLastByDatabaseID(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to get last backup for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
var lastBackupTime *time.Time
|
||||
if lastBackup != nil {
|
||||
lastBackupTime = &lastBackup.CreatedAt
|
||||
}
|
||||
|
||||
remainedBackupTryCount := s.GetRemainedBackupTryCount(lastBackup)
|
||||
|
||||
if backupConfig.BackupInterval.ShouldTriggerBackup(time.Now().UTC(), lastBackupTime) ||
|
||||
remainedBackupTryCount > 0 {
|
||||
s.logger.Info(
|
||||
"Triggering scheduled backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"intervalType",
|
||||
backupConfig.BackupInterval.Interval,
|
||||
)
|
||||
|
||||
go s.backupService.MakeBackup(backupConfig.DatabaseID, remainedBackupTryCount == 1)
|
||||
s.logger.Info(
|
||||
"Successfully triggered scheduled backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRemainedBackupTryCount returns the number of remaining backup tries for a given backup.
|
||||
// If the backup is not failed or the backup config does not allow retries, it returns 0.
|
||||
// If the backup is failed and the backup config allows retries, it returns the number of remaining tries.
|
||||
// If the backup is failed and the backup config does not allow retries, it returns 0.
|
||||
func (s *BackupBackgroundService) GetRemainedBackupTryCount(lastBackup *Backup) int {
|
||||
if lastBackup == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if lastBackup.Status != BackupStatusFailed {
|
||||
return 0
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(lastBackup.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
if !backupConfig.IsRetryIfFailed {
|
||||
return 0
|
||||
}
|
||||
|
||||
maxFailedTriesCount := backupConfig.MaxFailedTriesCount
|
||||
|
||||
lastBackups, err := s.backupRepository.FindByDatabaseIDWithLimit(
|
||||
lastBackup.DatabaseID,
|
||||
maxFailedTriesCount,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to find last backups by database ID", "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
lastFailedBackups := make([]*Backup, 0)
|
||||
|
||||
for _, backup := range lastBackups {
|
||||
if backup.Status == BackupStatusFailed {
|
||||
lastFailedBackups = append(lastFailedBackups, backup)
|
||||
}
|
||||
}
|
||||
|
||||
return maxFailedTriesCount - len(lastFailedBackups)
|
||||
}
|
||||
@@ -1,321 +0,0 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/util/period"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
// Enable backups for the database
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add old backup
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusCompleted,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 2)
|
||||
|
||||
// cleanup
|
||||
for _, backup := range backups {
|
||||
err := backupRepository.DeleteByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
// Enable backups for the database
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add recent backup (1 hour ago)
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusCompleted,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1) // Should still be 1 backup, no new backup created
|
||||
|
||||
// cleanup
|
||||
for _, backup := range backups {
|
||||
err := backupRepository.DeleteByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T) {
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
// Enable backups for the database with retries disabled
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
backupConfig.IsRetryIfFailed = false
|
||||
backupConfig.MaxFailedTriesCount = 0
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add failed backup
|
||||
failMessage := "backup failed"
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1) // Should still be 1 backup, no retry attempted
|
||||
|
||||
// cleanup
|
||||
for _, backup := range backups {
|
||||
err := backupRepository.DeleteByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
// Enable backups for the database with retries enabled
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
backupConfig.IsRetryIfFailed = true
|
||||
backupConfig.MaxFailedTriesCount = 3
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add failed backup
|
||||
failMessage := "backup failed"
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 2) // Should have 2 backups, retry was attempted
|
||||
|
||||
// cleanup
|
||||
for _, backup := range backups {
|
||||
err := backupRepository.DeleteByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(100 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *testing.T) {
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
// Enable backups for the database with retries enabled
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
backupConfig.IsRetryIfFailed = true
|
||||
backupConfig.MaxFailedTriesCount = 3
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
failMessage := "backup failed"
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
}
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 3) // Should have 3 backups, not more than max
|
||||
|
||||
// cleanup
|
||||
for _, backup := range backups {
|
||||
err := backupRepository.DeleteByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BackupContextManager struct {
|
||||
mu sync.RWMutex
|
||||
cancelFuncs map[uuid.UUID]context.CancelFunc
|
||||
cancelledBackups map[uuid.UUID]bool
|
||||
}
|
||||
|
||||
func NewBackupContextManager() *BackupContextManager {
|
||||
return &BackupContextManager{
|
||||
cancelFuncs: make(map[uuid.UUID]context.CancelFunc),
|
||||
cancelledBackups: make(map[uuid.UUID]bool),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) RegisterBackup(backupID uuid.UUID, cancelFunc context.CancelFunc) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.cancelFuncs[backupID] = cancelFunc
|
||||
delete(m.cancelledBackups, backupID)
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) CancelBackup(backupID uuid.UUID) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.cancelledBackups[backupID] {
|
||||
return nil
|
||||
}
|
||||
|
||||
cancelFunc, exists := m.cancelFuncs[backupID]
|
||||
if exists {
|
||||
cancelFunc()
|
||||
delete(m.cancelFuncs, backupID)
|
||||
}
|
||||
|
||||
m.cancelledBackups[backupID] = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) IsCancelled(backupID uuid.UUID) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.cancelledBackups[backupID]
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) UnregisterBackup(backupID uuid.UUID) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.cancelFuncs, backupID)
|
||||
delete(m.cancelledBackups, backupID)
|
||||
}
|
||||
469
backend/internal/features/backups/backups/backuping/backuper.go
Normal file
469
backend/internal/features/backups/backups/backuping/backuper.go
Normal file
@@ -0,0 +1,469 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
)
|
||||
|
||||
const (
|
||||
heartbeatTickerInterval = 15 * time.Second
|
||||
backuperHeathcheckThreshold = 5 * time.Minute
|
||||
)
|
||||
|
||||
type BackuperNode struct {
|
||||
databaseService *databases.DatabaseService
|
||||
fieldEncryptor util_encryption.FieldEncryptor
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
backupRepository *backups_core.BackupRepository
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
storageService *storages.StorageService
|
||||
notificationSender backups_core.NotificationSender
|
||||
backupCancelManager *tasks_cancellation.TaskCancelManager
|
||||
backupNodesRegistry *BackupNodesRegistry
|
||||
logger *slog.Logger
|
||||
createBackupUseCase backups_core.CreateBackupUsecase
|
||||
nodeID uuid.UUID
|
||||
|
||||
lastHeartbeat time.Time
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (n *BackuperNode) Run(ctx context.Context) {
|
||||
wasAlreadyRun := n.hasRun.Load()
|
||||
|
||||
n.runOnce.Do(func() {
|
||||
n.hasRun.Store(true)
|
||||
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
|
||||
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
|
||||
|
||||
backupNode := BackupNode{
|
||||
ID: n.nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: time.Now().UTC(),
|
||||
}
|
||||
|
||||
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
|
||||
n.logger.Error("Failed to register node in registry", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
|
||||
go func() {
|
||||
n.MakeBackup(backupID, isCallNotifier)
|
||||
if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil {
|
||||
n.logger.Error(
|
||||
"Failed to publish backup completion",
|
||||
"error",
|
||||
err,
|
||||
"backupID",
|
||||
backupID,
|
||||
)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
err := n.backupNodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID, backupHandler)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := n.backupNodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
|
||||
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(heartbeatTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
|
||||
|
||||
if err := n.backupNodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
|
||||
n.logger.Error("Failed to unregister node from registry", "error", err)
|
||||
}
|
||||
|
||||
return
|
||||
case <-ticker.C:
|
||||
n.sendHeartbeat(&backupNode)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", n))
|
||||
}
|
||||
}
|
||||
|
||||
func (n *BackuperNode) IsBackuperRunning() bool {
|
||||
return n.lastHeartbeat.After(time.Now().UTC().Add(-backuperHeathcheckThreshold))
|
||||
}
|
||||
|
||||
func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
backup, err := n.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get backup by ID", "backupId", backupID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
databaseID := backup.DatabaseID
|
||||
|
||||
database, err := n.databaseService.GetDatabaseByID(databaseID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get database by ID", "databaseId", databaseID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
backupConfig, err := n.backupConfigService.GetBackupConfigByDbId(databaseID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if backupConfig.StorageID == nil {
|
||||
n.logger.Error("Backup config storage ID is not defined")
|
||||
return
|
||||
}
|
||||
|
||||
storage, err := n.storageService.GetStorageByID(*backupConfig.StorageID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get storage by ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now().UTC()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
n.backupCancelManager.RegisterTask(backup.ID, cancel)
|
||||
defer n.backupCancelManager.UnregisterTask(backup.ID)
|
||||
|
||||
backupProgressListener := func(
|
||||
completedMBs float64,
|
||||
) {
|
||||
backup.BackupSizeMb = completedMBs
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
// Check size limit (0 = unlimited)
|
||||
if backupConfig.MaxBackupSizeMB > 0 &&
|
||||
completedMBs > float64(backupConfig.MaxBackupSizeMB) {
|
||||
errMsg := fmt.Sprintf(
|
||||
"backup size (%.2f MB) exceeded maximum allowed size (%d MB)",
|
||||
completedMBs,
|
||||
backupConfig.MaxBackupSizeMB,
|
||||
)
|
||||
|
||||
backup.Status = backups_core.BackupStatusFailed
|
||||
backup.IsSkipRetry = true
|
||||
backup.FailMessage = &errMsg
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save backup with size exceeded error", "error", err)
|
||||
}
|
||||
cancel() // Cancel the backup context
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to update backup progress", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
backupMetadata, err := n.createBackupUseCase.Execute(
|
||||
ctx,
|
||||
backup,
|
||||
backupConfig,
|
||||
database,
|
||||
storage,
|
||||
backupProgressListener,
|
||||
)
|
||||
if err != nil {
|
||||
// Check if backup was already marked as failed by progress listener (e.g., size limit exceeded)
|
||||
// If so, skip error handling to avoid overwriting the status
|
||||
currentBackup, fetchErr := n.backupRepository.FindByID(backup.ID)
|
||||
if fetchErr == nil && currentBackup.Status == backups_core.BackupStatusFailed {
|
||||
n.logger.Warn(
|
||||
"Backup already marked as failed by progress listener, skipping error handling",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"failMessage",
|
||||
*currentBackup.FailMessage,
|
||||
)
|
||||
|
||||
// Still call notification for size limit failures
|
||||
n.SendBackupNotification(
|
||||
backupConfig,
|
||||
currentBackup,
|
||||
backups_config.NotificationBackupFailed,
|
||||
currentBackup.FailMessage,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
errMsg := err.Error()
|
||||
|
||||
// Log detailed error information for debugging
|
||||
n.logger.Error("Backup execution failed",
|
||||
"backupId", backup.ID,
|
||||
"databaseId", databaseID,
|
||||
"databaseType", database.Type,
|
||||
"storageId", storage.ID,
|
||||
"storageType", storage.Type,
|
||||
"error", err,
|
||||
"errorMessage", errMsg,
|
||||
)
|
||||
|
||||
// Check if backup was cancelled (not due to shutdown)
|
||||
isCancelled := strings.Contains(errMsg, "backup cancelled") ||
|
||||
strings.Contains(errMsg, "context canceled") ||
|
||||
errors.Is(err, context.Canceled)
|
||||
isShutdown := strings.Contains(errMsg, "shutdown")
|
||||
|
||||
if isCancelled && !isShutdown {
|
||||
n.logger.Warn("Backup was cancelled by user or system",
|
||||
"backupId", backup.ID,
|
||||
"isCancelled", isCancelled,
|
||||
"isShutdown", isShutdown,
|
||||
)
|
||||
|
||||
backup.Status = backups_core.BackupStatusCanceled
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save cancelled backup", "error", err)
|
||||
}
|
||||
|
||||
// Delete partial backup from storage
|
||||
storage, storageErr := n.storageService.GetStorageByID(backup.StorageID)
|
||||
if storageErr == nil {
|
||||
if deleteErr := storage.DeleteFile(n.fieldEncryptor, backup.FileName); deleteErr != nil {
|
||||
n.logger.Error(
|
||||
"Failed to delete partial backup file",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
deleteErr,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
backup.FailMessage = &errMsg
|
||||
backup.Status = backups_core.BackupStatusFailed
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
if updateErr := n.databaseService.SetBackupError(databaseID, errMsg); updateErr != nil {
|
||||
n.logger.Error(
|
||||
"Failed to update database last backup time",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
updateErr,
|
||||
)
|
||||
}
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save backup", "error", err)
|
||||
}
|
||||
|
||||
n.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupFailed,
|
||||
&errMsg,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
backup.Status = backups_core.BackupStatusCompleted
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
// Update backup with encryption metadata if provided
|
||||
if backupMetadata != nil {
|
||||
backupMetadata.BackupID = backup.ID
|
||||
|
||||
if err := backupMetadata.Validate(); err != nil {
|
||||
n.logger.Error("Failed to validate backup metadata", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
backup.EncryptionSalt = backupMetadata.EncryptionSalt
|
||||
backup.EncryptionIV = backupMetadata.EncryptionIV
|
||||
backup.Encryption = backupMetadata.Encryption
|
||||
}
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save backup", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Save metadata file to storage
|
||||
if backupMetadata != nil {
|
||||
metadataJSON, err := json.Marshal(backupMetadata)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to marshal backup metadata to JSON",
|
||||
"backupId", backup.ID,
|
||||
"error", err,
|
||||
)
|
||||
} else {
|
||||
metadataReader := bytes.NewReader(metadataJSON)
|
||||
metadataFileName := backup.FileName + ".metadata"
|
||||
|
||||
if err := storage.SaveFile(
|
||||
context.Background(),
|
||||
n.fieldEncryptor,
|
||||
n.logger,
|
||||
metadataFileName,
|
||||
metadataReader,
|
||||
); err != nil {
|
||||
n.logger.Error("Failed to save backup metadata file to storage",
|
||||
"backupId", backup.ID,
|
||||
"fileName", metadataFileName,
|
||||
"error", err,
|
||||
)
|
||||
} else {
|
||||
n.logger.Info("Backup metadata file saved successfully",
|
||||
"backupId", backup.ID,
|
||||
"fileName", metadataFileName,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update database last backup time
|
||||
now := time.Now().UTC()
|
||||
if updateErr := n.databaseService.SetLastBackupTime(databaseID, now); updateErr != nil {
|
||||
n.logger.Error(
|
||||
"Failed to update database last backup time",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
updateErr,
|
||||
)
|
||||
}
|
||||
|
||||
if backup.Status != backups_core.BackupStatusCompleted && !isCallNotifier {
|
||||
return
|
||||
}
|
||||
|
||||
n.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupSuccess,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
func (n *BackuperNode) SendBackupNotification(
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
backup *backups_core.Backup,
|
||||
notificationType backups_config.BackupNotificationType,
|
||||
errorMessage *string,
|
||||
) {
|
||||
database, err := n.databaseService.GetDatabaseByID(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
workspace, err := n.workspaceService.GetWorkspaceByID(*database.WorkspaceID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, notifier := range database.Notifiers {
|
||||
if !slices.Contains(
|
||||
backupConfig.SendNotificationsOn,
|
||||
notificationType,
|
||||
) {
|
||||
continue
|
||||
}
|
||||
|
||||
title := ""
|
||||
switch notificationType {
|
||||
case backups_config.NotificationBackupFailed:
|
||||
title = fmt.Sprintf(
|
||||
"❌ Backup failed for database \"%s\" (workspace \"%s\")",
|
||||
database.Name,
|
||||
workspace.Name,
|
||||
)
|
||||
case backups_config.NotificationBackupSuccess:
|
||||
title = fmt.Sprintf(
|
||||
"✅ Backup completed for database \"%s\" (workspace \"%s\")",
|
||||
database.Name,
|
||||
workspace.Name,
|
||||
)
|
||||
}
|
||||
|
||||
message := ""
|
||||
if errorMessage != nil {
|
||||
message = *errorMessage
|
||||
} else {
|
||||
// Format size conditionally
|
||||
var sizeStr string
|
||||
if backup.BackupSizeMb < 1024 {
|
||||
sizeStr = fmt.Sprintf("%.2f MB", backup.BackupSizeMb)
|
||||
} else {
|
||||
sizeGB := backup.BackupSizeMb / 1024
|
||||
sizeStr = fmt.Sprintf("%.2f GB", sizeGB)
|
||||
}
|
||||
|
||||
// Format duration as "0m 0s 0ms"
|
||||
totalMs := backup.BackupDurationMs
|
||||
minutes := totalMs / (1000 * 60)
|
||||
seconds := (totalMs % (1000 * 60)) / 1000
|
||||
durationStr := fmt.Sprintf("%dm %ds", minutes, seconds)
|
||||
|
||||
message = fmt.Sprintf(
|
||||
"Backup completed successfully in %s.\nCompressed backup size: %s",
|
||||
durationStr,
|
||||
sizeStr,
|
||||
)
|
||||
}
|
||||
|
||||
n.notificationSender.SendNotification(
|
||||
¬ifier,
|
||||
title,
|
||||
message,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (n *BackuperNode) sendHeartbeat(backupNode *BackupNode) {
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil {
|
||||
n.logger.Error("Failed to send heartbeat", "error", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,273 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
)
|
||||
|
||||
func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
t.Run("BackupFailed_FailNotificationSent", func(t *testing.T) {
|
||||
mockNotificationSender := &MockNotificationSender{}
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.notificationSender = mockNotificationSender
|
||||
backuperNode.createBackupUseCase = &CreateFailedBackupUsecase{}
|
||||
|
||||
// Create a backup record directly that will be looked up by MakeBackup
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err := backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Set up expectations
|
||||
mockNotificationSender.On("SendNotification",
|
||||
mock.Anything,
|
||||
mock.MatchedBy(func(title string) bool {
|
||||
return strings.Contains(title, "❌ Backup failed")
|
||||
}),
|
||||
mock.MatchedBy(func(message string) bool {
|
||||
return strings.Contains(message, "backup failed")
|
||||
}),
|
||||
).Once()
|
||||
|
||||
backuperNode.MakeBackup(backup.ID, true)
|
||||
|
||||
// Verify all expectations were met
|
||||
mockNotificationSender.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("BackupSuccess_SuccessNotificationSent", func(t *testing.T) {
|
||||
mockNotificationSender := &MockNotificationSender{}
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.notificationSender = mockNotificationSender
|
||||
backuperNode.createBackupUseCase = &CreateSuccessBackupUsecase{}
|
||||
|
||||
// Create a backup record directly that will be looked up by MakeBackup
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err := backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Set up expectations
|
||||
mockNotificationSender.On("SendNotification",
|
||||
mock.Anything,
|
||||
mock.MatchedBy(func(title string) bool {
|
||||
return strings.Contains(title, "✅ Backup completed")
|
||||
}),
|
||||
mock.MatchedBy(func(message string) bool {
|
||||
return strings.Contains(message, "Backup completed successfully")
|
||||
}),
|
||||
).Once()
|
||||
|
||||
backuperNode.MakeBackup(backup.ID, true)
|
||||
|
||||
// Verify all expectations were met
|
||||
mockNotificationSender.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("BackupSuccess_VerifyNotificationContent", func(t *testing.T) {
|
||||
mockNotificationSender := &MockNotificationSender{}
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.notificationSender = mockNotificationSender
|
||||
backuperNode.createBackupUseCase = &CreateSuccessBackupUsecase{}
|
||||
|
||||
// Create a backup record directly that will be looked up by MakeBackup
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err := backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// capture arguments
|
||||
var capturedNotifier *notifiers.Notifier
|
||||
var capturedTitle string
|
||||
var capturedMessage string
|
||||
|
||||
mockNotificationSender.On("SendNotification",
|
||||
mock.Anything,
|
||||
mock.AnythingOfType("string"),
|
||||
mock.AnythingOfType("string"),
|
||||
).Run(func(args mock.Arguments) {
|
||||
capturedNotifier = args.Get(0).(*notifiers.Notifier)
|
||||
capturedTitle = args.Get(1).(string)
|
||||
capturedMessage = args.Get(2).(string)
|
||||
}).Once()
|
||||
|
||||
backuperNode.MakeBackup(backup.ID, true)
|
||||
|
||||
// Verify expectations were met
|
||||
mockNotificationSender.AssertExpectations(t)
|
||||
|
||||
// Additional detailed assertions
|
||||
assert.Contains(t, capturedTitle, "✅ Backup completed")
|
||||
assert.Contains(t, capturedTitle, database.Name)
|
||||
assert.Contains(t, capturedMessage, "Backup completed successfully")
|
||||
assert.Contains(t, capturedMessage, "10.00 MB")
|
||||
assert.Equal(t, notifier.ID, capturedNotifier.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_BackupSizeLimits(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
t.Run("UnlimitedSize_MaxBackupSizeMBIsZero_BackupCompletes", func(t *testing.T) {
|
||||
// Enable backups with unlimited size (0)
|
||||
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
backupConfig.MaxBackupSizeMB = 0 // unlimited
|
||||
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.createBackupUseCase = &CreateLargeBackupUsecase{}
|
||||
|
||||
// Create a backup record
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode.MakeBackup(backup.ID, false)
|
||||
|
||||
// Verify backup completed successfully even with large size
|
||||
updatedBackup, err := backupRepository.FindByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
|
||||
assert.Equal(t, float64(10000), updatedBackup.BackupSizeMb)
|
||||
assert.Nil(t, updatedBackup.FailMessage)
|
||||
})
|
||||
|
||||
t.Run("SizeExceeded_BackupFailedWithIsSkipRetry", func(t *testing.T) {
|
||||
// Enable backups with 5 MB limit
|
||||
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
backupConfig.MaxBackupSizeMB = 5
|
||||
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.createBackupUseCase = &CreateProgressiveBackupUsecase{}
|
||||
|
||||
// Create a backup record
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode.MakeBackup(backup.ID, false)
|
||||
|
||||
// Verify backup was marked as failed with IsSkipRetry=true
|
||||
updatedBackup, err := backupRepository.FindByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusFailed, updatedBackup.Status)
|
||||
assert.True(t, updatedBackup.IsSkipRetry)
|
||||
assert.NotNil(t, updatedBackup.FailMessage)
|
||||
assert.Contains(t, *updatedBackup.FailMessage, "exceeded maximum allowed size")
|
||||
assert.Contains(t, *updatedBackup.FailMessage, "10.00 MB")
|
||||
assert.Contains(t, *updatedBackup.FailMessage, "5 MB")
|
||||
assert.Greater(t, updatedBackup.BackupSizeMb, float64(5))
|
||||
})
|
||||
|
||||
t.Run("SizeWithinLimit_BackupCompletes", func(t *testing.T) {
|
||||
// Enable backups with 100 MB limit
|
||||
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
backupConfig.MaxBackupSizeMB = 100
|
||||
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.createBackupUseCase = &CreateMediumBackupUsecase{}
|
||||
|
||||
// Create a backup record
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode.MakeBackup(backup.ID, false)
|
||||
|
||||
// Verify backup completed successfully
|
||||
updatedBackup, err := backupRepository.FindByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
|
||||
assert.Equal(t, float64(50), updatedBackup.BackupSizeMb)
|
||||
assert.Nil(t, updatedBackup.FailMessage)
|
||||
})
|
||||
}
|
||||
520
backend/internal/features/backups/backups/backuping/cleaner.go
Normal file
520
backend/internal/features/backups/backups/backuping/cleaner.go
Normal file
@@ -0,0 +1,520 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/storages"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/period"
|
||||
)
|
||||
|
||||
const (
|
||||
cleanerTickerInterval = 1 * time.Minute
|
||||
recentBackupGracePeriod = 60 * time.Minute
|
||||
)
|
||||
|
||||
type BackupCleaner struct {
|
||||
backupRepository *backups_core.BackupRepository
|
||||
storageService *storages.StorageService
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
fieldEncryptor util_encryption.FieldEncryptor
|
||||
logger *slog.Logger
|
||||
backupRemoveListeners []backups_core.BackupRemoveListener
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) Run(ctx context.Context) {
|
||||
wasAlreadyRun := c.hasRun.Load()
|
||||
|
||||
c.runOnce.Do(func() {
|
||||
c.hasRun.Store(true)
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(cleanerTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := c.cleanByRetentionPolicy(); err != nil {
|
||||
c.logger.Error("Failed to clean backups by retention policy", "error", err)
|
||||
}
|
||||
|
||||
if err := c.cleanExceededBackups(); err != nil {
|
||||
c.logger.Error("Failed to clean exceeded backups", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", c))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) DeleteBackup(backup *backups_core.Backup) error {
|
||||
for _, listener := range c.backupRemoveListeners {
|
||||
if err := listener.OnBeforeBackupRemove(backup); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
storage, err := c.storageService.GetStorageByID(backup.StorageID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := storage.DeleteFile(c.fieldEncryptor, backup.FileName); err != nil {
|
||||
// we do not return error here, because sometimes clean up performed
|
||||
// before unavailable storage removal or change - therefore we should
|
||||
// proceed even in case of error. It's possible that some S3 or
|
||||
// storage is not available yet, it should not block us
|
||||
c.logger.Error("Failed to delete backup file", "error", err)
|
||||
}
|
||||
|
||||
metadataFileName := backup.FileName + ".metadata"
|
||||
if err := storage.DeleteFile(c.fieldEncryptor, metadataFileName); err != nil {
|
||||
c.logger.Error("Failed to delete backup metadata file", "error", err)
|
||||
}
|
||||
|
||||
return c.backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) AddBackupRemoveListener(listener backups_core.BackupRemoveListener) {
|
||||
c.backupRemoveListeners = append(c.backupRemoveListeners, listener)
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanByRetentionPolicy() error {
|
||||
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
var cleanErr error
|
||||
|
||||
switch backupConfig.RetentionPolicyType {
|
||||
case backups_config.RetentionPolicyTypeCount:
|
||||
cleanErr = c.cleanByCount(backupConfig)
|
||||
case backups_config.RetentionPolicyTypeGFS:
|
||||
cleanErr = c.cleanByGFS(backupConfig)
|
||||
default:
|
||||
cleanErr = c.cleanByTimePeriod(backupConfig)
|
||||
}
|
||||
|
||||
if cleanErr != nil {
|
||||
c.logger.Error(
|
||||
"Failed to clean backups by retention policy",
|
||||
"databaseId", backupConfig.DatabaseID,
|
||||
"policy", backupConfig.RetentionPolicyType,
|
||||
"error", cleanErr,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanExceededBackups() error {
|
||||
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
if backupConfig.MaxBackupsTotalSizeMB <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := c.cleanExceededBackupsForDatabase(
|
||||
backupConfig.DatabaseID,
|
||||
backupConfig.MaxBackupsTotalSizeMB,
|
||||
); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to clean exceeded backups for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanByTimePeriod(backupConfig *backups_config.BackupConfig) error {
|
||||
if backupConfig.RetentionTimePeriod == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if backupConfig.RetentionTimePeriod == period.PeriodForever {
|
||||
return nil
|
||||
}
|
||||
|
||||
storeDuration := backupConfig.RetentionTimePeriod.ToDuration()
|
||||
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
|
||||
|
||||
oldBackups, err := c.backupRepository.FindBackupsBeforeDate(
|
||||
backupConfig.DatabaseID,
|
||||
dateBeforeBackupsShouldBeDeleted,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to find old backups for database %s: %w",
|
||||
backupConfig.DatabaseID,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
for _, backup := range oldBackups {
|
||||
if isRecentBackup(backup) {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
c.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Deleted old backup",
|
||||
"backupId", backup.ID,
|
||||
"databaseId", backupConfig.DatabaseID,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanByCount(backupConfig *backups_config.BackupConfig) error {
|
||||
if backupConfig.RetentionCount <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
completedBackups, err := c.backupRepository.FindByDatabaseIdAndStatus(
|
||||
backupConfig.DatabaseID,
|
||||
backups_core.BackupStatusCompleted,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to find completed backups for database %s: %w",
|
||||
backupConfig.DatabaseID,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
// completedBackups are ordered newest first; delete everything beyond position RetentionCount
|
||||
if len(completedBackups) <= backupConfig.RetentionCount {
|
||||
return nil
|
||||
}
|
||||
|
||||
toDelete := completedBackups[backupConfig.RetentionCount:]
|
||||
for _, backup := range toDelete {
|
||||
if isRecentBackup(backup) {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to delete backup by count policy",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Deleted backup by count policy",
|
||||
"backupId", backup.ID,
|
||||
"databaseId", backupConfig.DatabaseID,
|
||||
"retentionCount", backupConfig.RetentionCount,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanByGFS(backupConfig *backups_config.BackupConfig) error {
|
||||
if backupConfig.RetentionGfsHours <= 0 && backupConfig.RetentionGfsDays <= 0 &&
|
||||
backupConfig.RetentionGfsWeeks <= 0 && backupConfig.RetentionGfsMonths <= 0 &&
|
||||
backupConfig.RetentionGfsYears <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
completedBackups, err := c.backupRepository.FindByDatabaseIdAndStatus(
|
||||
backupConfig.DatabaseID,
|
||||
backups_core.BackupStatusCompleted,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to find completed backups for database %s: %w",
|
||||
backupConfig.DatabaseID,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
keepSet := buildGFSKeepSet(
|
||||
completedBackups,
|
||||
backupConfig.RetentionGfsHours,
|
||||
backupConfig.RetentionGfsDays,
|
||||
backupConfig.RetentionGfsWeeks,
|
||||
backupConfig.RetentionGfsMonths,
|
||||
backupConfig.RetentionGfsYears,
|
||||
)
|
||||
|
||||
for _, backup := range completedBackups {
|
||||
if keepSet[backup.ID] {
|
||||
continue
|
||||
}
|
||||
|
||||
if isRecentBackup(backup) {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to delete backup by GFS policy",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Deleted backup by GFS policy",
|
||||
"backupId", backup.ID,
|
||||
"databaseId", backupConfig.DatabaseID,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanExceededBackupsForDatabase(
|
||||
databaseID uuid.UUID,
|
||||
limitperDbMB int64,
|
||||
) error {
|
||||
for {
|
||||
backupsTotalSizeMB, err := c.backupRepository.GetTotalSizeByDatabase(databaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if backupsTotalSizeMB <= float64(limitperDbMB) {
|
||||
break
|
||||
}
|
||||
|
||||
oldestBackups, err := c.backupRepository.FindOldestByDatabaseExcludingInProgress(
|
||||
databaseID,
|
||||
1,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(oldestBackups) == 0 {
|
||||
c.logger.Warn(
|
||||
"No backups to delete but still over limit",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"totalSizeMB",
|
||||
backupsTotalSizeMB,
|
||||
"limitMB",
|
||||
limitperDbMB,
|
||||
)
|
||||
break
|
||||
}
|
||||
|
||||
backup := oldestBackups[0]
|
||||
if isRecentBackup(backup) {
|
||||
c.logger.Warn(
|
||||
"Oldest backup is too recent to delete, stopping size cleanup",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"totalSizeMB",
|
||||
backupsTotalSizeMB,
|
||||
"limitMB",
|
||||
limitperDbMB,
|
||||
)
|
||||
break
|
||||
}
|
||||
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to delete exceeded backup",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Deleted exceeded backup",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"backupSizeMB",
|
||||
backup.BackupSizeMb,
|
||||
"totalSizeMB",
|
||||
backupsTotalSizeMB,
|
||||
"limitMB",
|
||||
limitperDbMB,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isRecentBackup(backup *backups_core.Backup) bool {
|
||||
return time.Since(backup.CreatedAt) < recentBackupGracePeriod
|
||||
}
|
||||
|
||||
// buildGFSKeepSet determines which backups to retain under the GFS rotation scheme.
|
||||
// Backups must be sorted newest-first. A backup can fill multiple slots simultaneously
|
||||
// (e.g. the newest backup of a year also fills the monthly, weekly, daily, and hourly slot).
|
||||
func buildGFSKeepSet(
|
||||
backups []*backups_core.Backup,
|
||||
hours, days, weeks, months, years int,
|
||||
) map[uuid.UUID]bool {
|
||||
keep := make(map[uuid.UUID]bool)
|
||||
|
||||
if len(backups) == 0 {
|
||||
return keep
|
||||
}
|
||||
|
||||
hoursSeen := make(map[string]bool)
|
||||
daysSeen := make(map[string]bool)
|
||||
weeksSeen := make(map[string]bool)
|
||||
monthsSeen := make(map[string]bool)
|
||||
yearsSeen := make(map[string]bool)
|
||||
|
||||
hoursKept, daysKept, weeksKept, monthsKept, yearsKept := 0, 0, 0, 0, 0
|
||||
|
||||
// Compute per-level time-window cutoffs so higher-frequency slots
|
||||
// cannot absorb backups that belong to lower-frequency levels.
|
||||
ref := backups[0].CreatedAt
|
||||
|
||||
rawHourlyCutoff := ref.Add(-time.Duration(hours) * time.Hour)
|
||||
rawDailyCutoff := ref.Add(-time.Duration(days) * 24 * time.Hour)
|
||||
rawWeeklyCutoff := ref.Add(-time.Duration(weeks) * 7 * 24 * time.Hour)
|
||||
rawMonthlyCutoff := ref.AddDate(0, -months, 0)
|
||||
rawYearlyCutoff := ref.AddDate(-years, 0, 0)
|
||||
|
||||
// Hierarchical capping: each level's window cannot extend further back
|
||||
// than the nearest active lower-frequency level's window.
|
||||
yearlyCutoff := rawYearlyCutoff
|
||||
|
||||
monthlyCutoff := rawMonthlyCutoff
|
||||
if years > 0 {
|
||||
monthlyCutoff = laterOf(monthlyCutoff, yearlyCutoff)
|
||||
}
|
||||
|
||||
weeklyCutoff := rawWeeklyCutoff
|
||||
if months > 0 {
|
||||
weeklyCutoff = laterOf(weeklyCutoff, monthlyCutoff)
|
||||
} else if years > 0 {
|
||||
weeklyCutoff = laterOf(weeklyCutoff, yearlyCutoff)
|
||||
}
|
||||
|
||||
dailyCutoff := rawDailyCutoff
|
||||
switch {
|
||||
case weeks > 0:
|
||||
dailyCutoff = laterOf(dailyCutoff, weeklyCutoff)
|
||||
case months > 0:
|
||||
dailyCutoff = laterOf(dailyCutoff, monthlyCutoff)
|
||||
case years > 0:
|
||||
dailyCutoff = laterOf(dailyCutoff, yearlyCutoff)
|
||||
}
|
||||
|
||||
hourlyCutoff := rawHourlyCutoff
|
||||
switch {
|
||||
case days > 0:
|
||||
hourlyCutoff = laterOf(hourlyCutoff, dailyCutoff)
|
||||
case weeks > 0:
|
||||
hourlyCutoff = laterOf(hourlyCutoff, weeklyCutoff)
|
||||
case months > 0:
|
||||
hourlyCutoff = laterOf(hourlyCutoff, monthlyCutoff)
|
||||
case years > 0:
|
||||
hourlyCutoff = laterOf(hourlyCutoff, yearlyCutoff)
|
||||
}
|
||||
|
||||
for _, backup := range backups {
|
||||
t := backup.CreatedAt
|
||||
|
||||
hourKey := t.Format("2006-01-02-15")
|
||||
dayKey := t.Format("2006-01-02")
|
||||
weekYear, week := t.ISOWeek()
|
||||
weekKey := fmt.Sprintf("%d-%02d", weekYear, week)
|
||||
monthKey := t.Format("2006-01")
|
||||
yearKey := t.Format("2006")
|
||||
|
||||
if hours > 0 && hoursKept < hours && !hoursSeen[hourKey] && t.After(hourlyCutoff) {
|
||||
keep[backup.ID] = true
|
||||
hoursSeen[hourKey] = true
|
||||
hoursKept++
|
||||
}
|
||||
|
||||
if days > 0 && daysKept < days && !daysSeen[dayKey] && t.After(dailyCutoff) {
|
||||
keep[backup.ID] = true
|
||||
daysSeen[dayKey] = true
|
||||
daysKept++
|
||||
}
|
||||
|
||||
if weeks > 0 && weeksKept < weeks && !weeksSeen[weekKey] && t.After(weeklyCutoff) {
|
||||
keep[backup.ID] = true
|
||||
weeksSeen[weekKey] = true
|
||||
weeksKept++
|
||||
}
|
||||
|
||||
if months > 0 && monthsKept < months && !monthsSeen[monthKey] && t.After(monthlyCutoff) {
|
||||
keep[backup.ID] = true
|
||||
monthsSeen[monthKey] = true
|
||||
monthsKept++
|
||||
}
|
||||
|
||||
if years > 0 && yearsKept < years && !yearsSeen[yearKey] && t.After(yearlyCutoff) {
|
||||
keep[backup.ID] = true
|
||||
yearsSeen[yearKey] = true
|
||||
yearsKept++
|
||||
}
|
||||
}
|
||||
|
||||
return keep
|
||||
}
|
||||
|
||||
func laterOf(a, b time.Time) time.Time {
|
||||
if a.After(b) {
|
||||
return a
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
1020
backend/internal/features/backups/backups/backuping/cleaner_test.go
Normal file
1020
backend/internal/features/backups/backups/backuping/cleaner_test.go
Normal file
File diff suppressed because it is too large
Load Diff
98
backend/internal/features/backups/backups/backuping/di.go
Normal file
98
backend/internal/features/backups/backups/backuping/di.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/usecases"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var backupRepository = &backups_core.BackupRepository{}
|
||||
|
||||
var taskCancelManager = tasks_cancellation.GetTaskCancelManager()
|
||||
|
||||
var backupCleaner = &BackupCleaner{
|
||||
backupRepository,
|
||||
storages.GetStorageService(),
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
logger.GetLogger(),
|
||||
[]backups_core.BackupRemoveListener{},
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
var backupNodesRegistry = &BackupNodesRegistry{
|
||||
cache_utils.GetValkeyClient(),
|
||||
logger.GetLogger(),
|
||||
cache_utils.DefaultCacheTimeout,
|
||||
cache_utils.NewPubSubManager(),
|
||||
cache_utils.NewPubSubManager(),
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
func getNodeID() uuid.UUID {
|
||||
return uuid.New()
|
||||
}
|
||||
|
||||
var backuperNode = &BackuperNode{
|
||||
databases.GetDatabaseService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
taskCancelManager,
|
||||
backupNodesRegistry,
|
||||
logger.GetLogger(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
getNodeID(),
|
||||
time.Time{},
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
var backupsScheduler = &BackupsScheduler{
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
taskCancelManager,
|
||||
backupNodesRegistry,
|
||||
databases.GetDatabaseService(),
|
||||
time.Now().UTC(),
|
||||
logger.GetLogger(),
|
||||
make(map[uuid.UUID]BackupToNodeRelation),
|
||||
backuperNode,
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
func GetBackupsScheduler() *BackupsScheduler {
|
||||
return backupsScheduler
|
||||
}
|
||||
|
||||
func GetBackuperNode() *BackuperNode {
|
||||
return backuperNode
|
||||
}
|
||||
|
||||
func GetBackupNodesRegistry() *BackupNodesRegistry {
|
||||
return backupNodesRegistry
|
||||
}
|
||||
|
||||
func GetBackupCleaner() *BackupCleaner {
|
||||
return backupCleaner
|
||||
}
|
||||
34
backend/internal/features/backups/backups/backuping/dto.go
Normal file
34
backend/internal/features/backups/backups/backuping/dto.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BackupToNodeRelation struct {
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
BackupsIDs []uuid.UUID `json:"backupsIds"`
|
||||
}
|
||||
|
||||
type BackupNode struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ThroughputMBs int `json:"throughputMBs"`
|
||||
LastHeartbeat time.Time `json:"lastHeartbeat"`
|
||||
}
|
||||
|
||||
type BackupNodeStats struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ActiveBackups int `json:"activeBackups"`
|
||||
}
|
||||
|
||||
type BackupSubmitMessage struct {
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
BackupID uuid.UUID `json:"backupId"`
|
||||
IsCallNotifier bool `json:"isCallNotifier"`
|
||||
}
|
||||
|
||||
type BackupCompletionMessage struct {
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
BackupID uuid.UUID `json:"backupId"`
|
||||
}
|
||||
195
backend/internal/features/backups/backups/backuping/mocks.go
Normal file
195
backend/internal/features/backups/backups/backuping/mocks.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/mock"
|
||||
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
)
|
||||
|
||||
type MockNotificationSender struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockNotificationSender) SendNotification(
|
||||
notifier *notifiers.Notifier,
|
||||
title string,
|
||||
message string,
|
||||
) {
|
||||
m.Called(notifier, title, message)
|
||||
}
|
||||
|
||||
type CreateFailedBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateFailedBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
backupProgressListener(10)
|
||||
return nil, errors.New("backup failed")
|
||||
}
|
||||
|
||||
type CreateSuccessBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateSuccessBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
backupProgressListener(10)
|
||||
return &common.BackupMetadata{
|
||||
EncryptionSalt: nil,
|
||||
EncryptionIV: nil,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateLargeBackupUsecase simulates a large backup (10000 MB)
|
||||
type CreateLargeBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateLargeBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
backupProgressListener(10000)
|
||||
return &common.BackupMetadata{
|
||||
EncryptionSalt: nil,
|
||||
EncryptionIV: nil,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateProgressiveBackupUsecase simulates progressive size updates that exceed limit
|
||||
type CreateProgressiveBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateProgressiveBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
// Simulate progressive backup that grows beyond limit
|
||||
backupProgressListener(1)
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
backupProgressListener(3)
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
backupProgressListener(5)
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
backupProgressListener(10) // This exceeds the 5 MB limit
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
// Should not reach here due to cancellation
|
||||
return &common.BackupMetadata{
|
||||
EncryptionSalt: nil,
|
||||
EncryptionIV: nil,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateMediumBackupUsecase simulates a 50 MB backup
|
||||
type CreateMediumBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateMediumBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
backupProgressListener(50)
|
||||
return &common.BackupMetadata{
|
||||
EncryptionSalt: nil,
|
||||
EncryptionIV: nil,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// MockTrackingBackupUsecase tracks backup use case calls for testing parallel execution
|
||||
type MockTrackingBackupUsecase struct {
|
||||
callCount atomic.Int32
|
||||
calledBackupIDs chan uuid.UUID
|
||||
}
|
||||
|
||||
func NewMockTrackingBackupUsecase() *MockTrackingBackupUsecase {
|
||||
return &MockTrackingBackupUsecase{
|
||||
calledBackupIDs: make(chan uuid.UUID, 10),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockTrackingBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
m.callCount.Add(1)
|
||||
|
||||
// Send backup ID to channel (non-blocking)
|
||||
select {
|
||||
case m.calledBackupIDs <- backup.ID:
|
||||
default:
|
||||
}
|
||||
|
||||
// Simulate backup work
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
backupProgressListener(10)
|
||||
|
||||
return &common.BackupMetadata{
|
||||
EncryptionSalt: nil,
|
||||
EncryptionIV: nil,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MockTrackingBackupUsecase) GetCallCount() int32 {
|
||||
return m.callCount.Load()
|
||||
}
|
||||
|
||||
func (m *MockTrackingBackupUsecase) GetCalledBackupIDs() []uuid.UUID {
|
||||
ids := []uuid.UUID{}
|
||||
for {
|
||||
select {
|
||||
case id := <-m.calledBackupIDs:
|
||||
ids = append(ids, id)
|
||||
default:
|
||||
return ids
|
||||
}
|
||||
}
|
||||
}
|
||||
633
backend/internal/features/backups/backups/backuping/registry.go
Normal file
633
backend/internal/features/backups/backups/backuping/registry.go
Normal file
@@ -0,0 +1,633 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/valkey-io/valkey-go"
|
||||
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
)
|
||||
|
||||
const (
|
||||
nodeInfoKeyPrefix = "backup:node:"
|
||||
nodeInfoKeySuffix = ":info"
|
||||
nodeActiveBackupsPrefix = "backup:node:"
|
||||
nodeActiveBackupsSuffix = ":active_backups"
|
||||
backupSubmitChannel = "backup:submit"
|
||||
backupCompletionChannel = "backup:completion"
|
||||
|
||||
deadNodeThreshold = 2 * time.Minute
|
||||
cleanupTickerInterval = 1 * time.Second
|
||||
)
|
||||
|
||||
// BackupNodesRegistry helps to sync backups scheduler and backup nodes.
|
||||
//
|
||||
// Features:
|
||||
// - Track node availability and load level
|
||||
// - Assign from scheduler to node backups needed to be processed
|
||||
// - Notify scheduler from node about backup completion
|
||||
//
|
||||
// Important things to remember:
|
||||
// - Nodes without heartbeat for more than 2 minutes are not included
|
||||
// in available nodes list and stats
|
||||
//
|
||||
// Cleanup dead nodes performed on 2 levels:
|
||||
// - List and stats functions do not return dead nodes
|
||||
// - Periodically dead nodes are cleaned up in cache (to not
|
||||
// accumulate too many dead nodes in cache)
|
||||
type BackupNodesRegistry struct {
|
||||
client valkey.Client
|
||||
logger *slog.Logger
|
||||
timeout time.Duration
|
||||
pubsubBackups *cache_utils.PubSubManager
|
||||
pubsubCompletions *cache_utils.PubSubManager
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) Run(ctx context.Context) {
|
||||
wasAlreadyRun := r.hasRun.Load()
|
||||
|
||||
r.runOnce.Do(func() {
|
||||
r.hasRun.Store(true)
|
||||
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(cleanupTickerInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", r))
|
||||
}
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) GetAvailableNodes() ([]BackupNode, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
var allKeys []string
|
||||
cursor := uint64(0)
|
||||
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
|
||||
|
||||
for {
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return nil, fmt.Errorf("failed to scan node keys: %w", result.Error())
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse scan result: %w", err)
|
||||
}
|
||||
|
||||
allKeys = append(allKeys, scanResult.Elements...)
|
||||
|
||||
cursor = scanResult.Cursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return []BackupNode{}, nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pipeline get node keys: %w", err)
|
||||
}
|
||||
|
||||
threshold := time.Now().UTC().Add(-deadNodeThreshold)
|
||||
var nodes []BackupNode
|
||||
|
||||
for key, data := range keyDataMap {
|
||||
// Skip if the key doesn't exist (data is empty)
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node BackupNode
|
||||
if err := json.Unmarshal(data, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data", "key", key, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip nodes with zero/uninitialized heartbeat
|
||||
if node.LastHeartbeat.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.LastHeartbeat.Before(threshold) {
|
||||
continue
|
||||
}
|
||||
|
||||
nodes = append(nodes, node)
|
||||
}
|
||||
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) GetBackupNodesStats() ([]BackupNodeStats, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
var allKeys []string
|
||||
cursor := uint64(0)
|
||||
pattern := nodeActiveBackupsPrefix + "*" + nodeActiveBackupsSuffix
|
||||
|
||||
for {
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(100).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return nil, fmt.Errorf("failed to scan active backups keys: %w", result.Error())
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse scan result: %w", err)
|
||||
}
|
||||
|
||||
allKeys = append(allKeys, scanResult.Elements...)
|
||||
|
||||
cursor = scanResult.Cursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return []BackupNodeStats{}, nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pipeline get active backups keys: %w", err)
|
||||
}
|
||||
|
||||
var nodeInfoKeys []string
|
||||
nodeIDToStatsKey := make(map[string]string)
|
||||
for key := range keyDataMap {
|
||||
nodeID := r.extractNodeIDFromKey(key, nodeActiveBackupsPrefix, nodeActiveBackupsSuffix)
|
||||
nodeIDStr := nodeID.String()
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeIDStr, nodeInfoKeySuffix)
|
||||
nodeInfoKeys = append(nodeInfoKeys, infoKey)
|
||||
nodeIDToStatsKey[infoKey] = key
|
||||
}
|
||||
|
||||
nodeInfoMap, err := r.pipelineGetKeys(nodeInfoKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pipeline get node info keys: %w", err)
|
||||
}
|
||||
|
||||
threshold := time.Now().UTC().Add(-deadNodeThreshold)
|
||||
var stats []BackupNodeStats
|
||||
for infoKey, nodeData := range nodeInfoMap {
|
||||
// Skip if the info key doesn't exist (nodeData is empty)
|
||||
if len(nodeData) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node BackupNode
|
||||
if err := json.Unmarshal(nodeData, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data", "key", infoKey, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip nodes with zero/uninitialized heartbeat
|
||||
if node.LastHeartbeat.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.LastHeartbeat.Before(threshold) {
|
||||
continue
|
||||
}
|
||||
|
||||
statsKey := nodeIDToStatsKey[infoKey]
|
||||
tasksData := keyDataMap[statsKey]
|
||||
count, err := r.parseIntFromBytes(tasksData)
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse active backups count", "key", statsKey, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
stat := BackupNodeStats{
|
||||
ID: node.ID,
|
||||
ActiveBackups: int(count),
|
||||
}
|
||||
stats = append(stats, stat)
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) IncrementBackupsInProgress(nodeID uuid.UUID) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID.String(), nodeActiveBackupsSuffix)
|
||||
result := r.client.Do(ctx, r.client.B().Incr().Key(key).Build())
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to increment backups in progress for node %s: %w",
|
||||
nodeID,
|
||||
result.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) DecrementBackupsInProgress(nodeID uuid.UUID) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID.String(), nodeActiveBackupsSuffix)
|
||||
result := r.client.Do(ctx, r.client.B().Decr().Key(key).Build())
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to decrement backups in progress for node %s: %w",
|
||||
nodeID,
|
||||
result.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
newValue, err := result.AsInt64()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse decremented value for node %s: %w", nodeID, err)
|
||||
}
|
||||
|
||||
if newValue < 0 {
|
||||
setCtx, setCancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
r.client.Do(setCtx, r.client.B().Set().Key(key).Value("0").Build())
|
||||
setCancel()
|
||||
r.logger.Warn("Active backups counter went below 0, reset to 0", "nodeID", nodeID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) HearthbeatNodeInRegistry(now time.Time, backupNode BackupNode) error {
|
||||
if now.IsZero() {
|
||||
return fmt.Errorf("cannot register node with zero heartbeat timestamp")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
backupNode.LastHeartbeat = now
|
||||
|
||||
data, err := json.Marshal(backupNode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal backup node: %w", err)
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, backupNode.ID.String(), nodeInfoKeySuffix)
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Set().Key(key).Value(string(data)).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to register node %s: %w", backupNode.ID, result.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) UnregisterNodeFromRegistry(backupNode BackupNode) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, backupNode.ID.String(), nodeInfoKeySuffix)
|
||||
counterKey := fmt.Sprintf(
|
||||
"%s%s%s",
|
||||
nodeActiveBackupsPrefix,
|
||||
backupNode.ID.String(),
|
||||
nodeActiveBackupsSuffix,
|
||||
)
|
||||
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Del().Key(infoKey, counterKey).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to unregister node %s: %w", backupNode.ID, result.Error())
|
||||
}
|
||||
|
||||
r.logger.Info("Unregistered node from registry", "nodeID", backupNode.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) AssignBackupToNode(
|
||||
targetNodeID uuid.UUID,
|
||||
backupID uuid.UUID,
|
||||
isCallNotifier bool,
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
message := BackupSubmitMessage{
|
||||
NodeID: targetNodeID,
|
||||
BackupID: backupID,
|
||||
IsCallNotifier: isCallNotifier,
|
||||
}
|
||||
|
||||
messageJSON, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal backup submit message: %w", err)
|
||||
}
|
||||
|
||||
err = r.pubsubBackups.Publish(ctx, backupSubmitChannel, string(messageJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to publish backup submit message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) SubscribeNodeForBackupsAssignment(
|
||||
nodeID uuid.UUID,
|
||||
handler func(backupID uuid.UUID, isCallNotifier bool),
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
wrappedHandler := func(message string) {
|
||||
var msg BackupSubmitMessage
|
||||
if err := json.Unmarshal([]byte(message), &msg); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal backup submit message", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if msg.NodeID != nodeID {
|
||||
return
|
||||
}
|
||||
|
||||
handler(msg.BackupID, msg.IsCallNotifier)
|
||||
}
|
||||
|
||||
err := r.pubsubBackups.Subscribe(ctx, backupSubmitChannel, wrappedHandler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to backup submit channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Subscribed to backup submit channel", "nodeID", nodeID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) UnsubscribeNodeForBackupsAssignments() error {
|
||||
err := r.pubsubBackups.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unsubscribe from backup submit channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Unsubscribed from backup submit channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) PublishBackupCompletion(nodeID, backupID uuid.UUID) error {
|
||||
ctx := context.Background()
|
||||
|
||||
message := BackupCompletionMessage{
|
||||
NodeID: nodeID,
|
||||
BackupID: backupID,
|
||||
}
|
||||
|
||||
messageJSON, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal backup completion message: %w", err)
|
||||
}
|
||||
|
||||
err = r.pubsubCompletions.Publish(ctx, backupCompletionChannel, string(messageJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to publish backup completion message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) SubscribeForBackupsCompletions(
|
||||
handler func(nodeID, backupID uuid.UUID),
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
wrappedHandler := func(message string) {
|
||||
var msg BackupCompletionMessage
|
||||
if err := json.Unmarshal([]byte(message), &msg); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal backup completion message", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
handler(msg.NodeID, msg.BackupID)
|
||||
}
|
||||
|
||||
err := r.pubsubCompletions.Subscribe(ctx, backupCompletionChannel, wrappedHandler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to backup completion channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Subscribed to backup completion channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) UnsubscribeForBackupsCompletions() error {
|
||||
err := r.pubsubCompletions.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unsubscribe from backup completion channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Unsubscribed from backup completion channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uuid.UUID {
|
||||
nodeIDStr := strings.TrimPrefix(key, prefix)
|
||||
nodeIDStr = strings.TrimSuffix(nodeIDStr, suffix)
|
||||
|
||||
nodeID, err := uuid.Parse(nodeIDStr)
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse node ID from key", "key", key, "error", err)
|
||||
return uuid.Nil
|
||||
}
|
||||
|
||||
return nodeID
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, error) {
|
||||
if len(keys) == 0 {
|
||||
return make(map[string][]byte), nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
commands := make([]valkey.Completed, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
commands = append(commands, r.client.B().Get().Key(key).Build())
|
||||
}
|
||||
|
||||
results := r.client.DoMulti(ctx, commands...)
|
||||
|
||||
keyDataMap := make(map[string][]byte, len(keys))
|
||||
for i, result := range results {
|
||||
if result.Error() != nil {
|
||||
r.logger.Warn("Failed to get key in pipeline", "key", keys[i], "error", result.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
data, err := result.AsBytes()
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse key data in pipeline", "key", keys[i], "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
keyDataMap[keys[i]] = data
|
||||
}
|
||||
|
||||
return keyDataMap, nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
|
||||
str := string(data)
|
||||
var count int64
|
||||
_, err := fmt.Sscanf(str, "%d", &count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to parse integer from bytes: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) cleanupDeadNodes() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
var allKeys []string
|
||||
cursor := uint64(0)
|
||||
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
|
||||
|
||||
for {
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to scan node keys: %w", result.Error())
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse scan result: %w", err)
|
||||
}
|
||||
|
||||
allKeys = append(allKeys, scanResult.Elements...)
|
||||
|
||||
cursor = scanResult.Cursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to pipeline get node keys: %w", err)
|
||||
}
|
||||
|
||||
threshold := time.Now().UTC().Add(-deadNodeThreshold)
|
||||
var deadNodeKeys []string
|
||||
|
||||
for key, data := range keyDataMap {
|
||||
// Skip if the key doesn't exist (data is empty)
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node BackupNode
|
||||
if err := json.Unmarshal(data, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data during cleanup", "key", key, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip nodes with zero/uninitialized heartbeat
|
||||
if node.LastHeartbeat.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.LastHeartbeat.Before(threshold) {
|
||||
nodeID := node.ID.String()
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeID, nodeInfoKeySuffix)
|
||||
statsKey := fmt.Sprintf(
|
||||
"%s%s%s",
|
||||
nodeActiveBackupsPrefix,
|
||||
nodeID,
|
||||
nodeActiveBackupsSuffix,
|
||||
)
|
||||
|
||||
deadNodeKeys = append(deadNodeKeys, infoKey, statsKey)
|
||||
r.logger.Info(
|
||||
"Marking node for cleanup",
|
||||
"nodeID", nodeID,
|
||||
"lastHeartbeat", node.LastHeartbeat,
|
||||
"threshold", threshold,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if len(deadNodeKeys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
delCtx, delCancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer delCancel()
|
||||
|
||||
result := r.client.Do(
|
||||
delCtx,
|
||||
r.client.B().Del().Key(deadNodeKeys...).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to delete dead node keys: %w", result.Error())
|
||||
}
|
||||
|
||||
deletedCount, err := result.AsInt64()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse deleted count: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Cleaned up dead nodes", "deletedKeysCount", deletedCount)
|
||||
return nil
|
||||
}
|
||||
1134
backend/internal/features/backups/backups/backuping/registry_test.go
Normal file
1134
backend/internal/features/backups/backups/backuping/registry_test.go
Normal file
File diff suppressed because it is too large
Load Diff
587
backend/internal/features/backups/backups/backuping/scheduler.go
Normal file
587
backend/internal/features/backups/backups/backuping/scheduler.go
Normal file
@@ -0,0 +1,587 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
)
|
||||
|
||||
const (
|
||||
schedulerStartupDelay = 1 * time.Minute
|
||||
schedulerTickerInterval = 1 * time.Minute
|
||||
schedulerHealthcheckThreshold = 5 * time.Minute
|
||||
)
|
||||
|
||||
type BackupsScheduler struct {
|
||||
backupRepository *backups_core.BackupRepository
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
taskCancelManager *task_cancellation.TaskCancelManager
|
||||
backupNodesRegistry *BackupNodesRegistry
|
||||
databaseService *databases.DatabaseService
|
||||
|
||||
lastBackupTime time.Time
|
||||
logger *slog.Logger
|
||||
|
||||
backupToNodeRelations map[uuid.UUID]BackupToNodeRelation
|
||||
backuperNode *BackuperNode
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) Run(ctx context.Context) {
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
|
||||
if config.GetEnv().IsManyNodesMode {
|
||||
// wait other nodes to start
|
||||
time.Sleep(schedulerStartupDelay)
|
||||
}
|
||||
|
||||
if err := s.failBackupsInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail backups in progress", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err := s.backupNodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to subscribe to backup completions", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := s.backupNodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
|
||||
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(schedulerTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.checkDeadNodesAndFailBackups(); err != nil {
|
||||
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.runPendingBackups(); err != nil {
|
||||
s.logger.Error("Failed to run pending backups", "error", err)
|
||||
}
|
||||
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) IsSchedulerRunning() bool {
|
||||
// if last backup time is more than 5 minutes ago, return false
|
||||
return s.lastBackupTime.After(time.Now().UTC().Add(-schedulerHealthcheckThreshold))
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) IsBackupNodesAvailable() bool {
|
||||
nodes, err := s.backupNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get available nodes for health check", "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
return len(nodes) > 0
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) StartBackup(database *databases.Database, isCallNotifier bool) {
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if backupConfig.StorageID == nil {
|
||||
s.logger.Error("Backup config storage ID is nil", "databaseId", database.ID)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for existing in-progress backups
|
||||
inProgressBackups, err := s.backupRepository.FindByDatabaseIdAndStatus(
|
||||
database.ID,
|
||||
backups_core.BackupStatusInProgress,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to check for in-progress backups",
|
||||
"databaseId",
|
||||
database.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if len(inProgressBackups) > 0 {
|
||||
s.logger.Warn(
|
||||
"Backup already in progress for database, skipping new backup",
|
||||
"databaseId",
|
||||
database.ID,
|
||||
"existingBackupId",
|
||||
inProgressBackups[0].ID,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
leastBusyNodeID, err := s.calculateLeastBusyNode()
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to calculate least busy node",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
backupID := uuid.New()
|
||||
timestamp := time.Now().UTC()
|
||||
|
||||
backup := &backups_core.Backup{
|
||||
ID: backupID,
|
||||
DatabaseID: backupConfig.DatabaseID,
|
||||
StorageID: *backupConfig.StorageID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
BackupSizeMb: 0,
|
||||
CreatedAt: timestamp,
|
||||
}
|
||||
|
||||
backup.GenerateFilename(database.Name)
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to save backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.backupNodesRegistry.IncrementBackupsInProgress(*leastBusyNodeID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to increment backups in progress",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.backupNodesRegistry.AssignBackupToNode(*leastBusyNodeID, backup.ID, isCallNotifier); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to submit backup",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
if decrementErr := s.backupNodesRegistry.DecrementBackupsInProgress(*leastBusyNodeID); decrementErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress after submit failure",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"error",
|
||||
decrementErr,
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if relation, exists := s.backupToNodeRelations[*leastBusyNodeID]; exists {
|
||||
relation.BackupsIDs = append(relation.BackupsIDs, backup.ID)
|
||||
s.backupToNodeRelations[*leastBusyNodeID] = relation
|
||||
} else {
|
||||
s.backupToNodeRelations[*leastBusyNodeID] = BackupToNodeRelation{
|
||||
*leastBusyNodeID,
|
||||
[]uuid.UUID{backup.ID},
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Successfully triggered scheduled backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
)
|
||||
}
|
||||
|
||||
// GetRemainedBackupTryCount returns the number of remaining backup tries for a given backup.
|
||||
// If the backup is not failed or the backup config does not allow retries, it returns 0.
|
||||
// If the backup is failed and the backup config allows retries, it returns the number of remaining tries.
|
||||
// If the backup is failed and the backup config does not allow retries, it returns 0.
|
||||
func (s *BackupsScheduler) GetRemainedBackupTryCount(lastBackup *backups_core.Backup) int {
|
||||
if lastBackup == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if lastBackup.Status != backups_core.BackupStatusFailed {
|
||||
return 0
|
||||
}
|
||||
|
||||
if lastBackup.IsSkipRetry {
|
||||
return 0
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(lastBackup.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
if !backupConfig.IsRetryIfFailed {
|
||||
return 0
|
||||
}
|
||||
|
||||
maxFailedTriesCount := backupConfig.MaxFailedTriesCount
|
||||
|
||||
lastBackups, err := s.backupRepository.FindByDatabaseIDWithLimit(
|
||||
lastBackup.DatabaseID,
|
||||
maxFailedTriesCount,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to find last backups by database ID", "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
lastFailedBackups := make([]*backups_core.Backup, 0)
|
||||
|
||||
for _, backup := range lastBackups {
|
||||
if backup.Status == backups_core.BackupStatusFailed {
|
||||
lastFailedBackups = append(lastFailedBackups, backup)
|
||||
}
|
||||
}
|
||||
|
||||
return maxFailedTriesCount - len(lastFailedBackups)
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) runPendingBackups() error {
|
||||
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
if backupConfig.BackupInterval == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
lastBackup, err := s.backupRepository.FindLastByDatabaseID(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to get last backup for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
var lastBackupTime *time.Time
|
||||
if lastBackup != nil {
|
||||
lastBackupTime = &lastBackup.CreatedAt
|
||||
}
|
||||
|
||||
remainedBackupTryCount := s.GetRemainedBackupTryCount(lastBackup)
|
||||
|
||||
if backupConfig.BackupInterval.ShouldTriggerBackup(time.Now().UTC(), lastBackupTime) ||
|
||||
remainedBackupTryCount > 0 {
|
||||
s.logger.Info(
|
||||
"Triggering scheduled backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"intervalType",
|
||||
backupConfig.BackupInterval.Interval,
|
||||
)
|
||||
|
||||
database, err := s.databaseService.GetDatabaseByID(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get database by ID", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
s.StartBackup(database, remainedBackupTryCount == 1)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) failBackupsInProgress() error {
|
||||
backupsInProgress, err := s.backupRepository.FindByStatus(backups_core.BackupStatusInProgress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backup := range backupsInProgress {
|
||||
if err := s.taskCancelManager.CancelTask(backup.ID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to cancel backup via task cancel manager",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(backup.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
failMessage := "Backup failed due to application restart"
|
||||
backup.FailMessage = &failMessage
|
||||
backup.Status = backups_core.BackupStatusFailed
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
s.backuperNode.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupFailed,
|
||||
&failMessage,
|
||||
)
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
|
||||
nodes, err := s.backupNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get available nodes: %w", err)
|
||||
}
|
||||
|
||||
if len(nodes) == 0 {
|
||||
return nil, fmt.Errorf("no nodes available")
|
||||
}
|
||||
|
||||
stats, err := s.backupNodesRegistry.GetBackupNodesStats()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get backup nodes stats: %w", err)
|
||||
}
|
||||
|
||||
statsMap := make(map[uuid.UUID]int)
|
||||
for _, stat := range stats {
|
||||
statsMap[stat.ID] = stat.ActiveBackups
|
||||
}
|
||||
|
||||
var bestNode *BackupNode
|
||||
var bestScore float64 = -1
|
||||
|
||||
for i := range nodes {
|
||||
node := &nodes[i]
|
||||
|
||||
activeBackups := statsMap[node.ID]
|
||||
|
||||
var score float64
|
||||
if node.ThroughputMBs > 0 {
|
||||
score = float64(activeBackups) / float64(node.ThroughputMBs)
|
||||
} else {
|
||||
score = float64(activeBackups) * 1000
|
||||
}
|
||||
|
||||
if bestNode == nil || score < bestScore {
|
||||
bestNode = node
|
||||
bestScore = score
|
||||
}
|
||||
}
|
||||
|
||||
if bestNode == nil {
|
||||
return nil, fmt.Errorf("no suitable nodes available")
|
||||
}
|
||||
|
||||
return &bestNode.ID, nil
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) onBackupCompleted(nodeID, backupID uuid.UUID) {
|
||||
// Verify this task is actually a backup (registry contains multiple task types)
|
||||
_, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
// Not a backup task, ignore it
|
||||
return
|
||||
}
|
||||
|
||||
relation, exists := s.backupToNodeRelations[nodeID]
|
||||
if !exists {
|
||||
s.logger.Warn(
|
||||
"Received completion for unknown node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
newBackupIDs := make([]uuid.UUID, 0)
|
||||
found := false
|
||||
for _, id := range relation.BackupsIDs {
|
||||
if id == backupID {
|
||||
found = true
|
||||
continue
|
||||
}
|
||||
newBackupIDs = append(newBackupIDs, id)
|
||||
}
|
||||
|
||||
if !found {
|
||||
s.logger.Warn(
|
||||
"Backup not found in node's backup list",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if len(newBackupIDs) == 0 {
|
||||
delete(s.backupToNodeRelations, nodeID)
|
||||
} else {
|
||||
relation.BackupsIDs = newBackupIDs
|
||||
s.backupToNodeRelations[nodeID] = relation
|
||||
}
|
||||
|
||||
if err := s.backupNodesRegistry.DecrementBackupsInProgress(nodeID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) checkDeadNodesAndFailBackups() error {
|
||||
nodes, err := s.backupNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get available nodes: %w", err)
|
||||
}
|
||||
|
||||
aliveNodeIDs := make(map[uuid.UUID]bool)
|
||||
for _, node := range nodes {
|
||||
aliveNodeIDs[node.ID] = true
|
||||
}
|
||||
|
||||
for nodeID, relation := range s.backupToNodeRelations {
|
||||
if aliveNodeIDs[nodeID] {
|
||||
continue
|
||||
}
|
||||
|
||||
s.logger.Warn(
|
||||
"Node is dead, failing its backups",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupCount",
|
||||
len(relation.BackupsIDs),
|
||||
)
|
||||
|
||||
for _, backupID := range relation.BackupsIDs {
|
||||
backup, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to find backup for dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
failMessage := "Backup failed due to node unavailability"
|
||||
backup.FailMessage = &failMessage
|
||||
backup.Status = backups_core.BackupStatusFailed
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to save failed backup for dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := s.backupNodesRegistry.DecrementBackupsInProgress(nodeID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress for dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Failed backup due to dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
)
|
||||
}
|
||||
|
||||
delete(s.backupToNodeRelations, nodeID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
318
backend/internal/features/backups/backups/backuping/testing.go
Normal file
318
backend/internal/features/backups/backups/backuping/testing.go
Normal file
@@ -0,0 +1,318 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/usecases"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
func CreateTestRouter() *gin.Engine {
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
databases.GetDatabaseController(),
|
||||
backups_config.GetBackupConfigController(),
|
||||
)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func CreateTestBackuperNode() *BackuperNode {
|
||||
return &BackuperNode{
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
workspaceService: workspaces_services.GetWorkspaceService(),
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
notificationSender: notifiers.GetNotifierService(),
|
||||
backupCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
logger: logger.GetLogger(),
|
||||
createBackupUseCase: usecases.GetCreateBackupUsecase(),
|
||||
nodeID: uuid.New(),
|
||||
lastHeartbeat: time.Time{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestBackuperNodeWithUseCase(useCase backups_core.CreateBackupUsecase) *BackuperNode {
|
||||
return &BackuperNode{
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
workspaceService: workspaces_services.GetWorkspaceService(),
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
notificationSender: notifiers.GetNotifierService(),
|
||||
backupCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
logger: logger.GetLogger(),
|
||||
createBackupUseCase: useCase,
|
||||
nodeID: uuid.New(),
|
||||
lastHeartbeat: time.Time{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestScheduler() *BackupsScheduler {
|
||||
return &BackupsScheduler{
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
taskCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
lastBackupTime: time.Now().UTC(),
|
||||
logger: logger.GetLogger(),
|
||||
backupToNodeRelations: make(map[uuid.UUID]BackupToNodeRelation),
|
||||
backuperNode: CreateTestBackuperNode(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
// WaitForBackupCompletion waits for a new backup to be created and completed (or failed)
|
||||
// for the given database. It checks for backups with count greater than expectedInitialCount.
|
||||
func WaitForBackupCompletion(
|
||||
t *testing.T,
|
||||
databaseID uuid.UUID,
|
||||
expectedInitialCount int,
|
||||
timeout time.Duration,
|
||||
) {
|
||||
deadline := time.Now().UTC().Add(timeout)
|
||||
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
backups, err := backupRepository.FindByDatabaseID(databaseID)
|
||||
if err != nil {
|
||||
t.Logf("WaitForBackupCompletion: error finding backups: %v", err)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
t.Logf(
|
||||
"WaitForBackupCompletion: found %d backups (expected > %d)",
|
||||
len(backups),
|
||||
expectedInitialCount,
|
||||
)
|
||||
|
||||
if len(backups) > expectedInitialCount {
|
||||
// Check if the newest backup has completed or failed
|
||||
newestBackup := backups[0]
|
||||
t.Logf("WaitForBackupCompletion: newest backup status: %s", newestBackup.Status)
|
||||
|
||||
if newestBackup.Status == backups_core.BackupStatusCompleted ||
|
||||
newestBackup.Status == backups_core.BackupStatusFailed ||
|
||||
newestBackup.Status == backups_core.BackupStatusCanceled {
|
||||
t.Logf(
|
||||
"WaitForBackupCompletion: backup finished with status %s",
|
||||
newestBackup.Status,
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Logf("WaitForBackupCompletion: timeout waiting for backup to complete")
|
||||
}
|
||||
|
||||
// StartBackuperNodeForTest starts a BackuperNode in a goroutine for testing.
|
||||
// The node registers itself in the registry and subscribes to backup assignments.
|
||||
// Returns a context cancel function that should be deferred to stop the node.
|
||||
func StartBackuperNodeForTest(t *testing.T, backuperNode *BackuperNode) context.CancelFunc {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
backuperNode.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Poll registry for node presence instead of fixed sleep
|
||||
deadline := time.Now().UTC().Add(5 * time.Second)
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
nodes, err := backupNodesRegistry.GetAvailableNodes()
|
||||
if err == nil {
|
||||
for _, node := range nodes {
|
||||
if node.ID == backuperNode.nodeID {
|
||||
t.Logf("BackuperNode registered in registry: %s", backuperNode.nodeID)
|
||||
|
||||
return func() {
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("BackuperNode stopped gracefully")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Log("BackuperNode stop timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Fatalf("BackuperNode failed to register in registry within timeout")
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartSchedulerForTest starts the BackupsScheduler in a goroutine for testing.
|
||||
// The scheduler subscribes to task completions and manages backup lifecycle.
|
||||
// Returns a context cancel function that should be deferred to stop the scheduler.
|
||||
func StartSchedulerForTest(t *testing.T, scheduler *BackupsScheduler) context.CancelFunc {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
scheduler.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Give scheduler time to subscribe to completions
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
t.Log("BackupsScheduler started")
|
||||
|
||||
return func() {
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("BackupsScheduler stopped gracefully")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Log("BackupsScheduler stop timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StopBackuperNodeForTest stops the BackuperNode by canceling its context.
|
||||
// It waits for the node to unregister from the registry.
|
||||
func StopBackuperNodeForTest(t *testing.T, cancel context.CancelFunc, backuperNode *BackuperNode) {
|
||||
cancel()
|
||||
|
||||
// Wait for node to unregister from registry
|
||||
deadline := time.Now().UTC().Add(2 * time.Second)
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
nodes, err := backupNodesRegistry.GetAvailableNodes()
|
||||
if err == nil {
|
||||
found := false
|
||||
for _, node := range nodes {
|
||||
if node.ID == backuperNode.nodeID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Logf("BackuperNode unregistered from registry: %s", backuperNode.nodeID)
|
||||
return
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Logf("BackuperNode stop completed for %s", backuperNode.nodeID)
|
||||
}
|
||||
|
||||
func CreateMockNodeInRegistry(nodeID uuid.UUID, throughputMBs int, lastHeartbeat time.Time) error {
|
||||
backupNode := BackupNode{
|
||||
ID: nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: lastHeartbeat,
|
||||
}
|
||||
|
||||
return backupNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
|
||||
}
|
||||
|
||||
func UpdateNodeHeartbeatDirectly(
|
||||
nodeID uuid.UUID,
|
||||
throughputMBs int,
|
||||
lastHeartbeat time.Time,
|
||||
) error {
|
||||
backupNode := BackupNode{
|
||||
ID: nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: lastHeartbeat,
|
||||
}
|
||||
|
||||
return backupNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
|
||||
}
|
||||
|
||||
func GetNodeFromRegistry(nodeID uuid.UUID) (*BackupNode, error) {
|
||||
nodes, err := backupNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
if node.ID == nodeID {
|
||||
return &node, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("node not found")
|
||||
}
|
||||
|
||||
// WaitForActiveTasksDecrease waits for the active task count to decrease below the initial count.
|
||||
// It polls the registry every 500ms until the count decreases or the timeout is reached.
|
||||
// Returns true if the count decreased, false if timeout was reached.
|
||||
func WaitForActiveTasksDecrease(
|
||||
t *testing.T,
|
||||
nodeID uuid.UUID,
|
||||
initialCount int,
|
||||
timeout time.Duration,
|
||||
) bool {
|
||||
deadline := time.Now().UTC().Add(timeout)
|
||||
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
stats, err := backupNodesRegistry.GetBackupNodesStats()
|
||||
if err != nil {
|
||||
t.Logf("WaitForActiveTasksDecrease: error getting node stats: %v", err)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, stat := range stats {
|
||||
if stat.ID == nodeID {
|
||||
t.Logf(
|
||||
"WaitForActiveTasksDecrease: current active tasks = %d (initial = %d)",
|
||||
stat.ActiveBackups,
|
||||
initialCount,
|
||||
)
|
||||
if stat.ActiveBackups < initialCount {
|
||||
t.Logf(
|
||||
"WaitForActiveTasksDecrease: active tasks decreased from %d to %d",
|
||||
initialCount,
|
||||
stat.ActiveBackups,
|
||||
)
|
||||
return true
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Logf("WaitForActiveTasksDecrease: timeout waiting for active tasks to decrease")
|
||||
return false
|
||||
}
|
||||
38
backend/internal/features/backups/backups/common/dto.go
Normal file
38
backend/internal/features/backups/backups/common/dto.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
)
|
||||
|
||||
type BackupMetadata struct {
|
||||
BackupID uuid.UUID `json:"backupId"`
|
||||
EncryptionSalt *string `json:"encryptionSalt"`
|
||||
EncryptionIV *string `json:"encryptionIV"`
|
||||
Encryption backups_config.BackupEncryption `json:"encryption"`
|
||||
}
|
||||
|
||||
func (m *BackupMetadata) Validate() error {
|
||||
if m.BackupID == uuid.Nil {
|
||||
return errors.New("backup ID is required")
|
||||
}
|
||||
|
||||
if m.Encryption == "" {
|
||||
return errors.New("encryption is required")
|
||||
}
|
||||
|
||||
if m.Encryption == backups_config.BackupEncryptionEncrypted {
|
||||
if m.EncryptionSalt == nil {
|
||||
return errors.New("encryption salt is required when encryption is enabled")
|
||||
}
|
||||
|
||||
if m.EncryptionIV == nil {
|
||||
return errors.New("encryption IV is required when encryption is enabled")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -7,6 +7,10 @@ type CountingWriter struct {
|
||||
BytesWritten int64
|
||||
}
|
||||
|
||||
func NewCountingWriter(writer io.Writer) *CountingWriter {
|
||||
return &CountingWriter{Writer: writer}
|
||||
}
|
||||
|
||||
func (cw *CountingWriter) Write(p []byte) (n int, err error) {
|
||||
n, err = cw.Writer.Write(p)
|
||||
cw.BytesWritten += int64(n)
|
||||
@@ -16,7 +20,3 @@ func (cw *CountingWriter) Write(p []byte) (n int, err error) {
|
||||
func (cw *CountingWriter) GetBytesWritten() int64 {
|
||||
return cw.BytesWritten
|
||||
}
|
||||
|
||||
func NewCountingWriter(writer io.Writer) *CountingWriter {
|
||||
return &CountingWriter{Writer: writer}
|
||||
}
|
||||
@@ -1,216 +0,0 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/features/databases"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BackupController struct {
|
||||
backupService *BackupService
|
||||
}
|
||||
|
||||
func (c *BackupController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
router.GET("/backups", c.GetBackups)
|
||||
router.POST("/backups", c.MakeBackup)
|
||||
router.GET("/backups/:id/file", c.GetFile)
|
||||
router.DELETE("/backups/:id", c.DeleteBackup)
|
||||
router.POST("/backups/:id/cancel", c.CancelBackup)
|
||||
}
|
||||
|
||||
// GetBackups
|
||||
// @Summary Get backups for a database
|
||||
// @Description Get paginated backups for the specified database
|
||||
// @Tags backups
|
||||
// @Produce json
|
||||
// @Param database_id query string true "Database ID"
|
||||
// @Param limit query int false "Number of items per page" default(10)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Success 200 {object} GetBackupsResponse
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Router /backups [get]
|
||||
func (c *BackupController) GetBackups(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var request GetBackupsRequest
|
||||
if err := ctx.ShouldBindQuery(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
databaseID, err := uuid.Parse(request.DatabaseID)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database_id"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.backupService.GetBackups(user, databaseID, request.Limit, request.Offset)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// MakeBackup
|
||||
// @Summary Create a backup
|
||||
// @Description Create a new backup for the specified database
|
||||
// @Tags backups
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body MakeBackupRequest true "Backup creation data"
|
||||
// @Success 200 {object} map[string]string
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Router /backups [post]
|
||||
func (c *BackupController) MakeBackup(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var request MakeBackupRequest
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.backupService.MakeBackupWithAuth(user, request.DatabaseID); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, gin.H{"message": "backup started successfully"})
|
||||
}
|
||||
|
||||
// DeleteBackup
|
||||
// @Summary Delete a backup
|
||||
// @Description Delete an existing backup
|
||||
// @Tags backups
|
||||
// @Param id path string true "Backup ID"
|
||||
// @Success 204
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Router /backups/{id} [delete]
|
||||
func (c *BackupController) DeleteBackup(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.backupService.DeleteBackup(user, id); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// CancelBackup
|
||||
// @Summary Cancel an in-progress backup
|
||||
// @Description Cancel a backup that is currently in progress
|
||||
// @Tags backups
|
||||
// @Param id path string true "Backup ID"
|
||||
// @Success 204
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Router /backups/{id}/cancel [post]
|
||||
func (c *BackupController) CancelBackup(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.backupService.CancelBackup(user, id); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// GetFile
|
||||
// @Summary Download a backup file
|
||||
// @Description Download the backup file for the specified backup
|
||||
// @Tags backups
|
||||
// @Param id path string true "Backup ID"
|
||||
// @Success 200 {file} file
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Router /backups/{id}/file [get]
|
||||
func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
|
||||
return
|
||||
}
|
||||
|
||||
fileReader, dbType, err := c.backupService.GetBackupFile(user, id)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := fileReader.Close(); err != nil {
|
||||
fmt.Printf("Error closing file reader: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
extension := ".dump"
|
||||
if dbType == databases.DatabaseTypeMysql || dbType == databases.DatabaseTypeMariadb {
|
||||
extension = ".sql.zst"
|
||||
}
|
||||
|
||||
ctx.Header("Content-Type", "application/octet-stream")
|
||||
ctx.Header(
|
||||
"Content-Disposition",
|
||||
fmt.Sprintf("attachment; filename=\"backup_%s%s\"", id.String(), extension),
|
||||
)
|
||||
|
||||
_, err = io.Copy(ctx.Writer, fileReader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "failed to stream file"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type MakeBackupRequest struct {
|
||||
DatabaseID uuid.UUID `json:"database_id" binding:"required"`
|
||||
}
|
||||
@@ -1,711 +0,0 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"databasus-backend/internal/features/storages"
|
||||
local_storage "databasus-backend/internal/features/storages/models/local"
|
||||
users_dto "databasus-backend/internal/features/users/dto"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_models "databasus-backend/internal/features/workspaces/models"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
"databasus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func Test_GetBackups_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
workspaceRole *users_enums.WorkspaceRole
|
||||
isGlobalAdmin bool
|
||||
expectSuccess bool
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "workspace viewer can get backups",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace member can get backups",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "non-member cannot get backups",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "global admin can get backups",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: true,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, _ := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUserToken = admin.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
if *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
|
||||
testUserToken = owner.Token
|
||||
} else {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
testUserToken = member.Token
|
||||
}
|
||||
} else {
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
testUserToken = nonMember.Token
|
||||
}
|
||||
|
||||
testResp := test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups?database_id=%s", database.ID.String()),
|
||||
"Bearer "+testUserToken,
|
||||
tt.expectedStatusCode,
|
||||
)
|
||||
|
||||
if tt.expectSuccess {
|
||||
var response GetBackupsResponse
|
||||
err := json.Unmarshal(testResp.Body, &response)
|
||||
assert.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(response.Backups), 1)
|
||||
assert.GreaterOrEqual(t, response.Total, int64(1))
|
||||
} else {
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateBackup_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
workspaceRole *users_enums.WorkspaceRole
|
||||
isGlobalAdmin bool
|
||||
expectSuccess bool
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "workspace owner can create backup",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace member can create backup",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace viewer can create backup",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "non-member cannot create backup",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "global admin can create backup",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: true,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
enableBackupForDatabase(database.ID)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUserToken = admin.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
if *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
|
||||
testUserToken = owner.Token
|
||||
} else {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
testUserToken = member.Token
|
||||
}
|
||||
} else {
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
testUserToken = nonMember.Token
|
||||
}
|
||||
|
||||
request := MakeBackupRequest{DatabaseID: database.ID}
|
||||
testResp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backups",
|
||||
"Bearer "+testUserToken,
|
||||
request,
|
||||
tt.expectedStatusCode,
|
||||
)
|
||||
|
||||
if tt.expectSuccess {
|
||||
assert.Contains(t, string(testResp.Body), "backup started successfully")
|
||||
} else {
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateBackup_AuditLogWritten(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
enableBackupForDatabase(database.ID)
|
||||
|
||||
request := MakeBackupRequest{DatabaseID: database.ID}
|
||||
test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backups",
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
auditLogService := audit_logs.GetAuditLogService()
|
||||
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
|
||||
workspace.ID,
|
||||
&audit_logs.GetAuditLogsRequest{
|
||||
Limit: 100,
|
||||
Offset: 0,
|
||||
},
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
found := false
|
||||
for _, log := range auditLogs.AuditLogs {
|
||||
if strings.Contains(log.Message, "Backup manually initiated") &&
|
||||
strings.Contains(log.Message, database.Name) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Audit log for backup creation not found")
|
||||
}
|
||||
|
||||
func Test_DeleteBackup_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
workspaceRole *users_enums.WorkspaceRole
|
||||
isGlobalAdmin bool
|
||||
expectSuccess bool
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "workspace owner can delete backup",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusNoContent,
|
||||
},
|
||||
{
|
||||
name: "workspace member can delete backup",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusNoContent,
|
||||
},
|
||||
{
|
||||
name: "workspace viewer cannot delete backup",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "non-member cannot delete backup",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "global admin can delete backup",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: true,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusNoContent,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUserToken = admin.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
if *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
|
||||
testUserToken = owner.Token
|
||||
} else {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
testUserToken = member.Token
|
||||
}
|
||||
} else {
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
testUserToken = nonMember.Token
|
||||
}
|
||||
|
||||
testResp := test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s", backup.ID.String()),
|
||||
"Bearer "+testUserToken,
|
||||
tt.expectedStatusCode,
|
||||
)
|
||||
|
||||
if !tt.expectSuccess {
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
} else {
|
||||
userService := users_services.GetUserService()
|
||||
ownerUser, err := userService.GetUserFromToken(owner.Token)
|
||||
assert.NoError(t, err)
|
||||
|
||||
response, err := GetBackupService().GetBackups(ownerUser, database.ID, 10, 0)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, len(response.Backups))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_DeleteBackup_AuditLogWritten(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusNoContent,
|
||||
)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
auditLogService := audit_logs.GetAuditLogService()
|
||||
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
|
||||
workspace.ID,
|
||||
&audit_logs.GetAuditLogsRequest{
|
||||
Limit: 100,
|
||||
Offset: 0,
|
||||
},
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
found := false
|
||||
for _, log := range auditLogs.AuditLogs {
|
||||
if strings.Contains(log.Message, "Backup deleted") &&
|
||||
strings.Contains(log.Message, database.Name) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Audit log for backup deletion not found")
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
workspaceRole *users_enums.WorkspaceRole
|
||||
isGlobalAdmin bool
|
||||
expectSuccess bool
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "workspace viewer can download backup",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace member can download backup",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "non-member cannot download backup",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "global admin can download backup",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: true,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUserToken = admin.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
if *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
|
||||
testUserToken = owner.Token
|
||||
} else {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
testUserToken = member.Token
|
||||
}
|
||||
} else {
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
testUserToken = nonMember.Token
|
||||
}
|
||||
|
||||
testResp := test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file", backup.ID.String()),
|
||||
"Bearer "+testUserToken,
|
||||
tt.expectedStatusCode,
|
||||
)
|
||||
|
||||
if !tt.expectSuccess {
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
auditLogService := audit_logs.GetAuditLogService()
|
||||
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
|
||||
workspace.ID,
|
||||
&audit_logs.GetAuditLogsRequest{
|
||||
Limit: 100,
|
||||
Offset: 0,
|
||||
},
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
found := false
|
||||
for _, log := range auditLogs.AuditLogs {
|
||||
if strings.Contains(log.Message, "Backup file downloaded") &&
|
||||
strings.Contains(log.Message, database.Name) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Audit log for backup download not found")
|
||||
}
|
||||
|
||||
func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
configService := backups_config.GetBackupConfigService()
|
||||
config, err := configService.GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
config.IsBackupsEnabled = true
|
||||
config.StorageID = &storage.ID
|
||||
config.Storage = storage
|
||||
_, err = configService.SaveBackupConfig(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backup := &Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: BackupStatusInProgress,
|
||||
BackupSizeMb: 0,
|
||||
BackupDurationMs: 0,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
repo := &BackupRepository{}
|
||||
err = repo.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Register a cancellable context for the backup
|
||||
GetBackupService().backupContextManager.RegisterBackup(backup.ID, func() {})
|
||||
|
||||
resp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/cancel", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusNoContent,
|
||||
)
|
||||
|
||||
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
|
||||
|
||||
// Verify audit log was created
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
userService := users_services.GetUserService()
|
||||
adminUser, err := userService.GetUserFromToken(admin.Token)
|
||||
assert.NoError(t, err)
|
||||
|
||||
auditLogService := audit_logs.GetAuditLogService()
|
||||
auditLogs, err := auditLogService.GetGlobalAuditLogs(
|
||||
adminUser,
|
||||
&audit_logs.GetAuditLogsRequest{Limit: 100, Offset: 0},
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
foundCancelLog := false
|
||||
for _, log := range auditLogs.AuditLogs {
|
||||
if strings.Contains(log.Message, "Backup cancelled") &&
|
||||
strings.Contains(log.Message, database.Name) {
|
||||
foundCancelLog = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, foundCancelLog, "Cancel audit log should be created")
|
||||
}
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
return CreateTestRouter()
|
||||
}
|
||||
|
||||
func createTestDatabase(
|
||||
name string,
|
||||
workspaceID uuid.UUID,
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *databases.Database {
|
||||
testDbName := "test_db"
|
||||
request := databases.Database{
|
||||
Name: name,
|
||||
WorkspaceID: &workspaceID,
|
||||
Type: databases.DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
router,
|
||||
"POST",
|
||||
"/api/v1/databases/create",
|
||||
"Bearer "+token,
|
||||
request,
|
||||
)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
panic(
|
||||
fmt.Sprintf("Failed to create database. Status: %d, Body: %s", w.Code, w.Body.String()),
|
||||
)
|
||||
}
|
||||
|
||||
var database databases.Database
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return &database
|
||||
}
|
||||
|
||||
func createTestStorage(workspaceID uuid.UUID) *storages.Storage {
|
||||
storage := &storages.Storage{
|
||||
WorkspaceID: workspaceID,
|
||||
Type: storages.StorageTypeLocal,
|
||||
Name: "Test Storage " + uuid.New().String(),
|
||||
LocalStorage: &local_storage.LocalStorage{},
|
||||
}
|
||||
|
||||
repo := &storages.StorageRepository{}
|
||||
storage, err := repo.Save(storage)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return storage
|
||||
}
|
||||
|
||||
func enableBackupForDatabase(databaseID uuid.UUID) {
|
||||
configService := backups_config.GetBackupConfigService()
|
||||
config, err := configService.GetBackupConfigByDbId(databaseID)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
config.IsBackupsEnabled = true
|
||||
_, err = configService.SaveBackupConfig(config)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func createTestDatabaseWithBackups(
|
||||
workspace *workspaces_models.Workspace,
|
||||
owner *users_dto.SignInResponseDTO,
|
||||
router *gin.Engine,
|
||||
) (*databases.Database, *Backup) {
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
configService := backups_config.GetBackupConfigService()
|
||||
config, err := configService.GetBackupConfigByDbId(database.ID)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
config.IsBackupsEnabled = true
|
||||
config.StorageID = &storage.ID
|
||||
config.Storage = storage
|
||||
_, err = configService.SaveBackupConfig(config)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
backup := createTestBackup(database, owner)
|
||||
|
||||
return database, backup
|
||||
}
|
||||
|
||||
func createTestBackup(
|
||||
database *databases.Database,
|
||||
owner *users_dto.SignInResponseDTO,
|
||||
) *Backup {
|
||||
userService := users_services.GetUserService()
|
||||
user, err := userService.GetUserFromToken(owner.Token)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
storages, err := storages.GetStorageService().GetStorages(user, *database.WorkspaceID)
|
||||
if err != nil || len(storages) == 0 {
|
||||
panic("No storage found for workspace")
|
||||
}
|
||||
|
||||
backup := &Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storages[0].ID,
|
||||
Status: BackupStatusCompleted,
|
||||
BackupSizeMb: 10.5,
|
||||
BackupDurationMs: 1000,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
repo := &BackupRepository{}
|
||||
if err := repo.Save(backup); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Create a dummy backup file for testing download functionality
|
||||
dummyContent := []byte("dummy backup content for testing")
|
||||
reader := strings.NewReader(string(dummyContent))
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
if err := storages[0].SaveFile(context.Background(), encryption.GetFieldEncryptor(), logger, backup.ID, reader); err != nil {
|
||||
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
|
||||
}
|
||||
|
||||
return backup
|
||||
}
|
||||
@@ -0,0 +1,361 @@
|
||||
package backups_controllers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
backups_dto "databasus-backend/internal/features/backups/backups/dto"
|
||||
backups_services "databasus-backend/internal/features/backups/backups/services"
|
||||
"databasus-backend/internal/features/databases"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
files_utils "databasus-backend/internal/util/files"
|
||||
)
|
||||
|
||||
type BackupController struct {
|
||||
backupService *backups_services.BackupService
|
||||
}
|
||||
|
||||
func (c *BackupController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
router.GET("/backups", c.GetBackups)
|
||||
router.POST("/backups", c.MakeBackup)
|
||||
router.POST("/backups/:id/download-token", c.GenerateDownloadToken)
|
||||
router.DELETE("/backups/:id", c.DeleteBackup)
|
||||
router.POST("/backups/:id/cancel", c.CancelBackup)
|
||||
}
|
||||
|
||||
// RegisterPublicRoutes registers routes that don't require Bearer authentication
|
||||
// (they have their own authentication mechanisms like download tokens)
|
||||
func (c *BackupController) RegisterPublicRoutes(router *gin.RouterGroup) {
|
||||
router.GET("/backups/:id/file", c.GetFile)
|
||||
}
|
||||
|
||||
// GetBackups
|
||||
// @Summary Get backups for a database
|
||||
// @Description Get paginated backups for the specified database
|
||||
// @Tags backups
|
||||
// @Produce json
|
||||
// @Param database_id query string true "Database ID"
|
||||
// @Param limit query int false "Number of items per page" default(10)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Success 200 {object} backups_dto.GetBackupsResponse
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Router /backups [get]
|
||||
func (c *BackupController) GetBackups(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var request backups_dto.GetBackupsRequest
|
||||
if err := ctx.ShouldBindQuery(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
databaseID, err := uuid.Parse(request.DatabaseID)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database_id"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.backupService.GetBackups(user, databaseID, request.Limit, request.Offset)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// MakeBackup
|
||||
// @Summary Create a backup
|
||||
// @Description Create a new backup for the specified database
|
||||
// @Tags backups
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body backups_dto.MakeBackupRequest true "Backup creation data"
|
||||
// @Success 200 {object} map[string]string
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Router /backups [post]
|
||||
func (c *BackupController) MakeBackup(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var request backups_dto.MakeBackupRequest
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.backupService.MakeBackupWithAuth(user, request.DatabaseID); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, gin.H{"message": "backup started successfully"})
|
||||
}
|
||||
|
||||
// DeleteBackup
|
||||
// @Summary Delete a backup
|
||||
// @Description Delete an existing backup
|
||||
// @Tags backups
|
||||
// @Param id path string true "Backup ID"
|
||||
// @Success 204
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Router /backups/{id} [delete]
|
||||
func (c *BackupController) DeleteBackup(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.backupService.DeleteBackup(user, id); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// CancelBackup
|
||||
// @Summary Cancel an in-progress backup
|
||||
// @Description Cancel a backup that is currently in progress
|
||||
// @Tags backups
|
||||
// @Param id path string true "Backup ID"
|
||||
// @Success 204
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Router /backups/{id}/cancel [post]
|
||||
func (c *BackupController) CancelBackup(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.backupService.CancelBackup(user, id); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// GenerateDownloadToken
|
||||
// @Summary Generate short-lived download token
|
||||
// @Description Generate a token for downloading a backup file (valid for 5 minutes)
|
||||
// @Tags backups
|
||||
// @Param id path string true "Backup ID"
|
||||
// @Success 200 {object} backups_download.GenerateDownloadTokenResponse
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 409 {object} map[string]string "Download already in progress"
|
||||
// @Router /backups/{id}/download-token [post]
|
||||
func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.backupService.GenerateDownloadToken(user, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, backups_download.ErrDownloadAlreadyInProgress) {
|
||||
ctx.JSON(
|
||||
http.StatusConflict,
|
||||
gin.H{
|
||||
"error": "Download already in progress for some of backups. Please wait until previous download completed or cancel it",
|
||||
},
|
||||
)
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// GetFile
|
||||
// @Summary Download a backup file
|
||||
// @Description Download the backup file for the specified backup using a download token.
|
||||
// @Description
|
||||
// @Description **Download Concurrency Control:**
|
||||
// @Description - Only one download per user is allowed at a time
|
||||
// @Description - If a download is already in progress, returns 409 Conflict
|
||||
// @Description - Downloads are tracked using cache with 5-second TTL and 3-second heartbeat
|
||||
// @Description - Browser cancellations automatically release the download lock
|
||||
// @Description - Server crashes are handled via automatic cache expiry (5 seconds)
|
||||
// @Tags backups
|
||||
// @Param id path string true "Backup ID"
|
||||
// @Param token query string true "Download token"
|
||||
// @Success 200 {file} file
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 409 {object} map[string]string "Download already in progress"
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /backups/{id}/file [get]
|
||||
func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
token := ctx.Query("token")
|
||||
if token == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "download token is required"})
|
||||
return
|
||||
}
|
||||
|
||||
backupIDParam := ctx.Param("id")
|
||||
backupID, err := uuid.Parse(backupIDParam)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
|
||||
return
|
||||
}
|
||||
|
||||
downloadToken, rateLimiter, err := c.backupService.ValidateDownloadToken(token)
|
||||
if err != nil {
|
||||
if errors.Is(err, backups_download.ErrDownloadAlreadyInProgress) {
|
||||
ctx.JSON(
|
||||
http.StatusConflict,
|
||||
gin.H{
|
||||
"error": "download already in progress for this user. Please wait until previous download completed or cancel it",
|
||||
},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"})
|
||||
return
|
||||
}
|
||||
|
||||
if downloadToken.BackupID != backupID {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"})
|
||||
return
|
||||
}
|
||||
|
||||
fileReader, backup, database, err := c.backupService.GetBackupFileWithoutAuth(
|
||||
downloadToken.BackupID,
|
||||
)
|
||||
if err != nil {
|
||||
c.backupService.UnregisterDownload(downloadToken.UserID)
|
||||
c.backupService.ReleaseDownloadLock(downloadToken.UserID)
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
rateLimitedReader := backups_download.NewRateLimitedReader(fileReader, rateLimiter)
|
||||
|
||||
heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background())
|
||||
defer func() {
|
||||
cancelHeartbeat()
|
||||
c.backupService.UnregisterDownload(downloadToken.UserID)
|
||||
c.backupService.ReleaseDownloadLock(downloadToken.UserID)
|
||||
if err := rateLimitedReader.Close(); err != nil {
|
||||
fmt.Printf("Error closing file reader: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go c.startDownloadHeartbeat(heartbeatCtx, downloadToken.UserID)
|
||||
|
||||
filename := c.generateBackupFilename(backup, database)
|
||||
|
||||
if backup.BackupSizeMb > 0 {
|
||||
sizeBytes := int64(backup.BackupSizeMb * 1024 * 1024)
|
||||
ctx.Header("Content-Length", fmt.Sprintf("%d", sizeBytes))
|
||||
}
|
||||
|
||||
ctx.Header("Content-Type", "application/octet-stream")
|
||||
ctx.Header(
|
||||
"Content-Disposition",
|
||||
fmt.Sprintf("attachment; filename=\"%s\"", filename),
|
||||
)
|
||||
|
||||
_, err = io.Copy(ctx.Writer, rateLimitedReader)
|
||||
if err != nil {
|
||||
fmt.Printf("Error streaming file: %v\n", err)
|
||||
}
|
||||
|
||||
c.backupService.WriteAuditLogForDownload(downloadToken.UserID, backup, database)
|
||||
}
|
||||
|
||||
func (c *BackupController) generateBackupFilename(
|
||||
backup *backups_core.Backup,
|
||||
database *databases.Database,
|
||||
) string {
|
||||
// Format timestamp as YYYY-MM-DD_HH-mm-ss
|
||||
timestamp := backup.CreatedAt.Format("2006-01-02_15-04-05")
|
||||
|
||||
// Sanitize database name for filename (replace spaces and special chars)
|
||||
safeName := files_utils.SanitizeFilename(database.Name)
|
||||
|
||||
// Determine extension based on database type
|
||||
extension := c.getBackupExtension(database.Type)
|
||||
|
||||
return fmt.Sprintf("%s_backup_%s%s", safeName, timestamp, extension)
|
||||
}
|
||||
|
||||
func (c *BackupController) getBackupExtension(
|
||||
dbType databases.DatabaseType,
|
||||
) string {
|
||||
switch dbType {
|
||||
case databases.DatabaseTypeMysql, databases.DatabaseTypeMariadb:
|
||||
return ".sql.zst"
|
||||
case databases.DatabaseTypePostgres:
|
||||
// PostgreSQL custom format
|
||||
return ".dump"
|
||||
case databases.DatabaseTypeMongodb:
|
||||
return ".archive"
|
||||
default:
|
||||
return ".backup"
|
||||
}
|
||||
}
|
||||
|
||||
func (c *BackupController) startDownloadHeartbeat(ctx context.Context, userID uuid.UUID) {
|
||||
ticker := time.NewTicker(backups_download.GetDownloadHeartbeatInterval())
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.backupService.RefreshDownloadLock(userID)
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
23
backend/internal/features/backups/backups/controllers/di.go
Normal file
23
backend/internal/features/backups/backups/controllers/di.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package backups_controllers
|
||||
|
||||
import (
|
||||
backups_services "databasus-backend/internal/features/backups/backups/services"
|
||||
"databasus-backend/internal/features/databases"
|
||||
)
|
||||
|
||||
var backupController = &BackupController{
|
||||
backups_services.GetBackupService(),
|
||||
}
|
||||
|
||||
func GetBackupController() *BackupController {
|
||||
return backupController
|
||||
}
|
||||
|
||||
var postgresWalBackupController = &PostgreWalBackupController{
|
||||
databases.GetDatabaseService(),
|
||||
backups_services.GetWalService(),
|
||||
}
|
||||
|
||||
func GetPostgresWalBackupController() *PostgreWalBackupController {
|
||||
return postgresWalBackupController
|
||||
}
|
||||
@@ -0,0 +1,291 @@
|
||||
package backups_controllers
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_dto "databasus-backend/internal/features/backups/backups/dto"
|
||||
backups_services "databasus-backend/internal/features/backups/backups/services"
|
||||
"databasus-backend/internal/features/databases"
|
||||
)
|
||||
|
||||
// PostgreWalBackupController handles WAL backup endpoints used by the databasus-cli agent.
|
||||
// Authentication is via a plain agent token in the Authorization header (no Bearer prefix).
|
||||
type PostgreWalBackupController struct {
|
||||
databaseService *databases.DatabaseService
|
||||
walService *backups_services.PostgreWalBackupService
|
||||
}
|
||||
|
||||
func (c *PostgreWalBackupController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
walRoutes := router.Group("/backups/postgres/wal")
|
||||
|
||||
walRoutes.GET("/next-full-backup-time", c.GetNextFullBackupTime)
|
||||
walRoutes.POST("/error", c.ReportError)
|
||||
walRoutes.POST("/upload", c.Upload)
|
||||
walRoutes.GET("/restore/plan", c.GetRestorePlan)
|
||||
walRoutes.GET("/restore/download", c.DownloadBackupFile)
|
||||
}
|
||||
|
||||
// GetNextFullBackupTime
|
||||
// @Summary Get next full backup time
|
||||
// @Description Returns the next scheduled full basebackup time for the authenticated database
|
||||
// @Tags backups-wal
|
||||
// @Produce json
|
||||
// @Security AgentToken
|
||||
// @Success 200 {object} backups_dto.GetNextFullBackupTimeResponse
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /backups/postgres/wal/next-full-backup-time [get]
|
||||
func (c *PostgreWalBackupController) GetNextFullBackupTime(ctx *gin.Context) {
|
||||
database, err := c.getDatabase(ctx)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.walService.GetNextFullBackupTime(database)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// ReportError
|
||||
// @Summary Report agent error
|
||||
// @Description Records a fatal error from the agent against the database record and marks it as errored
|
||||
// @Tags backups-wal
|
||||
// @Accept json
|
||||
// @Security AgentToken
|
||||
// @Param request body backups_dto.ReportErrorRequest true "Error details"
|
||||
// @Success 200
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /backups/postgres/wal/error [post]
|
||||
func (c *PostgreWalBackupController) ReportError(ctx *gin.Context) {
|
||||
database, err := c.getDatabase(ctx)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
|
||||
return
|
||||
}
|
||||
|
||||
var request backups_dto.ReportErrorRequest
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.walService.ReportError(database, request.Error); err != nil {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
// Upload
|
||||
// @Summary Stream upload a basebackup or WAL segment
|
||||
// @Description Accepts a zstd-compressed binary stream and stores it in the database's configured storage.
|
||||
// The server generates the storage filename; agents do not control the destination path.
|
||||
// For WAL segment uploads the server validates the WAL chain and returns 409 if a gap is detected
|
||||
// or 400 if no full backup exists yet (agent should trigger a full basebackup in both cases).
|
||||
// @Tags backups-wal
|
||||
// @Accept application/octet-stream
|
||||
// @Produce json
|
||||
// @Security AgentToken
|
||||
// @Param X-Upload-Type header string true "Upload type" Enums(basebackup, wal)
|
||||
// @Param X-Wal-Segment-Name header string false "24-hex WAL segment identifier (required for wal uploads, e.g. 0000000100000001000000AB)"
|
||||
// @Param X-Wal-Segment-Size header int false "WAL segment size in bytes reported by the PostgreSQL instance (default: 16777216)"
|
||||
// @Param fullBackupWalStartSegment query string false "First WAL segment needed to make the basebackup consistent (required for basebackup uploads)"
|
||||
// @Param fullBackupWalStopSegment query string false "Last WAL segment included in the basebackup (required for basebackup uploads)"
|
||||
// @Success 204
|
||||
// @Failure 400 {object} backups_dto.UploadGapResponse "No full backup exists (error: no_full_backup)"
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 409 {object} backups_dto.UploadGapResponse "WAL chain gap detected (error: gap_detected)"
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /backups/postgres/wal/upload [post]
|
||||
func (c *PostgreWalBackupController) Upload(ctx *gin.Context) {
|
||||
database, err := c.getDatabase(ctx)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
|
||||
return
|
||||
}
|
||||
|
||||
uploadType := backups_core.PgWalUploadType(ctx.GetHeader("X-Upload-Type"))
|
||||
if uploadType != backups_core.PgWalUploadTypeBasebackup &&
|
||||
uploadType != backups_core.PgWalUploadTypeWal {
|
||||
ctx.JSON(
|
||||
http.StatusBadRequest,
|
||||
gin.H{"error": "X-Upload-Type must be 'basebackup' or 'wal'"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
walSegmentName := ""
|
||||
if uploadType == backups_core.PgWalUploadTypeWal {
|
||||
walSegmentName = ctx.GetHeader("X-Wal-Segment-Name")
|
||||
if walSegmentName == "" {
|
||||
ctx.JSON(
|
||||
http.StatusBadRequest,
|
||||
gin.H{"error": "X-Wal-Segment-Name is required for wal uploads"},
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if uploadType == backups_core.PgWalUploadTypeBasebackup {
|
||||
if ctx.Query("fullBackupWalStartSegment") == "" ||
|
||||
ctx.Query("fullBackupWalStopSegment") == "" {
|
||||
ctx.JSON(
|
||||
http.StatusBadRequest,
|
||||
gin.H{
|
||||
"error": "fullBackupWalStartSegment and fullBackupWalStopSegment are required for basebackup uploads",
|
||||
},
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
walSegmentSizeBytes := int64(0)
|
||||
if raw := ctx.GetHeader("X-Wal-Segment-Size"); raw != "" {
|
||||
parsed, parseErr := strconv.ParseInt(raw, 10, 64)
|
||||
if parseErr != nil || parsed <= 0 {
|
||||
ctx.JSON(
|
||||
http.StatusBadRequest,
|
||||
gin.H{"error": "X-Wal-Segment-Size must be a positive integer"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
walSegmentSizeBytes = parsed
|
||||
}
|
||||
|
||||
gapResp, uploadErr := c.walService.UploadWal(
|
||||
ctx.Request.Context(),
|
||||
database,
|
||||
uploadType,
|
||||
walSegmentName,
|
||||
ctx.Query("fullBackupWalStartSegment"),
|
||||
ctx.Query("fullBackupWalStopSegment"),
|
||||
walSegmentSizeBytes,
|
||||
ctx.Request.Body,
|
||||
)
|
||||
|
||||
if uploadErr != nil {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": uploadErr.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if gapResp != nil {
|
||||
if gapResp.Error == "no_full_backup" {
|
||||
ctx.JSON(http.StatusBadRequest, gapResp)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusConflict, gapResp)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// GetRestorePlan
|
||||
// @Summary Get restore plan
|
||||
// @Description Resolves the full backup and all required WAL segments needed for recovery. Validates the WAL chain is continuous.
|
||||
// @Tags backups-wal
|
||||
// @Produce json
|
||||
// @Security AgentToken
|
||||
// @Param backupId query string false "UUID of a specific full backup to restore from; defaults to the most recent"
|
||||
// @Success 200 {object} backups_dto.GetRestorePlanResponse
|
||||
// @Failure 400 {object} map[string]string "Broken WAL chain or no backups available"
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /backups/postgres/wal/restore/plan [get]
|
||||
func (c *PostgreWalBackupController) GetRestorePlan(ctx *gin.Context) {
|
||||
database, err := c.getDatabase(ctx)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
|
||||
return
|
||||
}
|
||||
|
||||
var backupID *uuid.UUID
|
||||
if raw := ctx.Query("backupId"); raw != "" {
|
||||
parsed, parseErr := uuid.Parse(raw)
|
||||
if parseErr != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backupId format"})
|
||||
return
|
||||
}
|
||||
|
||||
backupID = &parsed
|
||||
}
|
||||
|
||||
response, planErr, err := c.walService.GetRestorePlan(database, backupID)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if planErr != nil {
|
||||
ctx.JSON(http.StatusBadRequest, planErr)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// DownloadBackupFile
|
||||
// @Summary Download a backup or WAL segment file for restore
|
||||
// @Description Retrieves the backup file by ID (validated against the authenticated database), decrypts it server-side if encrypted, and streams the zstd-compressed result to the agent
|
||||
// @Tags backups-wal
|
||||
// @Produce application/octet-stream
|
||||
// @Security AgentToken
|
||||
// @Param backupId query string true "Backup ID from the restore plan response"
|
||||
// @Success 200 {file} file
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Router /backups/postgres/wal/restore/download [get]
|
||||
func (c *PostgreWalBackupController) DownloadBackupFile(ctx *gin.Context) {
|
||||
database, err := c.getDatabase(ctx)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
|
||||
return
|
||||
}
|
||||
|
||||
backupIDRaw := ctx.Query("backupId")
|
||||
if backupIDRaw == "" {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "backupId is required"})
|
||||
return
|
||||
}
|
||||
|
||||
backupID, err := uuid.Parse(backupIDRaw)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backupId format"})
|
||||
return
|
||||
}
|
||||
|
||||
reader, err := c.walService.DownloadBackupFile(database, backupID)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer func() { _ = reader.Close() }()
|
||||
|
||||
ctx.Header("Content-Type", "application/octet-stream")
|
||||
ctx.Status(http.StatusOK)
|
||||
|
||||
_, _ = io.Copy(ctx.Writer, reader)
|
||||
}
|
||||
|
||||
func (c *PostgreWalBackupController) getDatabase(
|
||||
ctx *gin.Context,
|
||||
) (*databases.Database, error) {
|
||||
token := ctx.GetHeader("Authorization")
|
||||
return c.databaseService.GetDatabaseByAgentToken(token)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,97 @@
|
||||
package backups_controllers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
)
|
||||
|
||||
func CreateTestRouter() *gin.Engine {
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
databases.GetDatabaseController(),
|
||||
backups_config.GetBackupConfigController(),
|
||||
GetBackupController(),
|
||||
)
|
||||
|
||||
// Register public routes (no auth required - token-based)
|
||||
v1 := router.Group("/api/v1")
|
||||
GetBackupController().RegisterPublicRoutes(v1)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
// WaitForBackupCompletion waits for a new backup to be created and completed (or failed)
|
||||
// for the given database. It checks for backups with count greater than expectedInitialCount.
|
||||
func WaitForBackupCompletion(
|
||||
t *testing.T,
|
||||
databaseID uuid.UUID,
|
||||
expectedInitialCount int,
|
||||
timeout time.Duration,
|
||||
) {
|
||||
deadline := time.Now().UTC().Add(timeout)
|
||||
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
backups, err := backups_core.GetBackupRepository().FindByDatabaseID(databaseID)
|
||||
if err != nil {
|
||||
t.Logf("WaitForBackupCompletion: error finding backups: %v", err)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
t.Logf(
|
||||
"WaitForBackupCompletion: found %d backups (expected > %d)",
|
||||
len(backups),
|
||||
expectedInitialCount,
|
||||
)
|
||||
|
||||
if len(backups) > expectedInitialCount {
|
||||
// Check if the newest backup has completed or failed
|
||||
newestBackup := backups[0]
|
||||
t.Logf("WaitForBackupCompletion: newest backup status: %s", newestBackup.Status)
|
||||
|
||||
if newestBackup.Status == backups_core.BackupStatusCompleted ||
|
||||
newestBackup.Status == backups_core.BackupStatusFailed ||
|
||||
newestBackup.Status == backups_core.BackupStatusCanceled {
|
||||
t.Logf(
|
||||
"WaitForBackupCompletion: backup finished with status %s",
|
||||
newestBackup.Status,
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Logf("WaitForBackupCompletion: timeout waiting for backup to complete")
|
||||
}
|
||||
|
||||
// CreateTestBackup creates a simple test backup record for testing purposes
|
||||
func CreateTestBackup(databaseID, storageID uuid.UUID) *backups_core.Backup {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: databaseID,
|
||||
StorageID: storageID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10.5,
|
||||
BackupDurationMs: 1000,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
repo := &backups_core.BackupRepository{}
|
||||
if err := repo.Save(backup); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return backup
|
||||
}
|
||||
7
backend/internal/features/backups/backups/core/di.go
Normal file
7
backend/internal/features/backups/backups/core/di.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package backups_core
|
||||
|
||||
var backupRepository = &BackupRepository{}
|
||||
|
||||
func GetBackupRepository() *BackupRepository {
|
||||
return backupRepository
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package backups
|
||||
package backups_core
|
||||
|
||||
type BackupStatus string
|
||||
|
||||
@@ -8,3 +8,10 @@ const (
|
||||
BackupStatusFailed BackupStatus = "FAILED"
|
||||
BackupStatusCanceled BackupStatus = "CANCELED"
|
||||
)
|
||||
|
||||
type PgWalUploadType string
|
||||
|
||||
const (
|
||||
PgWalUploadTypeBasebackup PgWalUploadType = "basebackup"
|
||||
PgWalUploadTypeWal PgWalUploadType = "wal"
|
||||
)
|
||||
@@ -1,15 +1,13 @@
|
||||
package backups
|
||||
package backups_core
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
usecases_common "databasus-backend/internal/features/backups/backups/usecases/common"
|
||||
usecases_common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type NotificationSender interface {
|
||||
@@ -23,7 +21,7 @@ type NotificationSender interface {
|
||||
type CreateBackupUsecase interface {
|
||||
Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
58
backend/internal/features/backups/backups/core/model.go
Normal file
58
backend/internal/features/backups/backups/core/model.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package backups_core
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
files_utils "databasus-backend/internal/util/files"
|
||||
)
|
||||
|
||||
type PgWalBackupType string
|
||||
|
||||
const (
|
||||
PgWalBackupTypeFullBackup PgWalBackupType = "PG_FULL_BACKUP"
|
||||
PgWalBackupTypeWalSegment PgWalBackupType = "PG_WAL_SEGMENT"
|
||||
)
|
||||
|
||||
type Backup struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
|
||||
FileName string `json:"fileName" gorm:"column:file_name;type:text;not null"`
|
||||
|
||||
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;not null"`
|
||||
StorageID uuid.UUID `json:"storageId" gorm:"column:storage_id;type:uuid;not null"`
|
||||
|
||||
Status BackupStatus `json:"status" gorm:"column:status;not null"`
|
||||
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`
|
||||
IsSkipRetry bool `json:"isSkipRetry" gorm:"column:is_skip_retry;type:boolean;not null"`
|
||||
|
||||
BackupSizeMb float64 `json:"backupSizeMb" gorm:"column:backup_size_mb;default:0"`
|
||||
|
||||
BackupDurationMs int64 `json:"backupDurationMs" gorm:"column:backup_duration_ms;default:0"`
|
||||
|
||||
EncryptionSalt *string `json:"-" gorm:"column:encryption_salt"`
|
||||
EncryptionIV *string `json:"-" gorm:"column:encryption_iv"`
|
||||
Encryption backups_config.BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
|
||||
|
||||
// Postgres WAL backup specific fields
|
||||
PgWalBackupType *PgWalBackupType `json:"pgWalBackupType" gorm:"column:pg_wal_backup_type;type:text"`
|
||||
PgFullBackupWalStartSegmentName *string `json:"pgFullBackupWalStartSegmentName" gorm:"column:pg_wal_start_segment;type:text"`
|
||||
PgFullBackupWalStopSegmentName *string `json:"pgFullBackupWalStopSegmentName" gorm:"column:pg_wal_stop_segment;type:text"`
|
||||
PgVersion *string `json:"pgVersion" gorm:"column:pg_version;type:text"`
|
||||
PgWalSegmentName *string `json:"pgWalSegmentName" gorm:"column:pg_wal_segment_name;type:text"`
|
||||
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
|
||||
}
|
||||
|
||||
func (b *Backup) GenerateFilename(dbName string) {
|
||||
timestamp := time.Now().UTC()
|
||||
|
||||
b.FileName = fmt.Sprintf(
|
||||
"%s-%s-%s",
|
||||
files_utils.SanitizeFilename(dbName),
|
||||
timestamp.Format("20060102-150405"),
|
||||
b.ID.String(),
|
||||
)
|
||||
}
|
||||
@@ -1,13 +1,13 @@
|
||||
package backups
|
||||
package backups_core
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/storage"
|
||||
"errors"
|
||||
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"databasus-backend/internal/storage"
|
||||
)
|
||||
|
||||
type BackupRepository struct{}
|
||||
@@ -88,7 +88,7 @@ func (r *BackupRepository) FindLastByDatabaseID(databaseID uuid.UUID) (*Backup,
|
||||
Where("database_id = ?", databaseID).
|
||||
Order("created_at DESC").
|
||||
First(&backup).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -212,3 +212,167 @@ func (r *BackupRepository) CountByDatabaseID(databaseID uuid.UUID) (int64, error
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *BackupRepository) GetTotalSizeByDatabase(databaseID uuid.UUID) (float64, error) {
|
||||
var totalSize float64
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Model(&Backup{}).
|
||||
Select("COALESCE(SUM(backup_size_mb), 0)").
|
||||
Where("database_id = ? AND status != ?", databaseID, BackupStatusInProgress).
|
||||
Scan(&totalSize).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return totalSize, nil
|
||||
}
|
||||
|
||||
func (r *BackupRepository) FindOldestByDatabaseExcludingInProgress(
|
||||
databaseID uuid.UUID,
|
||||
limit int,
|
||||
) ([]*Backup, error) {
|
||||
var backups []*Backup
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Where("database_id = ? AND status != ?", databaseID, BackupStatusInProgress).
|
||||
Order("created_at ASC").
|
||||
Limit(limit).
|
||||
Find(&backups).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return backups, nil
|
||||
}
|
||||
|
||||
func (r *BackupRepository) FindCompletedFullWalBackupByID(
|
||||
databaseID uuid.UUID,
|
||||
backupID uuid.UUID,
|
||||
) (*Backup, error) {
|
||||
var backup Backup
|
||||
|
||||
err := storage.
|
||||
GetDb().
|
||||
Where(
|
||||
"database_id = ? AND id = ? AND pg_wal_backup_type = ? AND status = ?",
|
||||
databaseID,
|
||||
backupID,
|
||||
PgWalBackupTypeFullBackup,
|
||||
BackupStatusCompleted,
|
||||
).
|
||||
First(&backup).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &backup, nil
|
||||
}
|
||||
|
||||
func (r *BackupRepository) FindCompletedWalSegmentsAfter(
|
||||
databaseID uuid.UUID,
|
||||
afterSegmentName string,
|
||||
) ([]*Backup, error) {
|
||||
var backups []*Backup
|
||||
|
||||
err := storage.
|
||||
GetDb().
|
||||
Where(
|
||||
"database_id = ? AND pg_wal_backup_type = ? AND pg_wal_segment_name >= ? AND status = ?",
|
||||
databaseID,
|
||||
PgWalBackupTypeWalSegment,
|
||||
afterSegmentName,
|
||||
BackupStatusCompleted,
|
||||
).
|
||||
Order("pg_wal_segment_name ASC").
|
||||
Find(&backups).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return backups, nil
|
||||
}
|
||||
|
||||
func (r *BackupRepository) FindLastCompletedFullWalBackupByDatabaseID(
|
||||
databaseID uuid.UUID,
|
||||
) (*Backup, error) {
|
||||
var backup Backup
|
||||
|
||||
err := storage.
|
||||
GetDb().
|
||||
Where(
|
||||
"database_id = ? AND pg_wal_backup_type = ? AND status = ?",
|
||||
databaseID,
|
||||
PgWalBackupTypeFullBackup,
|
||||
BackupStatusCompleted,
|
||||
).
|
||||
Order("created_at DESC").
|
||||
First(&backup).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &backup, nil
|
||||
}
|
||||
|
||||
func (r *BackupRepository) FindWalSegmentByName(
|
||||
databaseID uuid.UUID,
|
||||
segmentName string,
|
||||
) (*Backup, error) {
|
||||
var backup Backup
|
||||
|
||||
err := storage.
|
||||
GetDb().
|
||||
Where(
|
||||
"database_id = ? AND pg_wal_backup_type = ? AND pg_wal_segment_name = ?",
|
||||
databaseID,
|
||||
PgWalBackupTypeWalSegment,
|
||||
segmentName,
|
||||
).
|
||||
First(&backup).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &backup, nil
|
||||
}
|
||||
|
||||
func (r *BackupRepository) FindLastWalSegmentAfter(
|
||||
databaseID uuid.UUID,
|
||||
afterSegmentName string,
|
||||
) (*Backup, error) {
|
||||
var backup Backup
|
||||
|
||||
err := storage.
|
||||
GetDb().
|
||||
Where(
|
||||
"database_id = ? AND pg_wal_backup_type = ? AND pg_wal_segment_name > ? AND status = ?",
|
||||
databaseID,
|
||||
PgWalBackupTypeWalSegment,
|
||||
afterSegmentName,
|
||||
BackupStatusCompleted,
|
||||
).
|
||||
Order("pg_wal_segment_name DESC").
|
||||
First(&backup).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &backup, nil
|
||||
}
|
||||
@@ -1,71 +0,0 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups/usecases"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var backupRepository = &BackupRepository{}
|
||||
|
||||
var backupContextManager = NewBackupContextManager()
|
||||
|
||||
var backupService = &BackupService{
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
backupRepository,
|
||||
notifiers.GetNotifierService(),
|
||||
notifiers.GetNotifierService(),
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption_secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
audit_logs.GetAuditLogService(),
|
||||
backupContextManager,
|
||||
}
|
||||
|
||||
var backupBackgroundService = &BackupBackgroundService{
|
||||
backupService,
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
time.Now().UTC(),
|
||||
logger.GetLogger(),
|
||||
}
|
||||
|
||||
var backupController = &BackupController{
|
||||
backupService,
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
backups_config.
|
||||
GetBackupConfigService().
|
||||
SetDatabaseStorageChangeListener(backupService)
|
||||
|
||||
databases.GetDatabaseService().AddDbRemoveListener(backupService)
|
||||
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
|
||||
}
|
||||
|
||||
func GetBackupService() *BackupService {
|
||||
return backupService
|
||||
}
|
||||
|
||||
func GetBackupController() *BackupController {
|
||||
return backupController
|
||||
}
|
||||
|
||||
func GetBackupBackgroundService() *BackupBackgroundService {
|
||||
return backupBackgroundService
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DownloadTokenBackgroundService struct {
|
||||
downloadTokenService *DownloadTokenService
|
||||
logger *slog.Logger
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *DownloadTokenBackgroundService) Run(ctx context.Context) {
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
s.logger.Info("Starting download token cleanup background service")
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
|
||||
s.logger.Error("Failed to clean expired download tokens", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BandwidthManager struct {
|
||||
mu sync.RWMutex
|
||||
activeDownloads map[uuid.UUID]*activeDownload
|
||||
maxTotalBytesPerSecond int64
|
||||
bytesPerSecondPerDownload int64
|
||||
}
|
||||
|
||||
type activeDownload struct {
|
||||
userID uuid.UUID
|
||||
rateLimiter *RateLimiter
|
||||
}
|
||||
|
||||
func NewBandwidthManager(throughputMBs int) *BandwidthManager {
|
||||
// Use 75% of total throughput
|
||||
maxBytes := int64(throughputMBs) * 1024 * 1024 * 75 / 100
|
||||
|
||||
return &BandwidthManager{
|
||||
activeDownloads: make(map[uuid.UUID]*activeDownload),
|
||||
maxTotalBytesPerSecond: maxBytes,
|
||||
bytesPerSecondPerDownload: maxBytes,
|
||||
}
|
||||
}
|
||||
|
||||
func (bm *BandwidthManager) RegisterDownload(userID uuid.UUID) (*RateLimiter, error) {
|
||||
bm.mu.Lock()
|
||||
defer bm.mu.Unlock()
|
||||
|
||||
if _, exists := bm.activeDownloads[userID]; exists {
|
||||
return nil, fmt.Errorf("download already registered for user %s", userID)
|
||||
}
|
||||
|
||||
rateLimiter := NewRateLimiter(bm.bytesPerSecondPerDownload)
|
||||
|
||||
bm.activeDownloads[userID] = &activeDownload{
|
||||
userID: userID,
|
||||
rateLimiter: rateLimiter,
|
||||
}
|
||||
|
||||
bm.recalculateRates()
|
||||
|
||||
return rateLimiter, nil
|
||||
}
|
||||
|
||||
func (bm *BandwidthManager) UnregisterDownload(userID uuid.UUID) {
|
||||
bm.mu.Lock()
|
||||
defer bm.mu.Unlock()
|
||||
|
||||
delete(bm.activeDownloads, userID)
|
||||
bm.recalculateRates()
|
||||
}
|
||||
|
||||
func (bm *BandwidthManager) GetActiveDownloadCount() int {
|
||||
bm.mu.RLock()
|
||||
defer bm.mu.RUnlock()
|
||||
return len(bm.activeDownloads)
|
||||
}
|
||||
|
||||
func (bm *BandwidthManager) recalculateRates() {
|
||||
activeCount := len(bm.activeDownloads)
|
||||
|
||||
if activeCount == 0 {
|
||||
bm.bytesPerSecondPerDownload = bm.maxTotalBytesPerSecond
|
||||
return
|
||||
}
|
||||
|
||||
newRate := bm.maxTotalBytesPerSecond / int64(activeCount)
|
||||
bm.bytesPerSecondPerDownload = newRate
|
||||
|
||||
for _, download := range bm.activeDownloads {
|
||||
download.rateLimiter.UpdateRate(newRate)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,150 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_BandwidthManager_RegisterSingleDownload(t *testing.T) {
|
||||
throughputMBs := 100
|
||||
manager := NewBandwidthManager(throughputMBs)
|
||||
|
||||
expectedBytesPerSec := int64(100 * 1024 * 1024 * 75 / 100)
|
||||
assert.Equal(t, expectedBytesPerSec, manager.maxTotalBytesPerSecond)
|
||||
assert.Equal(t, expectedBytesPerSec, manager.bytesPerSecondPerDownload)
|
||||
|
||||
userID := uuid.New()
|
||||
rateLimiter, err := manager.RegisterDownload(userID)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, rateLimiter)
|
||||
|
||||
assert.Equal(t, 1, manager.GetActiveDownloadCount())
|
||||
assert.Equal(t, expectedBytesPerSec, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, expectedBytesPerSec, rateLimiter.bytesPerSecond)
|
||||
}
|
||||
|
||||
func Test_BandwidthManager_RegisterMultipleDownloads_BandwidthShared(t *testing.T) {
|
||||
throughputMBs := 100
|
||||
manager := NewBandwidthManager(throughputMBs)
|
||||
|
||||
maxBytes := int64(100 * 1024 * 1024 * 75 / 100)
|
||||
|
||||
user1 := uuid.New()
|
||||
rateLimiter1, err := manager.RegisterDownload(user1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, maxBytes, rateLimiter1.bytesPerSecond)
|
||||
|
||||
user2 := uuid.New()
|
||||
rateLimiter2, err := manager.RegisterDownload(user2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expectedPerDownload := maxBytes / 2
|
||||
assert.Equal(t, expectedPerDownload, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter2.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter2.bytesPerSecond)
|
||||
|
||||
user3 := uuid.New()
|
||||
rateLimiter3, err := manager.RegisterDownload(user3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expectedPerDownload = maxBytes / 3
|
||||
assert.Equal(t, expectedPerDownload, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter2.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter3.bytesPerSecond)
|
||||
assert.Equal(t, 3, manager.GetActiveDownloadCount())
|
||||
}
|
||||
|
||||
func Test_BandwidthManager_UnregisterDownload_BandwidthRebalanced(t *testing.T) {
|
||||
throughputMBs := 100
|
||||
manager := NewBandwidthManager(throughputMBs)
|
||||
|
||||
maxBytes := int64(100 * 1024 * 1024 * 75 / 100)
|
||||
|
||||
user1 := uuid.New()
|
||||
rateLimiter1, _ := manager.RegisterDownload(user1)
|
||||
|
||||
user2 := uuid.New()
|
||||
_, _ = manager.RegisterDownload(user2)
|
||||
|
||||
user3 := uuid.New()
|
||||
rateLimiter3, _ := manager.RegisterDownload(user3)
|
||||
|
||||
assert.Equal(t, 3, manager.GetActiveDownloadCount())
|
||||
expectedPerDownload := maxBytes / 3
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
|
||||
|
||||
manager.UnregisterDownload(user2)
|
||||
|
||||
assert.Equal(t, 2, manager.GetActiveDownloadCount())
|
||||
expectedPerDownload = maxBytes / 2
|
||||
assert.Equal(t, expectedPerDownload, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter3.bytesPerSecond)
|
||||
|
||||
manager.UnregisterDownload(user1)
|
||||
|
||||
assert.Equal(t, 1, manager.GetActiveDownloadCount())
|
||||
assert.Equal(t, maxBytes, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, maxBytes, rateLimiter3.bytesPerSecond)
|
||||
|
||||
manager.UnregisterDownload(user3)
|
||||
assert.Equal(t, 0, manager.GetActiveDownloadCount())
|
||||
assert.Equal(t, maxBytes, manager.bytesPerSecondPerDownload)
|
||||
}
|
||||
|
||||
func Test_BandwidthManager_RegisterDuplicateUser_ReturnsError(t *testing.T) {
|
||||
manager := NewBandwidthManager(100)
|
||||
|
||||
userID := uuid.New()
|
||||
_, err := manager.RegisterDownload(userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = manager.RegisterDownload(userID)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "download already registered")
|
||||
}
|
||||
|
||||
func Test_RateLimiter_TokenBucketBasic(t *testing.T) {
|
||||
bytesPerSec := int64(1024 * 1024)
|
||||
limiter := NewRateLimiter(bytesPerSec)
|
||||
|
||||
assert.Equal(t, bytesPerSec, limiter.bytesPerSecond)
|
||||
assert.Equal(t, bytesPerSec*2, limiter.bucketSize)
|
||||
|
||||
start := time.Now()
|
||||
limiter.Wait(512 * 1024)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.Less(t, elapsed, 100*time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_RateLimiter_UpdateRate(t *testing.T) {
|
||||
limiter := NewRateLimiter(1024 * 1024)
|
||||
|
||||
assert.Equal(t, int64(1024*1024), limiter.bytesPerSecond)
|
||||
|
||||
newRate := int64(2 * 1024 * 1024)
|
||||
limiter.UpdateRate(newRate)
|
||||
|
||||
assert.Equal(t, newRate, limiter.bytesPerSecond)
|
||||
assert.Equal(t, newRate*2, limiter.bucketSize)
|
||||
}
|
||||
|
||||
func Test_RateLimiter_ThrottlesCorrectly(t *testing.T) {
|
||||
bytesPerSec := int64(1024 * 1024)
|
||||
limiter := NewRateLimiter(bytesPerSec)
|
||||
|
||||
limiter.availableTokens = 0
|
||||
|
||||
start := time.Now()
|
||||
limiter.Wait(bytesPerSec / 2)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.GreaterOrEqual(t, elapsed, 400*time.Millisecond)
|
||||
assert.LessOrEqual(t, elapsed, 700*time.Millisecond)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user