mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 00:32:03 +02:00
Compare commits
79 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4344f5ea5e | ||
|
|
7c6afa5b88 | ||
|
|
dbac799e1b | ||
|
|
7ee3817089 | ||
|
|
bae6f7f007 | ||
|
|
55dc087ddd | ||
|
|
c94d0db637 | ||
|
|
a1adef2261 | ||
|
|
4602dc3f88 | ||
|
|
cbbfc5ea8f | ||
|
|
dd1072e230 | ||
|
|
a495e5317a | ||
|
|
7eed647038 | ||
|
|
6973241e25 | ||
|
|
ab181f5b81 | ||
|
|
b60a0cc170 | ||
|
|
f319a497b3 | ||
|
|
bc870b3f8e | ||
|
|
15383c59eb | ||
|
|
d14c223a65 | ||
|
|
2c0a294027 | ||
|
|
5d851d73bd | ||
|
|
699913c251 | ||
|
|
a2e3f30a6d | ||
|
|
80f1174ecd | ||
|
|
a47f8d5e2c | ||
|
|
54b9e67656 | ||
|
|
3782846872 | ||
|
|
245a81897f | ||
|
|
5cbc0773b6 | ||
|
|
997fc01442 | ||
|
|
6d0ae32d0c | ||
|
|
011985d723 | ||
|
|
d677ee61de | ||
|
|
c6b8f6e87a | ||
|
|
2bb5f93d00 | ||
|
|
b91c150300 | ||
|
|
12b119ce40 | ||
|
|
7c6f0ab4ba | ||
|
|
6d2db4b298 | ||
|
|
6397423298 | ||
|
|
3470aae8e3 | ||
|
|
184fbcdb2c | ||
|
|
2d897dd722 | ||
|
|
cba40afd00 | ||
|
|
7aea012aeb | ||
|
|
6d5534deaa | ||
|
|
c04bd54683 | ||
|
|
1c3f16b372 | ||
|
|
ed08da56a6 | ||
|
|
c53e84b48d | ||
|
|
dbfeb9e27f | ||
|
|
02e86ffb3b | ||
|
|
207382116c | ||
|
|
a91ee50e31 | ||
|
|
7e5562b115 | ||
|
|
3ef51c4d68 | ||
|
|
e47e513460 | ||
|
|
226a6c06e6 | ||
|
|
615fd9d574 | ||
|
|
e9fcf20cdf | ||
|
|
7649f4acfd | ||
|
|
7e4c3bcc19 | ||
|
|
f2aecc0427 | ||
|
|
3ce7da319f | ||
|
|
096098f660 | ||
|
|
c3ba4a7c5a | ||
|
|
52c0f53608 | ||
|
|
a5095acad4 | ||
|
|
a6d32b5c09 | ||
|
|
722560e824 | ||
|
|
496ac6120c | ||
|
|
756c6c87af | ||
|
|
a23d05b735 | ||
|
|
33a8d302eb | ||
|
|
25ed1ffd2a | ||
|
|
67582325bb | ||
|
|
5a89558cf6 | ||
|
|
0ec02430b7 |
149
.github/workflows/ci-release.yml
vendored
149
.github/workflows/ci-release.yml
vendored
@@ -17,17 +17,7 @@ jobs:
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.24.4"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
~/.cache/go-build
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('backend/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
go-version: "1.24.9"
|
||||
|
||||
- name: Install golangci-lint
|
||||
run: |
|
||||
@@ -63,8 +53,6 @@ jobs:
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
cache: "npm"
|
||||
cache-dependency-path: frontend/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
@@ -93,8 +81,6 @@ jobs:
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
cache: "npm"
|
||||
cache-dependency-path: frontend/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
@@ -134,17 +120,7 @@ jobs:
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.24.4"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
~/.cache/go-build
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('backend/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
go-version: "1.24.9"
|
||||
|
||||
- name: Create .env file for testing
|
||||
run: |
|
||||
@@ -221,6 +197,12 @@ jobs:
|
||||
TEST_MONGODB_60_PORT=27060
|
||||
TEST_MONGODB_70_PORT=27070
|
||||
TEST_MONGODB_82_PORT=27082
|
||||
# Valkey (cache)
|
||||
VALKEY_HOST=localhost
|
||||
VALKEY_PORT=6379
|
||||
VALKEY_USERNAME=
|
||||
VALKEY_PASSWORD=
|
||||
VALKEY_IS_SSL=false
|
||||
EOF
|
||||
|
||||
- name: Start test containers
|
||||
@@ -233,6 +215,11 @@ jobs:
|
||||
# Wait for main dev database
|
||||
timeout 60 bash -c 'until docker exec dev-db pg_isready -h localhost -p 5437 -U postgres; do sleep 2; done'
|
||||
|
||||
# Wait for Valkey (cache)
|
||||
echo "Waiting for Valkey..."
|
||||
timeout 60 bash -c 'until docker exec dev-valkey valkey-cli ping 2>/dev/null | grep -q PONG; do sleep 2; done'
|
||||
echo "Valkey is ready!"
|
||||
|
||||
# Wait for test databases
|
||||
timeout 60 bash -c 'until nc -z localhost 5000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5001; do sleep 2; done'
|
||||
@@ -310,34 +297,6 @@ jobs:
|
||||
mkdir -p databasus-data/backups
|
||||
mkdir -p databasus-data/temp
|
||||
|
||||
- name: Cache PostgreSQL client tools
|
||||
id: cache-postgres
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: /usr/lib/postgresql
|
||||
key: postgres-clients-12-18-v1
|
||||
|
||||
- name: Cache MySQL client tools
|
||||
id: cache-mysql
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: backend/tools/mysql
|
||||
key: mysql-clients-57-80-84-9-v1
|
||||
|
||||
- name: Cache MariaDB client tools
|
||||
id: cache-mariadb
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: backend/tools/mariadb
|
||||
key: mariadb-clients-106-121-v1
|
||||
|
||||
- name: Cache MongoDB Database Tools
|
||||
id: cache-mongodb
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: backend/tools/mongodb
|
||||
key: mongodb-database-tools-100.10.0-v1
|
||||
|
||||
- name: Install MySQL dependencies
|
||||
run: |
|
||||
sudo apt-get update -qq
|
||||
@@ -345,31 +304,58 @@ jobs:
|
||||
sudo ln -sf /usr/lib/x86_64-linux-gnu/libncurses.so.6 /usr/lib/x86_64-linux-gnu/libncurses.so.5
|
||||
sudo ln -sf /usr/lib/x86_64-linux-gnu/libtinfo.so.6 /usr/lib/x86_64-linux-gnu/libtinfo.so.5
|
||||
|
||||
- name: Install PostgreSQL, MySQL, MariaDB and MongoDB client tools
|
||||
if: steps.cache-postgres.outputs.cache-hit != 'true' || steps.cache-mysql.outputs.cache-hit != 'true' || steps.cache-mariadb.outputs.cache-hit != 'true' || steps.cache-mongodb.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
chmod +x backend/tools/download_linux.sh
|
||||
cd backend/tools
|
||||
./download_linux.sh
|
||||
|
||||
- name: Setup PostgreSQL symlinks (when using cache)
|
||||
if: steps.cache-postgres.outputs.cache-hit == 'true'
|
||||
- name: Setup PostgreSQL, MySQL and MariaDB client tools from pre-built assets
|
||||
run: |
|
||||
cd backend/tools
|
||||
mkdir -p postgresql
|
||||
|
||||
# Create directory structure
|
||||
mkdir -p postgresql mysql mariadb mongodb/bin
|
||||
|
||||
# Copy PostgreSQL client tools (12-18) from pre-built assets
|
||||
for version in 12 13 14 15 16 17 18; do
|
||||
version_dir="postgresql/postgresql-$version"
|
||||
mkdir -p "$version_dir/bin"
|
||||
pg_bin_dir="/usr/lib/postgresql/$version/bin"
|
||||
if [ -d "$pg_bin_dir" ]; then
|
||||
ln -sf "$pg_bin_dir/pg_dump" "$version_dir/bin/pg_dump"
|
||||
ln -sf "$pg_bin_dir/pg_dumpall" "$version_dir/bin/pg_dumpall"
|
||||
ln -sf "$pg_bin_dir/psql" "$version_dir/bin/psql"
|
||||
ln -sf "$pg_bin_dir/pg_restore" "$version_dir/bin/pg_restore"
|
||||
ln -sf "$pg_bin_dir/createdb" "$version_dir/bin/createdb"
|
||||
ln -sf "$pg_bin_dir/dropdb" "$version_dir/bin/dropdb"
|
||||
fi
|
||||
mkdir -p postgresql/postgresql-$version
|
||||
cp -r ../../assets/tools/x64/postgresql/postgresql-$version/bin postgresql/postgresql-$version/
|
||||
done
|
||||
|
||||
# Copy MySQL client tools (5.7, 8.0, 8.4, 9) from pre-built assets
|
||||
for version in 5.7 8.0 8.4 9; do
|
||||
mkdir -p mysql/mysql-$version
|
||||
cp -r ../../assets/tools/x64/mysql/mysql-$version/bin mysql/mysql-$version/
|
||||
done
|
||||
|
||||
# Copy MariaDB client tools (10.6, 12.1) from pre-built assets
|
||||
for version in 10.6 12.1; do
|
||||
mkdir -p mariadb/mariadb-$version
|
||||
cp -r ../../assets/tools/x64/mariadb/mariadb-$version/bin mariadb/mariadb-$version/
|
||||
done
|
||||
|
||||
# Make all binaries executable
|
||||
chmod +x postgresql/*/bin/*
|
||||
chmod +x mysql/*/bin/*
|
||||
chmod +x mariadb/*/bin/*
|
||||
|
||||
echo "Pre-built client tools setup complete"
|
||||
|
||||
- name: Install MongoDB Database Tools
|
||||
run: |
|
||||
cd backend/tools
|
||||
|
||||
# MongoDB Database Tools must be downloaded (not in pre-built assets)
|
||||
# They are backward compatible - single version supports all servers (4.0-8.0)
|
||||
MONGODB_TOOLS_URL="https://fastdl.mongodb.org/tools/db/mongodb-database-tools-debian12-x86_64-100.10.0.deb"
|
||||
|
||||
echo "Downloading MongoDB Database Tools..."
|
||||
wget -q "$MONGODB_TOOLS_URL" -O /tmp/mongodb-database-tools.deb
|
||||
|
||||
echo "Installing MongoDB Database Tools..."
|
||||
sudo dpkg -i /tmp/mongodb-database-tools.deb || sudo apt-get install -f -y --no-install-recommends
|
||||
|
||||
# Create symlinks to tools directory
|
||||
ln -sf /usr/bin/mongodump mongodb/bin/mongodump
|
||||
ln -sf /usr/bin/mongorestore mongodb/bin/mongorestore
|
||||
|
||||
rm -f /tmp/mongodb-database-tools.deb
|
||||
echo "MongoDB Database Tools installed successfully"
|
||||
|
||||
- name: Verify MariaDB client tools exist
|
||||
run: |
|
||||
@@ -672,17 +658,6 @@ jobs:
|
||||
echo EOF
|
||||
} >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Update CITATION.cff version
|
||||
run: |
|
||||
VERSION="${{ needs.determine-version.outputs.new_version }}"
|
||||
sed -i "s/^version: .*/version: ${VERSION}/" CITATION.cff
|
||||
sed -i "s/^date-released: .*/date-released: \"$(date +%Y-%m-%d)\"/" CITATION.cff
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git add CITATION.cff
|
||||
git commit -m "Update CITATION.cff to v${VERSION}" || true
|
||||
git push || true
|
||||
|
||||
- name: Create GitHub Release
|
||||
uses: actions/create-release@v1
|
||||
env:
|
||||
@@ -726,4 +701,4 @@ jobs:
|
||||
- name: Push Helm chart to GHCR
|
||||
run: |
|
||||
VERSION="${{ needs.determine-version.outputs.new_version }}"
|
||||
helm push databasus-${VERSION}.tgz oci://ghcr.io/databasus/charts
|
||||
helm push databasus-${VERSION}.tgz oci://ghcr.io/databasus/charts
|
||||
@@ -6,14 +6,14 @@ repos:
|
||||
hooks:
|
||||
- id: frontend-format
|
||||
name: Frontend Format (Prettier)
|
||||
entry: powershell -Command "cd frontend; npm run format"
|
||||
entry: bash -c "cd frontend && npm run format"
|
||||
language: system
|
||||
files: ^frontend/.*\.(ts|tsx|js|jsx|json|css|md)$
|
||||
pass_filenames: false
|
||||
|
||||
- id: frontend-lint
|
||||
name: Frontend Lint (ESLint)
|
||||
entry: powershell -Command "cd frontend; npm run lint"
|
||||
entry: bash -c "cd frontend && npm run lint"
|
||||
language: system
|
||||
files: ^frontend/.*\.(ts|tsx|js|jsx)$
|
||||
pass_filenames: false
|
||||
@@ -23,7 +23,14 @@ repos:
|
||||
hooks:
|
||||
- id: backend-format-and-lint
|
||||
name: Backend Format & Lint (golangci-lint)
|
||||
entry: powershell -Command "cd backend; golangci-lint fmt; golangci-lint run"
|
||||
entry: bash -c "cd backend && golangci-lint fmt ./internal/... ./cmd/... && golangci-lint run ./internal/... ./cmd/..."
|
||||
language: system
|
||||
files: ^backend/.*\.go$
|
||||
pass_filenames: false
|
||||
pass_filenames: false
|
||||
|
||||
- id: backend-go-mod-tidy
|
||||
name: Backend Go Mod Tidy
|
||||
entry: bash -c "cd backend && go mod tidy"
|
||||
language: system
|
||||
files: ^backend/.*\.go$
|
||||
pass_filenames: false
|
||||
|
||||
@@ -32,5 +32,5 @@ keywords:
|
||||
- mongodb
|
||||
- mariadb
|
||||
license: Apache-2.0
|
||||
version: 2.19.0
|
||||
date-released: "2026-01-02"
|
||||
version: 2.21.0
|
||||
date-released: "2026-01-05"
|
||||
|
||||
38
Dockerfile
38
Dockerfile
@@ -22,7 +22,7 @@ RUN npm run build
|
||||
|
||||
# ========= BUILD BACKEND =========
|
||||
# Backend build stage
|
||||
FROM --platform=$BUILDPLATFORM golang:1.24.4 AS backend-build
|
||||
FROM --platform=$BUILDPLATFORM golang:1.24.9 AS backend-build
|
||||
|
||||
# Make TARGET args available early so tools built here match the final image arch
|
||||
ARG TARGETOS
|
||||
@@ -123,6 +123,15 @@ RUN wget -qO- https://www.postgresql.org/media/keys/ACCC4CF8.asc | apt-key add -
|
||||
apt-get install -y --no-install-recommends postgresql-17 && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Valkey server from debian repository
|
||||
# Valkey is only accessible internally (localhost) - not exposed outside container
|
||||
RUN wget -O /usr/share/keyrings/greensec.github.io-valkey-debian.key https://greensec.github.io/valkey-debian/public.key && \
|
||||
echo "deb [signed-by=/usr/share/keyrings/greensec.github.io-valkey-debian.key] https://greensec.github.io/valkey-debian/repo $(lsb_release -cs) main" \
|
||||
> /etc/apt/sources.list.d/valkey-debian.list && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends valkey && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# ========= Install rclone =========
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends rclone && \
|
||||
@@ -245,7 +254,34 @@ PG_BIN="/usr/lib/postgresql/17/bin"
|
||||
# Ensure proper ownership of data directory
|
||||
echo "Setting up data directory permissions..."
|
||||
mkdir -p /databasus-data/pgdata
|
||||
mkdir -p /databasus-data/temp
|
||||
mkdir -p /databasus-data/backups
|
||||
chown -R postgres:postgres /databasus-data
|
||||
chmod 700 /databasus-data/temp
|
||||
|
||||
# ========= Start Valkey (internal cache) =========
|
||||
echo "Configuring Valkey cache..."
|
||||
cat > /tmp/valkey.conf << 'VALKEY_CONFIG'
|
||||
port 6379
|
||||
bind 127.0.0.1
|
||||
protected-mode yes
|
||||
save ""
|
||||
maxmemory 256mb
|
||||
maxmemory-policy allkeys-lru
|
||||
VALKEY_CONFIG
|
||||
|
||||
echo "Starting Valkey..."
|
||||
valkey-server /tmp/valkey.conf &
|
||||
VALKEY_PID=\$!
|
||||
|
||||
echo "Waiting for Valkey to be ready..."
|
||||
for i in {1..30}; do
|
||||
if valkey-cli ping >/dev/null 2>&1; then
|
||||
echo "Valkey is ready!"
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
# Initialize PostgreSQL if not already initialized
|
||||
if [ ! -s "/databasus-data/pgdata/PG_VERSION" ]; then
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
<img src="assets/logo.svg" alt="Databasus Logo" width="250"/>
|
||||
|
||||
<h3>Backup tool for PostgreSQL, MySQL and MongoDB</h3>
|
||||
<p>Databasus is a free, open source and self-hosted tool to backup databases. Make backups with different storages (S3, Google Drive, FTP, etc.) and notifications about progress (Slack, Discord, Telegram, etc.). Previously known as Postgresus (see migration guide).</p>
|
||||
<p>Databasus is a free, open source and self-hosted tool to backup databases (with focus on PostgreSQL). Make backups with different storages (S3, Google Drive, FTP, etc.) and notifications about progress (Slack, Discord, Telegram, etc.). Previously known as Postgresus (see migration guide).</p>
|
||||
|
||||
<!-- Badges -->
|
||||
[](https://www.postgresql.org/)
|
||||
|
||||
@@ -1,152 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
Always place private methods to the bottom of file
|
||||
|
||||
**This rule applies to ALL Go files including tests, services, controllers, repositories, etc.**
|
||||
|
||||
In Go, exported (public) functions/methods start with uppercase letters, while unexported (private) ones start with lowercase letters.
|
||||
|
||||
## Structure Order:
|
||||
|
||||
1. Type definitions and constants
|
||||
2. Public methods/functions (uppercase)
|
||||
3. Private methods/functions (lowercase)
|
||||
|
||||
## Examples:
|
||||
|
||||
### Service with methods:
|
||||
|
||||
```go
|
||||
type UserService struct {
|
||||
repository *UserRepository
|
||||
}
|
||||
|
||||
// Public methods first
|
||||
func (s *UserService) CreateUser(user *User) error {
|
||||
if err := s.validateUser(user); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.repository.Save(user)
|
||||
}
|
||||
|
||||
func (s *UserService) GetUser(id uuid.UUID) (*User, error) {
|
||||
return s.repository.FindByID(id)
|
||||
}
|
||||
|
||||
// Private methods at the bottom
|
||||
func (s *UserService) validateUser(user *User) error {
|
||||
if user.Name == "" {
|
||||
return errors.New("name is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
### Package-level functions:
|
||||
|
||||
```go
|
||||
package utils
|
||||
|
||||
// Public functions first
|
||||
func ProcessData(data []byte) (Result, error) {
|
||||
cleaned := sanitizeInput(data)
|
||||
return parseData(cleaned)
|
||||
}
|
||||
|
||||
func ValidateInput(input string) bool {
|
||||
return isValidFormat(input) && checkLength(input)
|
||||
}
|
||||
|
||||
// Private functions at the bottom
|
||||
func sanitizeInput(data []byte) []byte {
|
||||
// implementation
|
||||
}
|
||||
|
||||
func parseData(data []byte) (Result, error) {
|
||||
// implementation
|
||||
}
|
||||
|
||||
func isValidFormat(input string) bool {
|
||||
// implementation
|
||||
}
|
||||
|
||||
func checkLength(input string) bool {
|
||||
// implementation
|
||||
}
|
||||
```
|
||||
|
||||
### Test files:
|
||||
|
||||
```go
|
||||
package user_test
|
||||
|
||||
// Public test functions first
|
||||
func Test_CreateUser_ValidInput_UserCreated(t *testing.T) {
|
||||
user := createTestUser()
|
||||
result, err := service.CreateUser(user)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
|
||||
func Test_GetUser_ExistingUser_ReturnsUser(t *testing.T) {
|
||||
user := createTestUser()
|
||||
// test implementation
|
||||
}
|
||||
|
||||
// Private helper functions at the bottom
|
||||
func createTestUser() *User {
|
||||
return &User{
|
||||
Name: "Test User",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
}
|
||||
|
||||
func setupTestDatabase() *Database {
|
||||
// setup implementation
|
||||
}
|
||||
```
|
||||
|
||||
### Controller example:
|
||||
|
||||
```go
|
||||
type ProjectController struct {
|
||||
service *ProjectService
|
||||
}
|
||||
|
||||
// Public HTTP handlers first
|
||||
func (c *ProjectController) CreateProject(ctx *gin.Context) {
|
||||
var request CreateProjectRequest
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
c.handleError(ctx, err)
|
||||
return
|
||||
}
|
||||
// handler logic
|
||||
}
|
||||
|
||||
func (c *ProjectController) GetProject(ctx *gin.Context) {
|
||||
projectID := c.extractProjectID(ctx)
|
||||
// handler logic
|
||||
}
|
||||
|
||||
// Private helper methods at the bottom
|
||||
func (c *ProjectController) handleError(ctx *gin.Context, err error) {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
}
|
||||
|
||||
func (c *ProjectController) extractProjectID(ctx *gin.Context) uuid.UUID {
|
||||
return uuid.MustParse(ctx.Param("projectId"))
|
||||
}
|
||||
```
|
||||
|
||||
## Key Points:
|
||||
|
||||
- **Exported/Public** = starts with uppercase letter (CreateUser, GetProject)
|
||||
- **Unexported/Private** = starts with lowercase letter (validateUser, handleError)
|
||||
- This improves code readability by showing the public API first
|
||||
- Private helpers are implementation details, so they go at the bottom
|
||||
- Apply this rule consistently across ALL Go files in the project
|
||||
@@ -1,45 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
## Comment Guidelines
|
||||
|
||||
1. **No obvious comments** - Don't state what the code already clearly shows
|
||||
2. **Functions and variables should have meaningful names** - Code should be self-documenting
|
||||
3. **Comments for unclear code only** - Only add comments when code logic isn't immediately clear
|
||||
|
||||
## Key Principles:
|
||||
|
||||
- **Code should tell a story** - Use descriptive variable and function names
|
||||
- **Comments explain WHY, not WHAT** - The code shows what happens, comments explain business logic or complex decisions
|
||||
- **Prefer refactoring over commenting** - If code needs explaining, consider making it clearer instead
|
||||
- **API documentation is required** - Swagger comments for all HTTP endpoints are mandatory
|
||||
- **Complex algorithms deserve comments** - Mathematical formulas, business rules, or non-obvious optimizations
|
||||
|
||||
Example of useless comment:
|
||||
|
||||
1.
|
||||
|
||||
```sql
|
||||
// Create projects table
|
||||
CREATE TABLE projects (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
name TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
```
|
||||
|
||||
2.
|
||||
|
||||
```go
|
||||
// Create test project
|
||||
project := CreateTestProject(projectName, user, router)
|
||||
```
|
||||
|
||||
3.
|
||||
|
||||
```go
|
||||
// CreateValidLogItems creates valid log items for testing
|
||||
func CreateValidLogItems(count int, uniqueID string) []logs_receiving.LogItemRequestDTO {
|
||||
```
|
||||
@@ -1,133 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
1. When we write controller:
|
||||
|
||||
- we combine all routes to single controller
|
||||
- names them as .WhatWeDo (not "handlers") concept
|
||||
|
||||
2. We use gin and \*gin.Context for all routes.
|
||||
Example:
|
||||
|
||||
func (c *TasksController) GetAvailableTasks(ctx *gin.Context) ...
|
||||
|
||||
3. We document all routes with Swagger in the following format:
|
||||
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
)
|
||||
|
||||
type AuditLogController struct {
|
||||
auditLogService \*AuditLogService
|
||||
}
|
||||
|
||||
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// All audit log endpoints require authentication (handled in main.go)
|
||||
auditRoutes := router.Group("/audit-logs")
|
||||
|
||||
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
|
||||
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
|
||||
|
||||
}
|
||||
|
||||
// GetGlobalAuditLogs
|
||||
// @Summary Get global audit logs (ADMIN only)
|
||||
// @Description Retrieve all audit logs across the system
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/global [get]
|
||||
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(\*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "only administrators can view global audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
|
||||
}
|
||||
|
||||
// GetUserAuditLogs
|
||||
// @Summary Get user audit logs
|
||||
// @Description Retrieve audit logs for a specific user
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param userId path string true "User ID"
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/users/{userId} [get]
|
||||
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(\*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
userIDStr := ctx.Param("userId")
|
||||
targetUserID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view user audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
|
||||
}
|
||||
@@ -1,671 +0,0 @@
|
||||
---
|
||||
alwaysApply: false
|
||||
---
|
||||
|
||||
This is example of CRUD:
|
||||
|
||||
------ backend/internal/features/audit_logs/controller.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogController struct {
|
||||
auditLogService *AuditLogService
|
||||
}
|
||||
|
||||
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// All audit log endpoints require authentication (handled in main.go)
|
||||
auditRoutes := router.Group("/audit-logs")
|
||||
|
||||
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
|
||||
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
|
||||
}
|
||||
|
||||
// GetGlobalAuditLogs
|
||||
// @Summary Get global audit logs (ADMIN only)
|
||||
// @Description Retrieve all audit logs across the system
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/global [get]
|
||||
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "only administrators can view global audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// GetUserAuditLogs
|
||||
// @Summary Get user audit logs
|
||||
// @Description Retrieve audit logs for a specific user
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param userId path string true "User ID"
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/users/{userId} [get]
|
||||
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
userIDStr := ctx.Param("userId")
|
||||
targetUserID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view user audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/controller_test.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
"databasus-backend/internal/storage"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_GetGlobalAuditLogs_AdminSucceedsAndMemberGetsForbidden(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
memberUser := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
projectID := uuid.New()
|
||||
|
||||
// Create test logs
|
||||
createAuditLog(service, "Test log with user", &adminUser.UserID, nil)
|
||||
createAuditLog(service, "Test log with project", nil, &projectID)
|
||||
createAuditLog(service, "Test log standalone", nil, nil)
|
||||
|
||||
// Test ADMIN can access global logs
|
||||
var response GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
"/api/v1/audit-logs/global?limit=10", "Bearer "+adminUser.Token, http.StatusOK, &response)
|
||||
|
||||
assert.GreaterOrEqual(t, len(response.AuditLogs), 3)
|
||||
assert.GreaterOrEqual(t, response.Total, int64(3))
|
||||
|
||||
messages := extractMessages(response.AuditLogs)
|
||||
assert.Contains(t, messages, "Test log with user")
|
||||
assert.Contains(t, messages, "Test log with project")
|
||||
assert.Contains(t, messages, "Test log standalone")
|
||||
|
||||
// Test MEMBER cannot access global logs
|
||||
resp := test_utils.MakeGetRequest(t, router, "/api/v1/audit-logs/global",
|
||||
"Bearer "+memberUser.Token, http.StatusForbidden)
|
||||
assert.Contains(t, string(resp.Body), "only administrators can view global audit logs")
|
||||
}
|
||||
|
||||
func Test_GetUserAuditLogs_PermissionsEnforcedCorrectly(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
projectID := uuid.New()
|
||||
|
||||
// Create test logs for different users
|
||||
createAuditLog(service, "Test log user1 first", &user1.UserID, nil)
|
||||
createAuditLog(service, "Test log user1 second", &user1.UserID, &projectID)
|
||||
createAuditLog(service, "Test log user2 first", &user2.UserID, nil)
|
||||
createAuditLog(service, "Test log user2 second", &user2.UserID, &projectID)
|
||||
createAuditLog(service, "Test project log", nil, &projectID)
|
||||
|
||||
// Test ADMIN can view any user's logs
|
||||
var user1Response GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s?limit=10", user1.UserID.String()),
|
||||
"Bearer "+adminUser.Token, http.StatusOK, &user1Response)
|
||||
|
||||
assert.Equal(t, 2, len(user1Response.AuditLogs))
|
||||
messages := extractMessages(user1Response.AuditLogs)
|
||||
assert.Contains(t, messages, "Test log user1 first")
|
||||
assert.Contains(t, messages, "Test log user1 second")
|
||||
|
||||
// Test user can view own logs
|
||||
var ownLogsResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s", user2.UserID.String()),
|
||||
"Bearer "+user2.Token, http.StatusOK, &ownLogsResponse)
|
||||
assert.Equal(t, 2, len(ownLogsResponse.AuditLogs))
|
||||
|
||||
// Test user cannot view other user's logs
|
||||
resp := test_utils.MakeGetRequest(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s", user1.UserID.String()),
|
||||
"Bearer "+user2.Token, http.StatusForbidden)
|
||||
|
||||
assert.Contains(t, string(resp.Body), "insufficient permissions")
|
||||
}
|
||||
|
||||
func Test_FilterAuditLogsByTime_ReturnsOnlyLogsBeforeDate(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
db := storage.GetDb()
|
||||
baseTime := time.Now().UTC()
|
||||
|
||||
// Create logs with different timestamps
|
||||
createTimedLog(db, &adminUser.UserID, "Test old log", baseTime.Add(-2*time.Hour))
|
||||
createTimedLog(db, &adminUser.UserID, "Test recent log", baseTime.Add(-30*time.Minute))
|
||||
createAuditLog(service, "Test current log", &adminUser.UserID, nil)
|
||||
|
||||
// Test filtering - get logs before 1 hour ago
|
||||
beforeTime := baseTime.Add(-1 * time.Hour)
|
||||
var filteredResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/global?beforeDate=%s", beforeTime.Format(time.RFC3339)),
|
||||
"Bearer "+adminUser.Token, http.StatusOK, &filteredResponse)
|
||||
|
||||
// Verify only old log is returned
|
||||
messages := extractMessages(filteredResponse.AuditLogs)
|
||||
assert.Contains(t, messages, "Test old log")
|
||||
assert.NotContains(t, messages, "Test recent log")
|
||||
assert.NotContains(t, messages, "Test current log")
|
||||
|
||||
// Test without filter - should get all logs
|
||||
var allResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router, "/api/v1/audit-logs/global",
|
||||
"Bearer "+adminUser.Token, http.StatusOK, &allResponse)
|
||||
assert.GreaterOrEqual(t, len(allResponse.AuditLogs), 3)
|
||||
}
|
||||
|
||||
func createRouter() *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
SetupDependencies()
|
||||
|
||||
v1 := router.Group("/api/v1")
|
||||
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
|
||||
GetAuditLogController().RegisterRoutes(protected.(*gin.RouterGroup))
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/di.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var auditLogRepository = &AuditLogRepository{}
|
||||
var auditLogService = &AuditLogService{
|
||||
auditLogRepository: auditLogRepository,
|
||||
logger: logger.GetLogger(),
|
||||
}
|
||||
var auditLogController = &AuditLogController{
|
||||
auditLogService: auditLogService,
|
||||
}
|
||||
|
||||
func GetAuditLogService() *AuditLogService {
|
||||
return auditLogService
|
||||
}
|
||||
|
||||
func GetAuditLogController() *AuditLogController {
|
||||
return auditLogController
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
users_services.GetUserService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/dto.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import "time"
|
||||
|
||||
type GetAuditLogsRequest struct {
|
||||
Limit int `form:"limit" json:"limit"`
|
||||
Offset int `form:"offset" json:"offset"`
|
||||
BeforeDate *time.Time `form:"beforeDate" json:"beforeDate"`
|
||||
}
|
||||
|
||||
type GetAuditLogsResponse struct {
|
||||
AuditLogs []*AuditLog `json:"auditLogs"`
|
||||
Total int64 `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/models.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLog struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id"`
|
||||
UserID *uuid.UUID `json:"userId" gorm:"column:user_id"`
|
||||
ProjectID *uuid.UUID `json:"projectId" gorm:"column:project_id"`
|
||||
Message string `json:"message" gorm:"column:message"`
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
|
||||
}
|
||||
|
||||
func (AuditLog) TableName() string {
|
||||
return "audit_logs"
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/repository.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/storage"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogRepository struct{}
|
||||
|
||||
func (r *AuditLogRepository) Create(auditLog *AuditLog) error {
|
||||
if auditLog.ID == uuid.Nil {
|
||||
auditLog.ID = uuid.New()
|
||||
}
|
||||
|
||||
return storage.GetDb().Create(auditLog).Error
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetGlobal(limit, offset int, beforeDate *time.Time) ([]*AuditLog, error) {
|
||||
var auditLogs []*AuditLog
|
||||
|
||||
query := storage.GetDb().Order("created_at DESC")
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetByUser(
|
||||
userID uuid.UUID,
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLog, error) {
|
||||
var auditLogs []*AuditLog
|
||||
|
||||
query := storage.GetDb().
|
||||
Where("user_id = ?", userID).
|
||||
Order("created_at DESC")
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetByProject(
|
||||
projectID uuid.UUID,
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLog, error) {
|
||||
var auditLogs []*AuditLog
|
||||
|
||||
query := storage.GetDb().
|
||||
Where("project_id = ?", projectID).
|
||||
Order("created_at DESC")
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) CountGlobal(beforeDate *time.Time) (int64, error) {
|
||||
var count int64
|
||||
query := storage.GetDb().Model(&AuditLog{})
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/service.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogService struct {
|
||||
auditLogRepository *AuditLogRepository
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *AuditLogService) WriteAuditLog(
|
||||
message string,
|
||||
userID *uuid.UUID,
|
||||
projectID *uuid.UUID,
|
||||
) {
|
||||
auditLog := &AuditLog{
|
||||
UserID: userID,
|
||||
ProjectID: projectID,
|
||||
Message: message,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
err := s.auditLogRepository.Create(auditLog)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to create audit log", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuditLogService) CreateAuditLog(auditLog *AuditLog) error {
|
||||
return s.auditLogRepository.Create(auditLog)
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetGlobalAuditLogs(
|
||||
user *user_models.User,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
if user.Role != user_enums.UserRoleAdmin {
|
||||
return nil, errors.New("only administrators can view global audit logs")
|
||||
}
|
||||
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetGlobal(limit, offset, request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
total, err := s.auditLogRepository.CountGlobal(request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: total,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetUserAuditLogs(
|
||||
targetUserID uuid.UUID,
|
||||
user *user_models.User,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
// Users can view their own logs, ADMIN can view any user's logs
|
||||
if user.Role != user_enums.UserRoleAdmin && user.ID != targetUserID {
|
||||
return nil, errors.New("insufficient permissions to view user audit logs")
|
||||
}
|
||||
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetByUser(targetUserID, limit, offset, request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: int64(len(auditLogs)),
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetProjectAuditLogs(
|
||||
projectID uuid.UUID,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetByProject(projectID, limit, offset, request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: int64(len(auditLogs)),
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/service_test.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func Test_AuditLogs_ProjectSpecificLogs(t *testing.T) {
|
||||
service := GetAuditLogService()
|
||||
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
project1ID, project2ID := uuid.New(), uuid.New()
|
||||
|
||||
// Create test logs for projects
|
||||
createAuditLog(service, "Test project1 log first", &user1.UserID, &project1ID)
|
||||
createAuditLog(service, "Test project1 log second", &user2.UserID, &project1ID)
|
||||
createAuditLog(service, "Test project2 log first", &user1.UserID, &project2ID)
|
||||
createAuditLog(service, "Test project2 log second", &user2.UserID, &project2ID)
|
||||
createAuditLog(service, "Test no project log", &user1.UserID, nil)
|
||||
|
||||
request := &GetAuditLogsRequest{Limit: 10, Offset: 0}
|
||||
|
||||
// Test project 1 logs
|
||||
project1Response, err := service.GetProjectAuditLogs(project1ID, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(project1Response.AuditLogs))
|
||||
|
||||
messages := extractMessages(project1Response.AuditLogs)
|
||||
assert.Contains(t, messages, "Test project1 log first")
|
||||
assert.Contains(t, messages, "Test project1 log second")
|
||||
for _, log := range project1Response.AuditLogs {
|
||||
assert.Equal(t, &project1ID, log.ProjectID)
|
||||
}
|
||||
|
||||
// Test project 2 logs
|
||||
project2Response, err := service.GetProjectAuditLogs(project2ID, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(project2Response.AuditLogs))
|
||||
|
||||
messages2 := extractMessages(project2Response.AuditLogs)
|
||||
assert.Contains(t, messages2, "Test project2 log first")
|
||||
assert.Contains(t, messages2, "Test project2 log second")
|
||||
|
||||
// Test pagination
|
||||
limitedResponse, err := service.GetProjectAuditLogs(project1ID,
|
||||
&GetAuditLogsRequest{Limit: 1, Offset: 0})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(limitedResponse.AuditLogs))
|
||||
assert.Equal(t, 1, limitedResponse.Limit)
|
||||
|
||||
// Test beforeDate filter
|
||||
beforeTime := time.Now().UTC().Add(-1 * time.Minute)
|
||||
filteredResponse, err := service.GetProjectAuditLogs(project1ID,
|
||||
&GetAuditLogsRequest{Limit: 10, BeforeDate: &beforeTime})
|
||||
assert.NoError(t, err)
|
||||
for _, log := range filteredResponse.AuditLogs {
|
||||
assert.True(t, log.CreatedAt.Before(beforeTime))
|
||||
}
|
||||
}
|
||||
|
||||
func createAuditLog(service *AuditLogService, message string, userID, projectID *uuid.UUID) {
|
||||
service.WriteAuditLog(message, userID, projectID)
|
||||
}
|
||||
|
||||
func extractMessages(logs []*AuditLog) []string {
|
||||
messages := make([]string, len(logs))
|
||||
for i, log := range logs {
|
||||
messages[i] = log.Message
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
func createTimedLog(db *gorm.DB, userID *uuid.UUID, message string, createdAt time.Time) {
|
||||
log := &AuditLog{
|
||||
ID: uuid.New(),
|
||||
UserID: userID,
|
||||
Message: message,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
db.Create(log)
|
||||
}
|
||||
|
||||
```
|
||||
@@ -1,74 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
For DI files use implicit fields declaration styles (espesially
|
||||
for controllers, services, repositories, use cases, etc., not simple
|
||||
data structures).
|
||||
|
||||
So, instead of:
|
||||
|
||||
var orderController = &OrderController{
|
||||
orderService: orderService,
|
||||
botUserService: bot_users.GetBotUserService(),
|
||||
botService: bots.GetBotService(),
|
||||
userService: users.GetUserService(),
|
||||
}
|
||||
|
||||
Use:
|
||||
|
||||
var orderController = &OrderController{
|
||||
orderService,
|
||||
bot_users.GetBotUserService(),
|
||||
bots.GetBotService(),
|
||||
users.GetUserService(),
|
||||
}
|
||||
|
||||
This is needed to avoid forgetting to update DI style
|
||||
when we add new dependency.
|
||||
|
||||
---
|
||||
|
||||
Please force such usage if file look like this (see some
|
||||
services\controllers\repos definitions and getters):
|
||||
|
||||
var orderBackgroundService = &OrderBackgroundService{
|
||||
orderService: orderService,
|
||||
orderPaymentRepository: orderPaymentRepository,
|
||||
botService: bots.GetBotService(),
|
||||
paymentSettingsService: payment_settings.GetPaymentSettingsService(),
|
||||
|
||||
orderSubscriptionListeners: []OrderSubscriptionListener{},
|
||||
}
|
||||
|
||||
var orderController = &OrderController{
|
||||
orderService: orderService,
|
||||
botUserService: bot_users.GetBotUserService(),
|
||||
botService: bots.GetBotService(),
|
||||
userService: users.GetUserService(),
|
||||
}
|
||||
|
||||
func GetUniquePaymentRepository() *repositories.UniquePaymentRepository {
|
||||
return uniquePaymentRepository
|
||||
}
|
||||
|
||||
func GetOrderPaymentRepository() *repositories.OrderPaymentRepository {
|
||||
return orderPaymentRepository
|
||||
}
|
||||
|
||||
func GetOrderService() *OrderService {
|
||||
return orderService
|
||||
}
|
||||
|
||||
func GetOrderController() *OrderController {
|
||||
return orderController
|
||||
}
|
||||
|
||||
func GetOrderBackgroundService() *OrderBackgroundService {
|
||||
return orderBackgroundService
|
||||
}
|
||||
|
||||
func GetOrderRepository() *repositories.OrderRepository {
|
||||
return orderRepository
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
When writting migrations:
|
||||
|
||||
- write them for PostgreSQL
|
||||
- for PRIMARY UUID keys use gen_random_uuid()
|
||||
- for time use TIMESTAMPTZ (timestamp with zone)
|
||||
- split table, constraint and indexes declaration (table first, them other one by one)
|
||||
- format SQL in pretty way (add spaces, align columns types), constraints split by lines. The example:
|
||||
|
||||
CREATE TABLE marketplace_info (
|
||||
bot_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
title TEXT NOT NULL,
|
||||
description TEXT NOT NULL,
|
||||
short_description TEXT NOT NULL,
|
||||
tutorial_url TEXT,
|
||||
info_order BIGINT NOT NULL DEFAULT 0,
|
||||
is_published BOOLEAN NOT NULL DEFAULT FALSE
|
||||
);
|
||||
|
||||
ALTER TABLE marketplace_info_images
|
||||
ADD CONSTRAINT fk_marketplace_info_images_bot_id
|
||||
FOREIGN KEY (bot_id)
|
||||
REFERENCES marketplace_info (bot_id);
|
||||
@@ -1,12 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
When applying changes, do not forget to refactor old code.
|
||||
|
||||
You can shortify, make more readable, improve code quality, etc.
|
||||
Common logic can be extracted to functions, constants, files, etc.
|
||||
|
||||
After each large change with more than ~50-100 lines of code - always run `make lint` (from backend root folder) and, if you change frontend, run `npm run format` (from frontend root folder).
|
||||
@@ -1,147 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
After writing tests, always launch them and verify that they pass.
|
||||
|
||||
## Test Naming Format
|
||||
|
||||
Use these naming patterns:
|
||||
|
||||
- `Test_WhatWeDo_WhatWeExpect`
|
||||
- `Test_WhatWeDo_WhichConditions_WhatWeExpect`
|
||||
|
||||
## Examples from Real Codebase:
|
||||
|
||||
- `Test_CreateApiKey_WhenUserIsProjectOwner_ApiKeyCreated`
|
||||
- `Test_UpdateProject_WhenUserIsProjectAdmin_ProjectUpdated`
|
||||
- `Test_DeleteApiKey_WhenUserIsProjectMember_ReturnsForbidden`
|
||||
- `Test_GetProjectAuditLogs_WithDifferentUserRoles_EnforcesPermissionsCorrectly`
|
||||
- `Test_ProjectLifecycleE2E_CompletesSuccessfully`
|
||||
|
||||
## Testing Philosophy
|
||||
|
||||
**Prefer Controllers Over Unit Tests:**
|
||||
|
||||
- Test through HTTP endpoints via controllers whenever possible
|
||||
- Avoid testing repositories, services in isolation - test via API instead
|
||||
- Only use unit tests for complex model logic when no API exists
|
||||
- Name test files `controller_test.go` or `service_test.go`, not `integration_test.go`
|
||||
|
||||
**Extract Common Logic to Testing Utilities:**
|
||||
|
||||
- Create `testing.go` or `testing/testing.go` files for shared test utilities
|
||||
- Extract router creation, user setup, models creation helpers (in API, not just structs creation)
|
||||
- Reuse common patterns across different test files
|
||||
|
||||
**Refactor Existing Tests:**
|
||||
|
||||
- When working with existing tests, always look for opportunities to refactor and improve
|
||||
- Extract repetitive setup code to common utilities
|
||||
- Simplify complex tests by breaking them into smaller, focused tests
|
||||
- Replace inline test data creation with reusable helper functions
|
||||
- Consolidate similar test patterns across different test files
|
||||
- Make tests more readable and maintainable for other developers
|
||||
|
||||
## Testing Utilities Structure
|
||||
|
||||
**Create `testing.go` or `testing/testing.go` files with common utilities:**
|
||||
|
||||
```go
|
||||
package projects_testing
|
||||
|
||||
// CreateTestRouter creates unified router for all controllers
|
||||
func CreateTestRouter(controllers ...ControllerInterface) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
v1 := router.Group("/api/v1")
|
||||
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
|
||||
|
||||
for _, controller := range controllers {
|
||||
if routerGroup, ok := protected.(*gin.RouterGroup); ok {
|
||||
controller.RegisterRoutes(routerGroup)
|
||||
}
|
||||
}
|
||||
return router
|
||||
}
|
||||
|
||||
// CreateTestProjectViaAPI creates project through HTTP API
|
||||
func CreateTestProjectViaAPI(name string, owner *users_dto.SignInResponseDTO, router *gin.Engine) (*projects_models.Project, string) {
|
||||
request := projects_dto.CreateProjectRequestDTO{Name: name}
|
||||
w := MakeAPIRequest(router, "POST", "/api/v1/projects", "Bearer "+owner.Token, request)
|
||||
// Handle response...
|
||||
return project, owner.Token
|
||||
}
|
||||
|
||||
// AddMemberToProject adds member via API call
|
||||
func AddMemberToProject(project *projects_models.Project, member *users_dto.SignInResponseDTO, role users_enums.ProjectRole, ownerToken string, router *gin.Engine) {
|
||||
// Implementation...
|
||||
}
|
||||
```
|
||||
|
||||
## Controller Test Examples
|
||||
|
||||
**Permission-based testing:**
|
||||
|
||||
```go
|
||||
func Test_CreateApiKey_WhenUserIsProjectOwner_ApiKeyCreated(t *testing.T) {
|
||||
router := CreateApiKeyTestRouter(GetProjectController(), GetMembershipController())
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
project, _ := projects_testing.CreateTestProjectViaAPI("Test Project", owner, router)
|
||||
|
||||
request := CreateApiKeyRequestDTO{Name: "Test API Key"}
|
||||
var response ApiKey
|
||||
test_utils.MakePostRequestAndUnmarshal(t, router, "/api/v1/projects/api-keys/"+project.ID.String(), "Bearer "+owner.Token, request, http.StatusOK, &response)
|
||||
|
||||
assert.Equal(t, "Test API Key", response.Name)
|
||||
assert.NotEmpty(t, response.Token)
|
||||
}
|
||||
```
|
||||
|
||||
**Cross-project security testing:**
|
||||
|
||||
```go
|
||||
func Test_UpdateApiKey_WithApiKeyFromDifferentProject_ReturnsBadRequest(t *testing.T) {
|
||||
router := CreateApiKeyTestRouter(GetProjectController(), GetMembershipController())
|
||||
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
project1, _ := projects_testing.CreateTestProjectViaAPI("Project 1", owner1, router)
|
||||
project2, _ := projects_testing.CreateTestProjectViaAPI("Project 2", owner2, router)
|
||||
|
||||
apiKey := CreateTestApiKey("Cross Project Key", project1.ID, owner1.Token, router)
|
||||
|
||||
// Try to update via different project endpoint
|
||||
request := UpdateApiKeyRequestDTO{Name: &"Hacked Key"}
|
||||
resp := test_utils.MakePutRequest(t, router, "/api/v1/projects/api-keys/"+project2.ID.String()+"/"+apiKey.ID.String(), "Bearer "+owner2.Token, request, http.StatusBadRequest)
|
||||
|
||||
assert.Contains(t, string(resp.Body), "API key does not belong to this project")
|
||||
}
|
||||
```
|
||||
|
||||
**E2E lifecycle testing:**
|
||||
|
||||
```go
|
||||
func Test_ProjectLifecycleE2E_CompletesSuccessfully(t *testing.T) {
|
||||
router := projects_testing.CreateTestRouter(GetProjectController(), GetMembershipController())
|
||||
|
||||
// 1. Create project
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
project := projects_testing.CreateTestProject("E2E Project", owner, router)
|
||||
|
||||
// 2. Add member
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
projects_testing.AddMemberToProject(project, member, users_enums.ProjectRoleMember, owner.Token, router)
|
||||
|
||||
// 3. Promote to admin
|
||||
projects_testing.ChangeMemberRole(project, member.UserID, users_enums.ProjectRoleAdmin, owner.Token, router)
|
||||
|
||||
// 4. Transfer ownership
|
||||
projects_testing.TransferProjectOwnership(project, member.UserID, owner.Token, router)
|
||||
|
||||
// 5. Verify new owner can manage project
|
||||
finalProject := projects_testing.GetProject(project.ID, member.Token, router)
|
||||
assert.Equal(t, project.ID, finalProject.ID)
|
||||
}
|
||||
```
|
||||
@@ -1,6 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
Always use time.Now().UTC() instead of time.Now()
|
||||
@@ -2,8 +2,10 @@
|
||||
DEV_DB_NAME=databasus
|
||||
DEV_DB_USERNAME=postgres
|
||||
DEV_DB_PASSWORD=Q1234567
|
||||
#app
|
||||
# app
|
||||
ENV_MODE=development
|
||||
# logging
|
||||
SHOW_DB_INSTALLATION_VERIFICATION_LOGS=true
|
||||
# db
|
||||
DATABASE_DSN=host=dev-db user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
|
||||
DATABASE_URL=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable
|
||||
@@ -11,6 +13,12 @@ DATABASE_URL=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable
|
||||
GOOSE_DRIVER=postgres
|
||||
GOOSE_DBSTRING=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable
|
||||
GOOSE_MIGRATION_DIR=./migrations
|
||||
# valkey
|
||||
VALKEY_HOST=127.0.0.1
|
||||
VALKEY_PORT=6379
|
||||
VALKEY_USERNAME=
|
||||
VALKEY_PASSWORD=
|
||||
VALKEY_IS_SSL=false
|
||||
# testing
|
||||
# to get Google Drive env variables: add storage in UI and copy data from added storage here
|
||||
TEST_GOOGLE_DRIVE_CLIENT_ID=
|
||||
|
||||
@@ -10,4 +10,10 @@ DATABASE_URL=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disab
|
||||
# migrations
|
||||
GOOSE_DRIVER=postgres
|
||||
GOOSE_DBSTRING=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disable
|
||||
GOOSE_MIGRATION_DIR=./migrations
|
||||
GOOSE_MIGRATION_DIR=./migrations
|
||||
# valkey
|
||||
VALKEY_HOST=127.0.0.1
|
||||
VALKEY_PORT=6379
|
||||
VALKEY_USERNAME=
|
||||
VALKEY_PASSWORD=
|
||||
VALKEY_IS_SSL=false
|
||||
3
backend/.gitignore
vendored
3
backend/.gitignore
vendored
@@ -17,4 +17,5 @@ ui/build/*
|
||||
pgdata-for-restore/
|
||||
temp/
|
||||
cmd.exe
|
||||
temp/
|
||||
temp/
|
||||
valkey-data/
|
||||
@@ -2,10 +2,10 @@ run:
|
||||
go run cmd/main.go
|
||||
|
||||
test:
|
||||
go test -p=1 -count=1 -failfast -timeout 10m .\internal\...
|
||||
go test -p=1 -count=1 -failfast -timeout 15m ./internal/...
|
||||
|
||||
lint:
|
||||
golangci-lint fmt && golangci-lint run
|
||||
golangci-lint fmt ./cmd/... ./internal/... && golangci-lint run ./cmd/... ./internal/...
|
||||
|
||||
migration-create:
|
||||
goose create $(name) sql
|
||||
|
||||
@@ -15,6 +15,8 @@ import (
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/disk"
|
||||
@@ -25,10 +27,13 @@ import (
|
||||
"databasus-backend/internal/features/restores"
|
||||
"databasus-backend/internal/features/storages"
|
||||
system_healthcheck "databasus-backend/internal/features/system/healthcheck"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
task_registry "databasus-backend/internal/features/tasks/registry"
|
||||
users_controllers "databasus-backend/internal/features/users/controllers"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
env_utils "databasus-backend/internal/util/env"
|
||||
files_utils "databasus-backend/internal/util/files"
|
||||
"databasus-backend/internal/util/logger"
|
||||
@@ -52,7 +57,23 @@ import (
|
||||
func main() {
|
||||
log := logger.GetLogger()
|
||||
|
||||
runMigrations(log)
|
||||
cache_utils.TestCacheConnection()
|
||||
|
||||
if config.GetEnv().IsPrimaryNode {
|
||||
log.Info("Clearing cache...")
|
||||
|
||||
err := cache_utils.ClearAllCache()
|
||||
if err != nil {
|
||||
log.Error("Failed to clear cache", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
if config.GetEnv().IsPrimaryNode {
|
||||
runMigrations(log)
|
||||
} else {
|
||||
log.Info("Skipping migrations (IS_PRIMARY_NODE is false)")
|
||||
}
|
||||
|
||||
// create directories that used for backups and restore
|
||||
err := files_utils.EnsureDirectories([]string{
|
||||
@@ -96,7 +117,9 @@ func main() {
|
||||
enableCors(ginApp)
|
||||
setUpRoutes(ginApp)
|
||||
setUpDependencies()
|
||||
|
||||
runBackgroundTasks(log)
|
||||
|
||||
mountFrontend(ginApp)
|
||||
|
||||
startServerWithGracefulShutdown(log, ginApp)
|
||||
@@ -183,6 +206,7 @@ func setUpRoutes(r *gin.Engine) {
|
||||
userController := users_controllers.GetUserController()
|
||||
userController.RegisterRoutes(v1)
|
||||
system_healthcheck.GetHealthcheckController().RegisterRoutes(v1)
|
||||
backups.GetBackupController().RegisterPublicRoutes(v1)
|
||||
|
||||
// Setup auth middleware
|
||||
userService := users_services.GetUserService()
|
||||
@@ -217,31 +241,69 @@ func setUpDependencies() {
|
||||
audit_logs.SetupDependencies()
|
||||
notifiers.SetupDependencies()
|
||||
storages.SetupDependencies()
|
||||
backups_config.SetupDependencies()
|
||||
task_cancellation.SetupDependencies()
|
||||
}
|
||||
|
||||
func runBackgroundTasks(log *slog.Logger) {
|
||||
log.Info("Preparing to run background tasks...")
|
||||
|
||||
// Create context that will be cancelled on shutdown
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Set up signal handling for graceful shutdown
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-quit
|
||||
log.Info("Shutdown signal received, cancelling all background tasks")
|
||||
cancel()
|
||||
}()
|
||||
|
||||
err := files_utils.CleanFolder(config.GetEnv().TempFolder)
|
||||
if err != nil {
|
||||
log.Error("Failed to clean temp folder", "error", err)
|
||||
}
|
||||
|
||||
go runWithPanicLogging(log, "backup background service", func() {
|
||||
backups.GetBackupBackgroundService().Run()
|
||||
})
|
||||
if config.GetEnv().IsPrimaryNode {
|
||||
log.Info("Starting primary node background tasks...")
|
||||
|
||||
go runWithPanicLogging(log, "restore background service", func() {
|
||||
restores.GetRestoreBackgroundService().Run()
|
||||
})
|
||||
go runWithPanicLogging(log, "backup background service", func() {
|
||||
backuping.GetBackupsScheduler().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "healthcheck attempt background service", func() {
|
||||
healthcheck_attempt.GetHealthcheckAttemptBackgroundService().Run()
|
||||
})
|
||||
go runWithPanicLogging(log, "restore background service", func() {
|
||||
restores.GetRestoreBackgroundService().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "audit log cleanup background service", func() {
|
||||
audit_logs.GetAuditLogBackgroundService().Run()
|
||||
})
|
||||
go runWithPanicLogging(log, "healthcheck attempt background service", func() {
|
||||
healthcheck_attempt.GetHealthcheckAttemptBackgroundService().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "audit log cleanup background service", func() {
|
||||
audit_logs.GetAuditLogBackgroundService().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "download token cleanup background service", func() {
|
||||
backups_download.GetDownloadTokenBackgroundService().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "task nodes registry background service", func() {
|
||||
task_registry.GetTaskNodesRegistry().Run(ctx)
|
||||
})
|
||||
} else {
|
||||
log.Info("Skipping primary node tasks as not primary node")
|
||||
}
|
||||
|
||||
if config.GetEnv().IsBackupNode {
|
||||
log.Info("Starting backup node background tasks...")
|
||||
|
||||
go runWithPanicLogging(log, "backup node", func() {
|
||||
backuping.GetBackuperNode().Run(ctx)
|
||||
})
|
||||
} else {
|
||||
log.Info("Skipping backup node tasks as not backup node")
|
||||
}
|
||||
}
|
||||
|
||||
func runWithPanicLogging(log *slog.Logger, serviceName string, fn func()) {
|
||||
@@ -284,16 +346,13 @@ func generateSwaggerDocs(log *slog.Logger) {
|
||||
func runMigrations(log *slog.Logger) {
|
||||
log.Info("Running database migrations...")
|
||||
|
||||
cmd := exec.Command("goose", "up")
|
||||
cmd := exec.Command("goose", "-dir", "./migrations", "up")
|
||||
cmd.Env = append(
|
||||
os.Environ(),
|
||||
"GOOSE_DRIVER=postgres",
|
||||
"GOOSE_DBSTRING="+config.GetEnv().DatabaseDsn,
|
||||
)
|
||||
|
||||
// Set the working directory to where migrations are located
|
||||
cmd.Dir = "./migrations"
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
log.Error("Failed to run migrations", "error", err, "output", string(output))
|
||||
|
||||
@@ -19,6 +19,21 @@ services:
|
||||
command: -p 5437
|
||||
shm_size: 10gb
|
||||
|
||||
# Valkey for caching
|
||||
dev-valkey:
|
||||
image: valkey/valkey:9.0.1-alpine
|
||||
ports:
|
||||
- "${VALKEY_PORT:-6379}:6379"
|
||||
volumes:
|
||||
- ./valkey-data:/data
|
||||
container_name: dev-valkey
|
||||
healthcheck:
|
||||
test: ["CMD", "valkey-cli", "ping"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
start_period: 20s
|
||||
|
||||
# Test MinIO container
|
||||
test-minio:
|
||||
image: minio/minio:latest
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
module databasus-backend
|
||||
|
||||
go 1.24.4
|
||||
go 1.24.9
|
||||
|
||||
require (
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0
|
||||
@@ -25,9 +25,9 @@ require (
|
||||
github.com/swaggo/files v1.0.1
|
||||
github.com/swaggo/gin-swagger v1.6.0
|
||||
github.com/swaggo/swag v1.16.4
|
||||
github.com/valkey-io/valkey-go v1.0.70
|
||||
go.mongodb.org/mongo-driver v1.17.6
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/time v0.14.0
|
||||
gorm.io/driver/postgres v1.5.11
|
||||
gorm.io/gorm v1.26.1
|
||||
)
|
||||
@@ -185,6 +185,7 @@ require (
|
||||
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
golang.org/x/term v0.38.0 // indirect
|
||||
golang.org/x/time v0.14.0 // indirect
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
|
||||
gopkg.in/validator.v2 v2.0.1 // indirect
|
||||
moul.io/http2curl/v2 v2.3.0 // indirect
|
||||
@@ -269,7 +270,7 @@ require (
|
||||
go.opentelemetry.io/otel/metric v1.38.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.38.0 // indirect
|
||||
golang.org/x/arch v0.17.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/net v0.48.0 // indirect
|
||||
golang.org/x/oauth2 v0.33.0
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/sys v0.39.0 // indirect
|
||||
|
||||
@@ -539,8 +539,8 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
||||
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
|
||||
github.com/onsi/ginkgo/v2 v2.17.3 h1:oJcvKpIb7/8uLpDDtnQuf18xVnwKp8DTD7DQ6gTd/MU=
|
||||
github.com/onsi/ginkgo/v2 v2.17.3/go.mod h1:nP2DPOQoNsQmsVyv5rDA8JkXQoCs6goXIvr/PRJ1eCc=
|
||||
github.com/onsi/gomega v1.37.0 h1:CdEG8g0S133B4OswTDC/5XPSzE1OeP29QOioj2PID2Y=
|
||||
github.com/onsi/gomega v1.37.0/go.mod h1:8D9+Txp43QWKhM24yyOBEdpkzN8FvJyAwecBgsU4KU0=
|
||||
github.com/onsi/gomega v1.38.3 h1:eTX+W6dobAYfFeGC2PV6RwXRu/MyT+cQguijutvkpSM=
|
||||
github.com/onsi/gomega v1.38.3/go.mod h1:ZCU1pkQcXDO5Sl9/VVEGlDyp+zm0m1cmeG5TOzLgdh4=
|
||||
github.com/oracle/oci-go-sdk/v65 v65.104.0 h1:l9awEvzWvxmYhy/97A0hZ87pa7BncYXmcO/S8+rvgK0=
|
||||
github.com/oracle/oci-go-sdk/v65 v65.104.0/go.mod h1:oB8jFGVc/7/zJ+DbleE8MzGHjhs2ioCz5stRTdZdIcY=
|
||||
github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg=
|
||||
@@ -660,6 +660,8 @@ github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY=
|
||||
github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
|
||||
github.com/unknwon/goconfig v1.0.0 h1:rS7O+CmUdli1T+oDm7fYj1MwqNWtEJfNj+FqcUHML8U=
|
||||
github.com/unknwon/goconfig v1.0.0/go.mod h1:qu2ZQ/wcC/if2u32263HTVC39PeOQRSmidQk3DuDFQ8=
|
||||
github.com/valkey-io/valkey-go v1.0.70 h1:mjYNT8qiazxDAJ0QNQ8twWT/YFOkOoRd40ERV2mB49Y=
|
||||
github.com/valkey-io/valkey-go v1.0.70/go.mod h1:VGhZ6fs68Qrn2+OhH+6waZH27bjpgQOiLyUQyXuYK5k=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM=
|
||||
@@ -720,6 +722,8 @@ go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
|
||||
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
|
||||
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
|
||||
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/arch v0.17.0 h1:4O3dfLzd+lQewptAHqjewQZQDyEdejz3VwgeYwkZneU=
|
||||
golang.org/x/arch v0.17.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
@@ -818,8 +822,8 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
||||
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/ilyakaznacheev/cleanenv"
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
@@ -29,6 +30,14 @@ type EnvVariables struct {
|
||||
MariadbInstallDir string `env:"MARIADB_INSTALL_DIR"`
|
||||
MongodbInstallDir string `env:"MONGODB_INSTALL_DIR"`
|
||||
|
||||
ShowDbInstallationVerificationLogs bool `env:"SHOW_DB_INSTALLATION_VERIFICATION_LOGS"`
|
||||
|
||||
NodeID string
|
||||
IsManyNodesMode bool `env:"IS_MANY_NODES_MODE"`
|
||||
IsPrimaryNode bool `env:"IS_PRIMARY_NODE"`
|
||||
IsBackupNode bool `env:"IS_BACKUP_NODE"`
|
||||
NodeNetworkThroughputMBs int `env:"NODE_NETWORK_THROUGHPUT_MBPS"`
|
||||
|
||||
DataFolder string
|
||||
TempFolder string
|
||||
SecretKeyPath string
|
||||
@@ -79,6 +88,13 @@ type EnvVariables struct {
|
||||
TestMongodb70Port string `env:"TEST_MONGODB_70_PORT"`
|
||||
TestMongodb82Port string `env:"TEST_MONGODB_82_PORT"`
|
||||
|
||||
// Valkey
|
||||
ValkeyHost string `env:"VALKEY_HOST" required:"true"`
|
||||
ValkeyPort string `env:"VALKEY_PORT" required:"true"`
|
||||
ValkeyUsername string `env:"VALKEY_USERNAME"`
|
||||
ValkeyPassword string `env:"VALKEY_PASSWORD"`
|
||||
ValkeyIsSsl bool `env:"VALKEY_IS_SSL" required:"true"`
|
||||
|
||||
// oauth
|
||||
GitHubClientID string `env:"GITHUB_CLIENT_ID"`
|
||||
GitHubClientSecret string `env:"GITHUB_CLIENT_SECRET"`
|
||||
@@ -155,6 +171,11 @@ func loadEnvVariables() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Set default value for ShowDbInstallationVerificationLogs if not defined
|
||||
if os.Getenv("SHOW_DB_INSTALLATION_VERIFICATION_LOGS") == "" {
|
||||
env.ShowDbInstallationVerificationLogs = true
|
||||
}
|
||||
|
||||
for _, arg := range os.Args {
|
||||
if strings.Contains(arg, "test") {
|
||||
env.IsTesting = true
|
||||
@@ -178,16 +199,56 @@ func loadEnvVariables() {
|
||||
log.Info("ENV_MODE loaded", "mode", env.EnvMode)
|
||||
|
||||
env.PostgresesInstallDir = filepath.Join(backendRoot, "tools", "postgresql")
|
||||
tools.VerifyPostgresesInstallation(log, env.EnvMode, env.PostgresesInstallDir)
|
||||
tools.VerifyPostgresesInstallation(
|
||||
log,
|
||||
env.EnvMode,
|
||||
env.PostgresesInstallDir,
|
||||
env.ShowDbInstallationVerificationLogs,
|
||||
)
|
||||
|
||||
env.MysqlInstallDir = filepath.Join(backendRoot, "tools", "mysql")
|
||||
tools.VerifyMysqlInstallation(log, env.EnvMode, env.MysqlInstallDir)
|
||||
tools.VerifyMysqlInstallation(
|
||||
log,
|
||||
env.EnvMode,
|
||||
env.MysqlInstallDir,
|
||||
env.ShowDbInstallationVerificationLogs,
|
||||
)
|
||||
|
||||
env.MariadbInstallDir = filepath.Join(backendRoot, "tools", "mariadb")
|
||||
tools.VerifyMariadbInstallation(log, env.EnvMode, env.MariadbInstallDir)
|
||||
tools.VerifyMariadbInstallation(
|
||||
log,
|
||||
env.EnvMode,
|
||||
env.MariadbInstallDir,
|
||||
env.ShowDbInstallationVerificationLogs,
|
||||
)
|
||||
|
||||
env.MongodbInstallDir = filepath.Join(backendRoot, "tools", "mongodb")
|
||||
tools.VerifyMongodbInstallation(log, env.EnvMode, env.MongodbInstallDir)
|
||||
tools.VerifyMongodbInstallation(
|
||||
log,
|
||||
env.EnvMode,
|
||||
env.MongodbInstallDir,
|
||||
env.ShowDbInstallationVerificationLogs,
|
||||
)
|
||||
|
||||
env.NodeID = uuid.New().String()
|
||||
if env.NodeNetworkThroughputMBs == 0 {
|
||||
env.NodeNetworkThroughputMBs = 125 // 1 Gbit/s
|
||||
}
|
||||
|
||||
if !env.IsManyNodesMode {
|
||||
env.IsPrimaryNode = true
|
||||
env.IsBackupNode = true
|
||||
}
|
||||
|
||||
// Valkey
|
||||
if env.ValkeyHost == "" {
|
||||
log.Error("VALKEY_HOST is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
if env.ValkeyPort == "" {
|
||||
log.Error("VALKEY_PORT is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Store the data and temp folders one level below the root
|
||||
// (projectRoot/databasus-data -> /databasus-data)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
)
|
||||
@@ -11,23 +11,25 @@ type AuditLogBackgroundService struct {
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *AuditLogBackgroundService) Run() {
|
||||
func (s *AuditLogBackgroundService) Run(ctx context.Context) {
|
||||
s.logger.Info("Starting audit log cleanup background service")
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
if config.IsShouldShutdown() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.cleanOldAuditLogs(); err != nil {
|
||||
s.logger.Error("Failed to clean old audit logs", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.cleanOldAuditLogs(); err != nil {
|
||||
s.logger.Error("Failed to clean old audit logs", "error", err)
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
@@ -50,7 +51,7 @@ func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
|
||||
|
||||
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "only administrators can view global audit logs" {
|
||||
if errors.Is(err, ErrOnlyAdminsCanViewGlobalLogs) {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -99,7 +100,7 @@ func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
|
||||
|
||||
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view user audit logs" {
|
||||
if errors.Is(err, ErrInsufficientPermissionsToViewLogs) {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
12
backend/internal/features/audit_logs/errors.go
Normal file
12
backend/internal/features/audit_logs/errors.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package audit_logs
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrOnlyAdminsCanViewGlobalLogs = errors.New(
|
||||
"only administrators can view global audit logs",
|
||||
)
|
||||
ErrInsufficientPermissionsToViewLogs = errors.New(
|
||||
"insufficient permissions to view user audit logs",
|
||||
)
|
||||
)
|
||||
@@ -1,7 +1,6 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
@@ -44,7 +43,7 @@ func (s *AuditLogService) GetGlobalAuditLogs(
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
if user.Role != user_enums.UserRoleAdmin {
|
||||
return nil, errors.New("only administrators can view global audit logs")
|
||||
return nil, ErrOnlyAdminsCanViewGlobalLogs
|
||||
}
|
||||
|
||||
limit := request.Limit
|
||||
@@ -79,7 +78,7 @@ func (s *AuditLogService) GetUserAuditLogs(
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
// Users can view their own logs, ADMIN can view any user's logs
|
||||
if user.Role != user_enums.UserRoleAdmin && user.ID != targetUserID {
|
||||
return nil, errors.New("insufficient permissions to view user audit logs")
|
||||
return nil, ErrInsufficientPermissionsToViewLogs
|
||||
}
|
||||
|
||||
limit := request.Limit
|
||||
|
||||
@@ -1,254 +0,0 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/storages"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/period"
|
||||
"log/slog"
|
||||
"time"
|
||||
)
|
||||
|
||||
type BackupBackgroundService struct {
|
||||
backupService *BackupService
|
||||
backupRepository *BackupRepository
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
storageService *storages.StorageService
|
||||
|
||||
lastBackupTime time.Time
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) Run() {
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
|
||||
if err := s.failBackupsInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail backups in progress", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
if config.IsShouldShutdown() {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.cleanOldBackups(); err != nil {
|
||||
s.logger.Error("Failed to clean old backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.runPendingBackups(); err != nil {
|
||||
s.logger.Error("Failed to run pending backups", "error", err)
|
||||
}
|
||||
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
time.Sleep(1 * time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) IsBackupsWorkerRunning() bool {
|
||||
// if last backup time is more than 5 minutes ago, return false
|
||||
return s.lastBackupTime.After(time.Now().UTC().Add(-5 * time.Minute))
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) failBackupsInProgress() error {
|
||||
backupsInProgress, err := s.backupRepository.FindByStatus(BackupStatusInProgress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backup := range backupsInProgress {
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(backup.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
failMessage := "Backup failed due to application restart"
|
||||
backup.FailMessage = &failMessage
|
||||
backup.Status = BackupStatusFailed
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
s.backupService.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupFailed,
|
||||
&failMessage,
|
||||
)
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) cleanOldBackups() error {
|
||||
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
backupStorePeriod := backupConfig.StorePeriod
|
||||
|
||||
if backupStorePeriod == period.PeriodForever {
|
||||
continue
|
||||
}
|
||||
|
||||
storeDuration := backupStorePeriod.ToDuration()
|
||||
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
|
||||
|
||||
oldBackups, err := s.backupRepository.FindBackupsBeforeDate(
|
||||
backupConfig.DatabaseID,
|
||||
dateBeforeBackupsShouldBeDeleted,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to find old backups for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, backup := range oldBackups {
|
||||
storage, err := s.storageService.GetStorageByID(backup.StorageID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to get storage by ID",
|
||||
"storageId",
|
||||
backup.StorageID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
err = storage.DeleteFile(encryptor, backup.ID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
|
||||
}
|
||||
|
||||
if err := s.backupRepository.DeleteByID(backup.ID); err != nil {
|
||||
s.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Deleted old backup",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) runPendingBackups() error {
|
||||
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
if backupConfig.BackupInterval == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
lastBackup, err := s.backupRepository.FindLastByDatabaseID(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to get last backup for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
var lastBackupTime *time.Time
|
||||
if lastBackup != nil {
|
||||
lastBackupTime = &lastBackup.CreatedAt
|
||||
}
|
||||
|
||||
remainedBackupTryCount := s.GetRemainedBackupTryCount(lastBackup)
|
||||
|
||||
if backupConfig.BackupInterval.ShouldTriggerBackup(time.Now().UTC(), lastBackupTime) ||
|
||||
remainedBackupTryCount > 0 {
|
||||
s.logger.Info(
|
||||
"Triggering scheduled backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"intervalType",
|
||||
backupConfig.BackupInterval.Interval,
|
||||
)
|
||||
|
||||
go s.backupService.MakeBackup(backupConfig.DatabaseID, remainedBackupTryCount == 1)
|
||||
s.logger.Info(
|
||||
"Successfully triggered scheduled backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRemainedBackupTryCount returns the number of remaining backup tries for a given backup.
|
||||
// If the backup is not failed or the backup config does not allow retries, it returns 0.
|
||||
// If the backup is failed and the backup config allows retries, it returns the number of remaining tries.
|
||||
// If the backup is failed and the backup config does not allow retries, it returns 0.
|
||||
func (s *BackupBackgroundService) GetRemainedBackupTryCount(lastBackup *Backup) int {
|
||||
if lastBackup == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if lastBackup.Status != BackupStatusFailed {
|
||||
return 0
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(lastBackup.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
if !backupConfig.IsRetryIfFailed {
|
||||
return 0
|
||||
}
|
||||
|
||||
maxFailedTriesCount := backupConfig.MaxFailedTriesCount
|
||||
|
||||
lastBackups, err := s.backupRepository.FindByDatabaseIDWithLimit(
|
||||
lastBackup.DatabaseID,
|
||||
maxFailedTriesCount,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to find last backups by database ID", "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
lastFailedBackups := make([]*Backup, 0)
|
||||
|
||||
for _, backup := range lastBackups {
|
||||
if backup.Status == BackupStatusFailed {
|
||||
lastFailedBackups = append(lastFailedBackups, backup)
|
||||
}
|
||||
}
|
||||
|
||||
return maxFailedTriesCount - len(lastFailedBackups)
|
||||
}
|
||||
@@ -1,321 +0,0 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/util/period"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
// Enable backups for the database
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add old backup
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusCompleted,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 2)
|
||||
|
||||
// cleanup
|
||||
for _, backup := range backups {
|
||||
err := backupRepository.DeleteByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
// Enable backups for the database
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add recent backup (1 hour ago)
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusCompleted,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1) // Should still be 1 backup, no new backup created
|
||||
|
||||
// cleanup
|
||||
for _, backup := range backups {
|
||||
err := backupRepository.DeleteByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T) {
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
// Enable backups for the database with retries disabled
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
backupConfig.IsRetryIfFailed = false
|
||||
backupConfig.MaxFailedTriesCount = 0
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add failed backup
|
||||
failMessage := "backup failed"
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1) // Should still be 1 backup, no retry attempted
|
||||
|
||||
// cleanup
|
||||
for _, backup := range backups {
|
||||
err := backupRepository.DeleteByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
// Enable backups for the database with retries enabled
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
backupConfig.IsRetryIfFailed = true
|
||||
backupConfig.MaxFailedTriesCount = 3
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add failed backup
|
||||
failMessage := "backup failed"
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 2) // Should have 2 backups, retry was attempted
|
||||
|
||||
// cleanup
|
||||
for _, backup := range backups {
|
||||
err := backupRepository.DeleteByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(100 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *testing.T) {
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
// Enable backups for the database with retries enabled
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
backupConfig.IsRetryIfFailed = true
|
||||
backupConfig.MaxFailedTriesCount = 3
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
failMessage := "backup failed"
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
}
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 3) // Should have 3 backups, not more than max
|
||||
|
||||
// cleanup
|
||||
for _, backup := range backups {
|
||||
err := backupRepository.DeleteByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BackupContextManager struct {
|
||||
mu sync.RWMutex
|
||||
cancelFuncs map[uuid.UUID]context.CancelFunc
|
||||
cancelledBackups map[uuid.UUID]bool
|
||||
}
|
||||
|
||||
func NewBackupContextManager() *BackupContextManager {
|
||||
return &BackupContextManager{
|
||||
cancelFuncs: make(map[uuid.UUID]context.CancelFunc),
|
||||
cancelledBackups: make(map[uuid.UUID]bool),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) RegisterBackup(backupID uuid.UUID, cancelFunc context.CancelFunc) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.cancelFuncs[backupID] = cancelFunc
|
||||
delete(m.cancelledBackups, backupID)
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) CancelBackup(backupID uuid.UUID) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.cancelledBackups[backupID] {
|
||||
return nil
|
||||
}
|
||||
|
||||
cancelFunc, exists := m.cancelFuncs[backupID]
|
||||
if exists {
|
||||
cancelFunc()
|
||||
delete(m.cancelFuncs, backupID)
|
||||
}
|
||||
|
||||
m.cancelledBackups[backupID] = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) IsCancelled(backupID uuid.UUID) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.cancelledBackups[backupID]
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) UnregisterBackup(backupID uuid.UUID) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.cancelFuncs, backupID)
|
||||
delete(m.cancelledBackups, backupID)
|
||||
}
|
||||
365
backend/internal/features/backups/backups/backuping/backuper.go
Normal file
365
backend/internal/features/backups/backups/backuping/backuper.go
Normal file
@@ -0,0 +1,365 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"databasus-backend/internal/config"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
task_registry "databasus-backend/internal/features/tasks/registry"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
heartbeatTickerInterval = 15 * time.Second
|
||||
backuperHeathcheckThreshold = 5 * time.Minute
|
||||
)
|
||||
|
||||
type BackuperNode struct {
|
||||
databaseService *databases.DatabaseService
|
||||
fieldEncryptor util_encryption.FieldEncryptor
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
backupRepository *backups_core.BackupRepository
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
storageService *storages.StorageService
|
||||
notificationSender backups_core.NotificationSender
|
||||
backupCancelManager *tasks_cancellation.TaskCancelManager
|
||||
tasksRegistry *task_registry.TaskNodesRegistry
|
||||
logger *slog.Logger
|
||||
createBackupUseCase backups_core.CreateBackupUsecase
|
||||
nodeID uuid.UUID
|
||||
|
||||
lastHeartbeat time.Time
|
||||
}
|
||||
|
||||
func (n *BackuperNode) Run(ctx context.Context) {
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
|
||||
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
|
||||
|
||||
backupNode := task_registry.TaskNode{
|
||||
ID: n.nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
}
|
||||
|
||||
if err := n.tasksRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
|
||||
n.logger.Error("Failed to register node in registry", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
|
||||
n.MakeBackup(backupID, isCallNotifier)
|
||||
if err := n.tasksRegistry.PublishTaskCompletion(n.nodeID.String(), backupID); err != nil {
|
||||
n.logger.Error(
|
||||
"Failed to publish backup completion",
|
||||
"error",
|
||||
err,
|
||||
"backupID",
|
||||
backupID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if err := n.tasksRegistry.SubscribeNodeForTasksAssignment(n.nodeID.String(), backupHandler); err != nil {
|
||||
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := n.tasksRegistry.UnsubscribeNodeForTasksAssignments(); err != nil {
|
||||
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(heartbeatTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
|
||||
|
||||
if err := n.tasksRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
|
||||
n.logger.Error("Failed to unregister node from registry", "error", err)
|
||||
}
|
||||
|
||||
return
|
||||
case <-ticker.C:
|
||||
n.sendHeartbeat(&backupNode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (n *BackuperNode) IsBackuperRunning() bool {
|
||||
return n.lastHeartbeat.After(time.Now().UTC().Add(-backuperHeathcheckThreshold))
|
||||
}
|
||||
|
||||
func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
backup, err := n.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get backup by ID", "backupId", backupID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
databaseID := backup.DatabaseID
|
||||
|
||||
database, err := n.databaseService.GetDatabaseByID(databaseID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get database by ID", "databaseId", databaseID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
backupConfig, err := n.backupConfigService.GetBackupConfigByDbId(databaseID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if backupConfig.StorageID == nil {
|
||||
n.logger.Error("Backup config storage ID is not defined")
|
||||
return
|
||||
}
|
||||
|
||||
storage, err := n.storageService.GetStorageByID(*backupConfig.StorageID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get storage by ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now().UTC()
|
||||
|
||||
backupProgressListener := func(
|
||||
completedMBs float64,
|
||||
) {
|
||||
backup.BackupSizeMb = completedMBs
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to update backup progress", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
n.backupCancelManager.RegisterTask(backup.ID, cancel)
|
||||
defer n.backupCancelManager.UnregisterTask(backup.ID)
|
||||
|
||||
backupMetadata, err := n.createBackupUseCase.Execute(
|
||||
ctx,
|
||||
backup.ID,
|
||||
backupConfig,
|
||||
database,
|
||||
storage,
|
||||
backupProgressListener,
|
||||
)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
|
||||
// Log detailed error information for debugging
|
||||
n.logger.Error("Backup execution failed",
|
||||
"backupId", backup.ID,
|
||||
"databaseId", databaseID,
|
||||
"databaseType", database.Type,
|
||||
"storageId", storage.ID,
|
||||
"storageType", storage.Type,
|
||||
"error", err,
|
||||
"errorMessage", errMsg,
|
||||
)
|
||||
|
||||
// Check if backup was cancelled (not due to shutdown)
|
||||
isCancelled := strings.Contains(errMsg, "backup cancelled") ||
|
||||
strings.Contains(errMsg, "context canceled") ||
|
||||
errors.Is(err, context.Canceled)
|
||||
isShutdown := strings.Contains(errMsg, "shutdown")
|
||||
|
||||
if isCancelled && !isShutdown {
|
||||
n.logger.Warn("Backup was cancelled by user or system",
|
||||
"backupId", backup.ID,
|
||||
"isCancelled", isCancelled,
|
||||
"isShutdown", isShutdown,
|
||||
)
|
||||
|
||||
backup.Status = backups_core.BackupStatusCanceled
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save cancelled backup", "error", err)
|
||||
}
|
||||
|
||||
// Delete partial backup from storage
|
||||
storage, storageErr := n.storageService.GetStorageByID(backup.StorageID)
|
||||
if storageErr == nil {
|
||||
if deleteErr := storage.DeleteFile(n.fieldEncryptor, backup.ID); deleteErr != nil {
|
||||
n.logger.Error(
|
||||
"Failed to delete partial backup file",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
deleteErr,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
backup.FailMessage = &errMsg
|
||||
backup.Status = backups_core.BackupStatusFailed
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
if updateErr := n.databaseService.SetBackupError(databaseID, errMsg); updateErr != nil {
|
||||
n.logger.Error(
|
||||
"Failed to update database last backup time",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
updateErr,
|
||||
)
|
||||
}
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save backup", "error", err)
|
||||
}
|
||||
|
||||
n.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupFailed,
|
||||
&errMsg,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
backup.Status = backups_core.BackupStatusCompleted
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
// Update backup with encryption metadata if provided
|
||||
if backupMetadata != nil {
|
||||
backup.EncryptionSalt = backupMetadata.EncryptionSalt
|
||||
backup.EncryptionIV = backupMetadata.EncryptionIV
|
||||
backup.Encryption = backupMetadata.Encryption
|
||||
}
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save backup", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Update database last backup time
|
||||
now := time.Now().UTC()
|
||||
if updateErr := n.databaseService.SetLastBackupTime(databaseID, now); updateErr != nil {
|
||||
n.logger.Error(
|
||||
"Failed to update database last backup time",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
updateErr,
|
||||
)
|
||||
}
|
||||
|
||||
if backup.Status != backups_core.BackupStatusCompleted && !isCallNotifier {
|
||||
return
|
||||
}
|
||||
|
||||
n.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupSuccess,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
func (n *BackuperNode) SendBackupNotification(
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
backup *backups_core.Backup,
|
||||
notificationType backups_config.BackupNotificationType,
|
||||
errorMessage *string,
|
||||
) {
|
||||
database, err := n.databaseService.GetDatabaseByID(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
workspace, err := n.workspaceService.GetWorkspaceByID(*database.WorkspaceID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, notifier := range database.Notifiers {
|
||||
if !slices.Contains(
|
||||
backupConfig.SendNotificationsOn,
|
||||
notificationType,
|
||||
) {
|
||||
continue
|
||||
}
|
||||
|
||||
title := ""
|
||||
switch notificationType {
|
||||
case backups_config.NotificationBackupFailed:
|
||||
title = fmt.Sprintf(
|
||||
"❌ Backup failed for database \"%s\" (workspace \"%s\")",
|
||||
database.Name,
|
||||
workspace.Name,
|
||||
)
|
||||
case backups_config.NotificationBackupSuccess:
|
||||
title = fmt.Sprintf(
|
||||
"✅ Backup completed for database \"%s\" (workspace \"%s\")",
|
||||
database.Name,
|
||||
workspace.Name,
|
||||
)
|
||||
}
|
||||
|
||||
message := ""
|
||||
if errorMessage != nil {
|
||||
message = *errorMessage
|
||||
} else {
|
||||
// Format size conditionally
|
||||
var sizeStr string
|
||||
if backup.BackupSizeMb < 1024 {
|
||||
sizeStr = fmt.Sprintf("%.2f MB", backup.BackupSizeMb)
|
||||
} else {
|
||||
sizeGB := backup.BackupSizeMb / 1024
|
||||
sizeStr = fmt.Sprintf("%.2f GB", sizeGB)
|
||||
}
|
||||
|
||||
// Format duration as "0m 0s 0ms"
|
||||
totalMs := backup.BackupDurationMs
|
||||
minutes := totalMs / (1000 * 60)
|
||||
seconds := (totalMs % (1000 * 60)) / 1000
|
||||
durationStr := fmt.Sprintf("%dm %ds", minutes, seconds)
|
||||
|
||||
message = fmt.Sprintf(
|
||||
"Backup completed successfully in %s.\nCompressed backup size: %s",
|
||||
durationStr,
|
||||
sizeStr,
|
||||
)
|
||||
}
|
||||
|
||||
n.notificationSender.SendNotification(
|
||||
¬ifier,
|
||||
title,
|
||||
message,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (n *BackuperNode) sendHeartbeat(backupNode *task_registry.TaskNode) {
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
if err := n.tasksRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil {
|
||||
n.logger.Error("Failed to send heartbeat", "error", err)
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package backups
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -8,17 +8,15 @@ import (
|
||||
"time"
|
||||
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -26,6 +24,7 @@ import (
|
||||
)
|
||||
|
||||
func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
@@ -50,22 +49,19 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
|
||||
t.Run("BackupFailed_FailNotificationSent", func(t *testing.T) {
|
||||
mockNotificationSender := &MockNotificationSender{}
|
||||
backupService := &BackupService{
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
backupRepository,
|
||||
notifiers.GetNotifierService(),
|
||||
mockNotificationSender,
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption_secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
&CreateFailedBackupUsecase{},
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
nil,
|
||||
NewBackupContextManager(),
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.notificationSender = mockNotificationSender
|
||||
backuperNode.createBackupUseCase = &CreateFailedBackupUsecase{}
|
||||
|
||||
// Create a backup record directly that will be looked up by MakeBackup
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err := backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Set up expectations
|
||||
mockNotificationSender.On("SendNotification",
|
||||
@@ -78,7 +74,7 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
}),
|
||||
).Once()
|
||||
|
||||
backupService.MakeBackup(database.ID, true)
|
||||
backuperNode.MakeBackup(backup.ID, true)
|
||||
|
||||
// Verify all expectations were met
|
||||
mockNotificationSender.AssertExpectations(t)
|
||||
@@ -86,6 +82,19 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
|
||||
t.Run("BackupSuccess_SuccessNotificationSent", func(t *testing.T) {
|
||||
mockNotificationSender := &MockNotificationSender{}
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.notificationSender = mockNotificationSender
|
||||
backuperNode.createBackupUseCase = &CreateSuccessBackupUsecase{}
|
||||
|
||||
// Create a backup record directly that will be looked up by MakeBackup
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err := backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Set up expectations
|
||||
mockNotificationSender.On("SendNotification",
|
||||
@@ -98,24 +107,7 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
}),
|
||||
).Once()
|
||||
|
||||
backupService := &BackupService{
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
backupRepository,
|
||||
notifiers.GetNotifierService(),
|
||||
mockNotificationSender,
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption_secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
&CreateSuccessBackupUsecase{},
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
nil,
|
||||
NewBackupContextManager(),
|
||||
}
|
||||
|
||||
backupService.MakeBackup(database.ID, true)
|
||||
backuperNode.MakeBackup(backup.ID, true)
|
||||
|
||||
// Verify all expectations were met
|
||||
mockNotificationSender.AssertExpectations(t)
|
||||
@@ -123,22 +115,19 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
|
||||
t.Run("BackupSuccess_VerifyNotificationContent", func(t *testing.T) {
|
||||
mockNotificationSender := &MockNotificationSender{}
|
||||
backupService := &BackupService{
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
backupRepository,
|
||||
notifiers.GetNotifierService(),
|
||||
mockNotificationSender,
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption_secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
&CreateSuccessBackupUsecase{},
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
nil,
|
||||
NewBackupContextManager(),
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.notificationSender = mockNotificationSender
|
||||
backuperNode.createBackupUseCase = &CreateSuccessBackupUsecase{}
|
||||
|
||||
// Create a backup record directly that will be looked up by MakeBackup
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err := backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// capture arguments
|
||||
var capturedNotifier *notifiers.Notifier
|
||||
@@ -155,7 +144,7 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
capturedMessage = args.Get(2).(string)
|
||||
}).Once()
|
||||
|
||||
backupService.MakeBackup(database.ID, true)
|
||||
backuperNode.MakeBackup(backup.ID, true)
|
||||
|
||||
// Verify expectations were met
|
||||
mockNotificationSender.AssertExpectations(t)
|
||||
71
backend/internal/features/backups/backups/backuping/di.go
Normal file
71
backend/internal/features/backups/backups/backuping/di.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/usecases"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
task_registry "databasus-backend/internal/features/tasks/registry"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var backupRepository = &backups_core.BackupRepository{}
|
||||
|
||||
var taskCancelManager = tasks_cancellation.GetTaskCancelManager()
|
||||
|
||||
var nodesRegistry = task_registry.GetTaskNodesRegistry()
|
||||
|
||||
func getNodeID() uuid.UUID {
|
||||
nodeIDStr := config.GetEnv().NodeID
|
||||
nodeID, err := uuid.Parse(nodeIDStr)
|
||||
if err != nil {
|
||||
logger.GetLogger().Error("Failed to parse node ID from config", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
return nodeID
|
||||
}
|
||||
|
||||
var backuperNode = &BackuperNode{
|
||||
databases.GetDatabaseService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
taskCancelManager,
|
||||
nodesRegistry,
|
||||
logger.GetLogger(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
getNodeID(),
|
||||
time.Time{},
|
||||
}
|
||||
|
||||
var backupsScheduler = &BackupsScheduler{
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
taskCancelManager,
|
||||
nodesRegistry,
|
||||
time.Now().UTC(),
|
||||
logger.GetLogger(),
|
||||
make(map[uuid.UUID]BackupToNodeRelation),
|
||||
backuperNode,
|
||||
}
|
||||
|
||||
func GetBackupsScheduler() *BackupsScheduler {
|
||||
return backupsScheduler
|
||||
}
|
||||
|
||||
func GetBackuperNode() *BackuperNode {
|
||||
return backuperNode
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
package backuping
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
type BackupToNodeRelation struct {
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
BackupsIDs []uuid.UUID `json:"backupsIds"`
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package backups
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
603
backend/internal/features/backups/backups/backuping/scheduler.go
Normal file
603
backend/internal/features/backups/backups/backuping/scheduler.go
Normal file
@@ -0,0 +1,603 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"databasus-backend/internal/config"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/storages"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
task_registry "databasus-backend/internal/features/tasks/registry"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/period"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
schedulerStartupDelay = 1 * time.Minute
|
||||
schedulerTickerInterval = 1 * time.Minute
|
||||
schedulerHealthcheckThreshold = 5 * time.Minute
|
||||
)
|
||||
|
||||
type BackupsScheduler struct {
|
||||
backupRepository *backups_core.BackupRepository
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
storageService *storages.StorageService
|
||||
taskCancelManager *task_cancellation.TaskCancelManager
|
||||
tasksRegistry *task_registry.TaskNodesRegistry
|
||||
|
||||
lastBackupTime time.Time
|
||||
logger *slog.Logger
|
||||
|
||||
backupToNodeRelations map[uuid.UUID]BackupToNodeRelation
|
||||
backuperNode *BackuperNode
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) Run(ctx context.Context) {
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
|
||||
if config.GetEnv().IsManyNodesMode {
|
||||
// wait other nodes to start
|
||||
time.Sleep(schedulerStartupDelay)
|
||||
}
|
||||
|
||||
if err := s.failBackupsInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail backups in progress", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err := s.tasksRegistry.SubscribeForTasksCompletions(s.onBackupCompleted); err != nil {
|
||||
s.logger.Error("Failed to subscribe to backup completions", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := s.tasksRegistry.UnsubscribeForTasksCompletions(); err != nil {
|
||||
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(schedulerTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.cleanOldBackups(); err != nil {
|
||||
s.logger.Error("Failed to clean old backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.checkDeadNodesAndFailBackups(); err != nil {
|
||||
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.runPendingBackups(); err != nil {
|
||||
s.logger.Error("Failed to run pending backups", "error", err)
|
||||
}
|
||||
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) IsSchedulerRunning() bool {
|
||||
// if last backup time is more than 5 minutes ago, return false
|
||||
return s.lastBackupTime.After(time.Now().UTC().Add(-schedulerHealthcheckThreshold))
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) failBackupsInProgress() error {
|
||||
backupsInProgress, err := s.backupRepository.FindByStatus(backups_core.BackupStatusInProgress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backup := range backupsInProgress {
|
||||
if err := s.taskCancelManager.CancelTask(backup.ID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to cancel backup via task cancel manager",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(backup.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
failMessage := "Backup failed due to application restart"
|
||||
backup.FailMessage = &failMessage
|
||||
backup.Status = backups_core.BackupStatusFailed
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
s.backuperNode.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupFailed,
|
||||
&failMessage,
|
||||
)
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool) {
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(databaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if backupConfig.StorageID == nil {
|
||||
s.logger.Error("Backup config storage ID is nil", "databaseId", databaseID)
|
||||
return
|
||||
}
|
||||
|
||||
leastBusyNodeID, err := s.calculateLeastBusyNode()
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to calculate least busy node",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: backupConfig.DatabaseID,
|
||||
StorageID: *backupConfig.StorageID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
BackupSizeMb: 0,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to save backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.tasksRegistry.IncrementTasksInProgress(leastBusyNodeID.String()); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to increment backups in progress",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.tasksRegistry.AssignTaskToNode(leastBusyNodeID.String(), backup.ID, isCallNotifier); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to submit backup",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
if decrementErr := s.tasksRegistry.DecrementTasksInProgress(leastBusyNodeID.String()); decrementErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress after submit failure",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"error",
|
||||
decrementErr,
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if relation, exists := s.backupToNodeRelations[*leastBusyNodeID]; exists {
|
||||
relation.BackupsIDs = append(relation.BackupsIDs, backup.ID)
|
||||
s.backupToNodeRelations[*leastBusyNodeID] = relation
|
||||
} else {
|
||||
s.backupToNodeRelations[*leastBusyNodeID] = BackupToNodeRelation{
|
||||
NodeID: *leastBusyNodeID,
|
||||
BackupsIDs: []uuid.UUID{backup.ID},
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Successfully triggered scheduled backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
)
|
||||
}
|
||||
|
||||
// GetRemainedBackupTryCount returns the number of remaining backup tries for a given backup.
|
||||
// If the backup is not failed or the backup config does not allow retries, it returns 0.
|
||||
// If the backup is failed and the backup config allows retries, it returns the number of remaining tries.
|
||||
// If the backup is failed and the backup config does not allow retries, it returns 0.
|
||||
func (s *BackupsScheduler) GetRemainedBackupTryCount(lastBackup *backups_core.Backup) int {
|
||||
if lastBackup == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if lastBackup.Status != backups_core.BackupStatusFailed {
|
||||
return 0
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(lastBackup.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
if !backupConfig.IsRetryIfFailed {
|
||||
return 0
|
||||
}
|
||||
|
||||
maxFailedTriesCount := backupConfig.MaxFailedTriesCount
|
||||
|
||||
lastBackups, err := s.backupRepository.FindByDatabaseIDWithLimit(
|
||||
lastBackup.DatabaseID,
|
||||
maxFailedTriesCount,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to find last backups by database ID", "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
lastFailedBackups := make([]*backups_core.Backup, 0)
|
||||
|
||||
for _, backup := range lastBackups {
|
||||
if backup.Status == backups_core.BackupStatusFailed {
|
||||
lastFailedBackups = append(lastFailedBackups, backup)
|
||||
}
|
||||
}
|
||||
|
||||
return maxFailedTriesCount - len(lastFailedBackups)
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) cleanOldBackups() error {
|
||||
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
backupStorePeriod := backupConfig.StorePeriod
|
||||
|
||||
if backupStorePeriod == period.PeriodForever {
|
||||
continue
|
||||
}
|
||||
|
||||
storeDuration := backupStorePeriod.ToDuration()
|
||||
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
|
||||
|
||||
oldBackups, err := s.backupRepository.FindBackupsBeforeDate(
|
||||
backupConfig.DatabaseID,
|
||||
dateBeforeBackupsShouldBeDeleted,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to find old backups for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, backup := range oldBackups {
|
||||
storage, err := s.storageService.GetStorageByID(backup.StorageID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to get storage by ID",
|
||||
"storageId",
|
||||
backup.StorageID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
err = storage.DeleteFile(encryptor, backup.ID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
|
||||
}
|
||||
|
||||
if err := s.backupRepository.DeleteByID(backup.ID); err != nil {
|
||||
s.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Deleted old backup",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) runPendingBackups() error {
|
||||
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
if backupConfig.BackupInterval == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
lastBackup, err := s.backupRepository.FindLastByDatabaseID(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to get last backup for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
var lastBackupTime *time.Time
|
||||
if lastBackup != nil {
|
||||
lastBackupTime = &lastBackup.CreatedAt
|
||||
}
|
||||
|
||||
remainedBackupTryCount := s.GetRemainedBackupTryCount(lastBackup)
|
||||
|
||||
if backupConfig.BackupInterval.ShouldTriggerBackup(time.Now().UTC(), lastBackupTime) ||
|
||||
remainedBackupTryCount > 0 {
|
||||
s.logger.Info(
|
||||
"Triggering scheduled backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"intervalType",
|
||||
backupConfig.BackupInterval.Interval,
|
||||
)
|
||||
|
||||
s.StartBackup(backupConfig.DatabaseID, remainedBackupTryCount == 1)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
|
||||
nodes, err := s.tasksRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get available nodes: %w", err)
|
||||
}
|
||||
|
||||
if len(nodes) == 0 {
|
||||
return nil, fmt.Errorf("no nodes available")
|
||||
}
|
||||
|
||||
stats, err := s.tasksRegistry.GetNodesStats()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get backup nodes stats: %w", err)
|
||||
}
|
||||
|
||||
statsMap := make(map[uuid.UUID]int)
|
||||
for _, stat := range stats {
|
||||
statsMap[stat.ID] = stat.ActiveTasks
|
||||
}
|
||||
|
||||
var bestNode *task_registry.TaskNode
|
||||
var bestScore float64 = -1
|
||||
|
||||
for i := range nodes {
|
||||
node := &nodes[i]
|
||||
|
||||
activeBackups := statsMap[node.ID]
|
||||
|
||||
var score float64
|
||||
if node.ThroughputMBs > 0 {
|
||||
score = float64(activeBackups) / float64(node.ThroughputMBs)
|
||||
} else {
|
||||
score = float64(activeBackups) * 1000
|
||||
}
|
||||
|
||||
if bestNode == nil || score < bestScore {
|
||||
bestNode = node
|
||||
bestScore = score
|
||||
}
|
||||
}
|
||||
|
||||
if bestNode == nil {
|
||||
return nil, fmt.Errorf("no suitable nodes available")
|
||||
}
|
||||
|
||||
return &bestNode.ID, nil
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUID) {
|
||||
nodeID, err := uuid.Parse(nodeIDStr)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to parse node ID from completion message",
|
||||
"nodeId",
|
||||
nodeIDStr,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify this task is actually a backup (registry contains multiple task types)
|
||||
_, err = s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
// Not a backup task, ignore it
|
||||
return
|
||||
}
|
||||
|
||||
relation, exists := s.backupToNodeRelations[nodeID]
|
||||
if !exists {
|
||||
s.logger.Warn(
|
||||
"Received completion for unknown node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
newBackupIDs := make([]uuid.UUID, 0)
|
||||
found := false
|
||||
for _, id := range relation.BackupsIDs {
|
||||
if id == backupID {
|
||||
found = true
|
||||
continue
|
||||
}
|
||||
newBackupIDs = append(newBackupIDs, id)
|
||||
}
|
||||
|
||||
if !found {
|
||||
s.logger.Warn(
|
||||
"Backup not found in node's backup list",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if len(newBackupIDs) == 0 {
|
||||
delete(s.backupToNodeRelations, nodeID)
|
||||
} else {
|
||||
relation.BackupsIDs = newBackupIDs
|
||||
s.backupToNodeRelations[nodeID] = relation
|
||||
}
|
||||
|
||||
if err := s.tasksRegistry.DecrementTasksInProgress(nodeIDStr); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) checkDeadNodesAndFailBackups() error {
|
||||
nodes, err := s.tasksRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get available nodes: %w", err)
|
||||
}
|
||||
|
||||
aliveNodeIDs := make(map[uuid.UUID]bool)
|
||||
for _, node := range nodes {
|
||||
aliveNodeIDs[node.ID] = true
|
||||
}
|
||||
|
||||
for nodeID, relation := range s.backupToNodeRelations {
|
||||
if aliveNodeIDs[nodeID] {
|
||||
continue
|
||||
}
|
||||
|
||||
s.logger.Warn(
|
||||
"Node is dead, failing its backups",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupCount",
|
||||
len(relation.BackupsIDs),
|
||||
)
|
||||
|
||||
for _, backupID := range relation.BackupsIDs {
|
||||
backup, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to find backup for dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
failMessage := "Backup failed due to node unavailability"
|
||||
backup.FailMessage = &failMessage
|
||||
backup.Status = backups_core.BackupStatusFailed
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to save failed backup for dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := s.tasksRegistry.DecrementTasksInProgress(nodeID.String()); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress for dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Failed backup due to dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
)
|
||||
}
|
||||
|
||||
delete(s.backupToNodeRelations, nodeID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
280
backend/internal/features/backups/backups/backuping/testing.go
Normal file
280
backend/internal/features/backups/backups/backuping/testing.go
Normal file
@@ -0,0 +1,280 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/usecases"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
task_registry "databasus-backend/internal/features/tasks/registry"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func CreateTestRouter() *gin.Engine {
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
databases.GetDatabaseController(),
|
||||
backups_config.GetBackupConfigController(),
|
||||
)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func CreateTestBackuperNode() *BackuperNode {
|
||||
return &BackuperNode{
|
||||
databases.GetDatabaseService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
taskCancelManager,
|
||||
nodesRegistry,
|
||||
logger.GetLogger(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
uuid.New(),
|
||||
time.Time{},
|
||||
}
|
||||
}
|
||||
|
||||
// WaitForBackupCompletion waits for a new backup to be created and completed (or failed)
|
||||
// for the given database. It checks for backups with count greater than expectedInitialCount.
|
||||
func WaitForBackupCompletion(
|
||||
t *testing.T,
|
||||
databaseID uuid.UUID,
|
||||
expectedInitialCount int,
|
||||
timeout time.Duration,
|
||||
) {
|
||||
deadline := time.Now().UTC().Add(timeout)
|
||||
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
backups, err := backupRepository.FindByDatabaseID(databaseID)
|
||||
if err != nil {
|
||||
t.Logf("WaitForBackupCompletion: error finding backups: %v", err)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
t.Logf(
|
||||
"WaitForBackupCompletion: found %d backups (expected > %d)",
|
||||
len(backups),
|
||||
expectedInitialCount,
|
||||
)
|
||||
|
||||
if len(backups) > expectedInitialCount {
|
||||
// Check if the newest backup has completed or failed
|
||||
newestBackup := backups[0]
|
||||
t.Logf("WaitForBackupCompletion: newest backup status: %s", newestBackup.Status)
|
||||
|
||||
if newestBackup.Status == backups_core.BackupStatusCompleted ||
|
||||
newestBackup.Status == backups_core.BackupStatusFailed ||
|
||||
newestBackup.Status == backups_core.BackupStatusCanceled {
|
||||
t.Logf(
|
||||
"WaitForBackupCompletion: backup finished with status %s",
|
||||
newestBackup.Status,
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Logf("WaitForBackupCompletion: timeout waiting for backup to complete")
|
||||
}
|
||||
|
||||
// StartBackuperNodeForTest starts a BackuperNode in a goroutine for testing.
|
||||
// The node registers itself in the registry and subscribes to backup assignments.
|
||||
// Returns a context cancel function that should be deferred to stop the node.
|
||||
func StartBackuperNodeForTest(t *testing.T, backuperNode *BackuperNode) context.CancelFunc {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
backuperNode.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Poll registry for node presence instead of fixed sleep
|
||||
deadline := time.Now().UTC().Add(5 * time.Second)
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
nodes, err := nodesRegistry.GetAvailableNodes()
|
||||
if err == nil {
|
||||
for _, node := range nodes {
|
||||
if node.ID == backuperNode.nodeID {
|
||||
t.Logf("BackuperNode registered in registry: %s", backuperNode.nodeID)
|
||||
|
||||
return func() {
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("BackuperNode stopped gracefully")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Log("BackuperNode stop timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Fatalf("BackuperNode failed to register in registry within timeout")
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartSchedulerForTest starts the BackupsScheduler in a goroutine for testing.
|
||||
// The scheduler subscribes to task completions and manages backup lifecycle.
|
||||
// Returns a context cancel function that should be deferred to stop the scheduler.
|
||||
func StartSchedulerForTest(t *testing.T) context.CancelFunc {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
GetBackupsScheduler().Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Give scheduler time to subscribe to completions
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
t.Log("BackupsScheduler started")
|
||||
|
||||
return func() {
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("BackupsScheduler stopped gracefully")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Log("BackupsScheduler stop timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StopBackuperNodeForTest stops the BackuperNode by canceling its context.
|
||||
// It waits for the node to unregister from the registry.
|
||||
func StopBackuperNodeForTest(t *testing.T, cancel context.CancelFunc, backuperNode *BackuperNode) {
|
||||
cancel()
|
||||
|
||||
// Wait for node to unregister from registry
|
||||
deadline := time.Now().UTC().Add(2 * time.Second)
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
nodes, err := nodesRegistry.GetAvailableNodes()
|
||||
if err == nil {
|
||||
found := false
|
||||
for _, node := range nodes {
|
||||
if node.ID == backuperNode.nodeID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Logf("BackuperNode unregistered from registry: %s", backuperNode.nodeID)
|
||||
return
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Logf("BackuperNode stop completed for %s", backuperNode.nodeID)
|
||||
}
|
||||
|
||||
func CreateMockNodeInRegistry(nodeID uuid.UUID, throughputMBs int, lastHeartbeat time.Time) error {
|
||||
backupNode := task_registry.TaskNode{
|
||||
ID: nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: lastHeartbeat,
|
||||
}
|
||||
|
||||
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
|
||||
}
|
||||
|
||||
func UpdateNodeHeartbeatDirectly(
|
||||
nodeID uuid.UUID,
|
||||
throughputMBs int,
|
||||
lastHeartbeat time.Time,
|
||||
) error {
|
||||
backupNode := task_registry.TaskNode{
|
||||
ID: nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: lastHeartbeat,
|
||||
}
|
||||
|
||||
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
|
||||
}
|
||||
|
||||
func GetNodeFromRegistry(nodeID uuid.UUID) (*task_registry.TaskNode, error) {
|
||||
nodes, err := nodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
if node.ID == nodeID {
|
||||
return &node, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("node not found")
|
||||
}
|
||||
|
||||
// WaitForActiveTasksDecrease waits for the active task count to decrease below the initial count.
|
||||
// It polls the registry every 500ms until the count decreases or the timeout is reached.
|
||||
// Returns true if the count decreased, false if timeout was reached.
|
||||
func WaitForActiveTasksDecrease(
|
||||
t *testing.T,
|
||||
nodeID uuid.UUID,
|
||||
initialCount int,
|
||||
timeout time.Duration,
|
||||
) bool {
|
||||
deadline := time.Now().UTC().Add(timeout)
|
||||
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
stats, err := nodesRegistry.GetNodesStats()
|
||||
if err != nil {
|
||||
t.Logf("WaitForActiveTasksDecrease: error getting node stats: %v", err)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, stat := range stats {
|
||||
if stat.ID == nodeID {
|
||||
t.Logf(
|
||||
"WaitForActiveTasksDecrease: current active tasks = %d (initial = %d)",
|
||||
stat.ActiveTasks,
|
||||
initialCount,
|
||||
)
|
||||
if stat.ActiveTasks < initialCount {
|
||||
t.Logf(
|
||||
"WaitForActiveTasksDecrease: active tasks decreased from %d to %d",
|
||||
initialCount,
|
||||
stat.ActiveTasks,
|
||||
)
|
||||
return true
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Logf("WaitForActiveTasksDecrease: timeout waiting for active tasks to decrease")
|
||||
return false
|
||||
}
|
||||
@@ -1,11 +1,15 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"context"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
"databasus-backend/internal/features/databases"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -18,11 +22,17 @@ type BackupController struct {
|
||||
func (c *BackupController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
router.GET("/backups", c.GetBackups)
|
||||
router.POST("/backups", c.MakeBackup)
|
||||
router.GET("/backups/:id/file", c.GetFile)
|
||||
router.POST("/backups/:id/download-token", c.GenerateDownloadToken)
|
||||
router.DELETE("/backups/:id", c.DeleteBackup)
|
||||
router.POST("/backups/:id/cancel", c.CancelBackup)
|
||||
}
|
||||
|
||||
// RegisterPublicRoutes registers routes that don't require Bearer authentication
|
||||
// (they have their own authentication mechanisms like download tokens)
|
||||
func (c *BackupController) RegisterPublicRoutes(router *gin.RouterGroup) {
|
||||
router.GET("/backups/:id/file", c.GetFile)
|
||||
}
|
||||
|
||||
// GetBackups
|
||||
// @Summary Get backups for a database
|
||||
// @Description Get paginated backups for the specified database
|
||||
@@ -159,17 +169,17 @@ func (c *BackupController) CancelBackup(ctx *gin.Context) {
|
||||
ctx.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// GetFile
|
||||
// @Summary Download a backup file
|
||||
// @Description Download the backup file for the specified backup
|
||||
// GenerateDownloadToken
|
||||
// @Summary Generate short-lived download token
|
||||
// @Description Generate a token for downloading a backup file (valid for 5 minutes)
|
||||
// @Tags backups
|
||||
// @Param id path string true "Backup ID"
|
||||
// @Success 200 {file} file
|
||||
// @Success 200 {object} backups_download.GenerateDownloadTokenResponse
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Router /backups/{id}/file [get]
|
||||
func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
// @Failure 409 {object} map[string]string "Download already in progress"
|
||||
// @Router /backups/{id}/download-token [post]
|
||||
func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
@@ -182,30 +192,122 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
fileReader, backup, database, err := c.backupService.GetBackupFile(user, id)
|
||||
response, err := c.backupService.GenerateDownloadToken(user, id)
|
||||
if err != nil {
|
||||
if err == backups_download.ErrDownloadAlreadyInProgress {
|
||||
ctx.JSON(
|
||||
http.StatusConflict,
|
||||
gin.H{
|
||||
"error": "Download already in progress for some of backups. Please wait until previous download completed or cancel it",
|
||||
},
|
||||
)
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// GetFile
|
||||
// @Summary Download a backup file
|
||||
// @Description Download the backup file for the specified backup using a download token.
|
||||
// @Description
|
||||
// @Description **Download Concurrency Control:**
|
||||
// @Description - Only one download per user is allowed at a time
|
||||
// @Description - If a download is already in progress, returns 409 Conflict
|
||||
// @Description - Downloads are tracked using cache with 5-second TTL and 3-second heartbeat
|
||||
// @Description - Browser cancellations automatically release the download lock
|
||||
// @Description - Server crashes are handled via automatic cache expiry (5 seconds)
|
||||
// @Tags backups
|
||||
// @Param id path string true "Backup ID"
|
||||
// @Param token query string true "Download token"
|
||||
// @Success 200 {file} file
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 409 {object} map[string]string "Download already in progress"
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /backups/{id}/file [get]
|
||||
func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
token := ctx.Query("token")
|
||||
if token == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "download token is required"})
|
||||
return
|
||||
}
|
||||
|
||||
backupIDParam := ctx.Param("id")
|
||||
backupID, err := uuid.Parse(backupIDParam)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
|
||||
return
|
||||
}
|
||||
|
||||
downloadToken, rateLimiter, err := c.backupService.ValidateDownloadToken(token)
|
||||
if err != nil {
|
||||
if err == backups_download.ErrDownloadAlreadyInProgress {
|
||||
ctx.JSON(
|
||||
http.StatusConflict,
|
||||
gin.H{
|
||||
"error": "download already in progress for this user. Please wait until previous download completed or cancel it",
|
||||
},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"})
|
||||
return
|
||||
}
|
||||
|
||||
if downloadToken.BackupID != backupID {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"})
|
||||
return
|
||||
}
|
||||
|
||||
fileReader, backup, database, err := c.backupService.GetBackupFileWithoutAuth(
|
||||
downloadToken.BackupID,
|
||||
)
|
||||
if err != nil {
|
||||
c.backupService.UnregisterDownload(downloadToken.UserID)
|
||||
c.backupService.ReleaseDownloadLock(downloadToken.UserID)
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
rateLimitedReader := backups_download.NewRateLimitedReader(fileReader, rateLimiter)
|
||||
|
||||
heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background())
|
||||
defer func() {
|
||||
if err := fileReader.Close(); err != nil {
|
||||
cancelHeartbeat()
|
||||
c.backupService.UnregisterDownload(downloadToken.UserID)
|
||||
c.backupService.ReleaseDownloadLock(downloadToken.UserID)
|
||||
if err := rateLimitedReader.Close(); err != nil {
|
||||
fmt.Printf("Error closing file reader: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go c.startDownloadHeartbeat(heartbeatCtx, downloadToken.UserID)
|
||||
|
||||
filename := c.generateBackupFilename(backup, database)
|
||||
|
||||
if backup.BackupSizeMb > 0 {
|
||||
sizeBytes := int64(backup.BackupSizeMb * 1024 * 1024)
|
||||
ctx.Header("Content-Length", fmt.Sprintf("%d", sizeBytes))
|
||||
}
|
||||
|
||||
ctx.Header("Content-Type", "application/octet-stream")
|
||||
ctx.Header(
|
||||
"Content-Disposition",
|
||||
fmt.Sprintf("attachment; filename=\"%s\"", filename),
|
||||
)
|
||||
|
||||
_, err = io.Copy(ctx.Writer, fileReader)
|
||||
_, err = io.Copy(ctx.Writer, rateLimitedReader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "failed to stream file"})
|
||||
fmt.Printf("Error streaming file: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.backupService.WriteAuditLogForDownload(downloadToken.UserID, backup, database)
|
||||
}
|
||||
|
||||
type MakeBackupRequest struct {
|
||||
@@ -213,7 +315,7 @@ type MakeBackupRequest struct {
|
||||
}
|
||||
|
||||
func (c *BackupController) generateBackupFilename(
|
||||
backup *Backup,
|
||||
backup *backups_core.Backup,
|
||||
database *databases.Database,
|
||||
) string {
|
||||
// Format timestamp as YYYY-MM-DD_HH-mm-ss
|
||||
@@ -270,3 +372,17 @@ func sanitizeFilename(name string) string {
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func (c *BackupController) startDownloadHeartbeat(ctx context.Context, userID uuid.UUID) {
|
||||
ticker := time.NewTicker(backups_download.GetDownloadHeartbeatInterval())
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.backupService.RefreshDownloadLock(userID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -15,7 +16,10 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
@@ -87,7 +91,13 @@ func Test_GetBackups_PermissionsEnforced(t *testing.T) {
|
||||
testUserToken = owner.Token
|
||||
} else {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
member,
|
||||
*tt.workspaceRole,
|
||||
owner.Token,
|
||||
router,
|
||||
)
|
||||
testUserToken = member.Token
|
||||
}
|
||||
} else {
|
||||
@@ -179,7 +189,13 @@ func Test_CreateBackup_PermissionsEnforced(t *testing.T) {
|
||||
testUserToken = owner.Token
|
||||
} else {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
member,
|
||||
*tt.workspaceRole,
|
||||
owner.Token,
|
||||
router,
|
||||
)
|
||||
testUserToken = member.Token
|
||||
}
|
||||
} else {
|
||||
@@ -309,7 +325,13 @@ func Test_DeleteBackup_PermissionsEnforced(t *testing.T) {
|
||||
testUserToken = owner.Token
|
||||
} else {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
member,
|
||||
*tt.workspaceRole,
|
||||
owner.Token,
|
||||
router,
|
||||
)
|
||||
testUserToken = member.Token
|
||||
}
|
||||
} else {
|
||||
@@ -378,7 +400,7 @@ func Test_DeleteBackup_AuditLogWritten(t *testing.T) {
|
||||
assert.True(t, found, "Audit log for backup deletion not found")
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_PermissionsEnforced(t *testing.T) {
|
||||
func Test_GenerateDownloadToken_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
workspaceRole *users_enums.WorkspaceRole
|
||||
@@ -387,28 +409,28 @@ func Test_DownloadBackup_PermissionsEnforced(t *testing.T) {
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "workspace viewer can download backup",
|
||||
name: "workspace viewer can generate token",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace member can download backup",
|
||||
name: "workspace member can generate token",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "non-member cannot download backup",
|
||||
name: "non-member cannot generate token",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "global admin can download backup",
|
||||
name: "global admin can generate token",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: true,
|
||||
expectSuccess: true,
|
||||
@@ -433,7 +455,13 @@ func Test_DownloadBackup_PermissionsEnforced(t *testing.T) {
|
||||
testUserToken = owner.Token
|
||||
} else {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
member,
|
||||
*tt.workspaceRole,
|
||||
owner.Token,
|
||||
router,
|
||||
)
|
||||
testUserToken = member.Token
|
||||
}
|
||||
} else {
|
||||
@@ -441,21 +469,244 @@ func Test_DownloadBackup_PermissionsEnforced(t *testing.T) {
|
||||
testUserToken = nonMember.Token
|
||||
}
|
||||
|
||||
testResp := test_utils.MakeGetRequest(
|
||||
testResp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file", backup.ID.String()),
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+testUserToken,
|
||||
nil,
|
||||
tt.expectedStatusCode,
|
||||
)
|
||||
|
||||
if !tt.expectSuccess {
|
||||
if tt.expectSuccess {
|
||||
var response backups_download.GenerateDownloadTokenResponse
|
||||
err := json.Unmarshal(testResp.Body, &response)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, response.Token)
|
||||
assert.NotEmpty(t, response.Filename)
|
||||
assert.Equal(t, backup.ID, response.BackupID)
|
||||
} else {
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_WithValidToken_Success(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
// Generate download token
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&tokenResponse,
|
||||
)
|
||||
|
||||
// Download with token
|
||||
testResp := test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), tokenResponse.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
// Verify response
|
||||
contentDisposition := testResp.Headers.Get("Content-Disposition")
|
||||
assert.Contains(t, contentDisposition, "attachment")
|
||||
assert.Contains(t, contentDisposition, tokenResponse.Filename)
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_WithoutToken_Unauthorized(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
// Try to download without token
|
||||
testResp := test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file", backup.ID.String()),
|
||||
"",
|
||||
http.StatusUnauthorized,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "download token is required")
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_WithInvalidToken_Unauthorized(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
// Try to download with invalid token
|
||||
testResp := test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), "invalid-token-xyz"),
|
||||
"",
|
||||
http.StatusUnauthorized,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "invalid or expired download token")
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_WithExpiredToken_Unauthorized(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
// Get user for token generation
|
||||
userService := users_services.GetUserService()
|
||||
user, err := userService.GetUserFromToken(owner.Token)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create an expired token directly in the database
|
||||
expiredToken := createExpiredDownloadToken(backup.ID, user.ID)
|
||||
|
||||
// Try to download with expired token
|
||||
testResp := test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), expiredToken),
|
||||
"",
|
||||
http.StatusUnauthorized,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "invalid or expired download token")
|
||||
|
||||
// Verify audit log was NOT created for failed download
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
auditLogService := audit_logs.GetAuditLogService()
|
||||
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
|
||||
workspace.ID,
|
||||
&audit_logs.GetAuditLogsRequest{
|
||||
Limit: 100,
|
||||
Offset: 0,
|
||||
},
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
found := false
|
||||
for _, log := range auditLogs.AuditLogs {
|
||||
if strings.Contains(log.Message, "Backup file downloaded") &&
|
||||
strings.Contains(log.Message, database.Name) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.False(t, found, "Audit log should NOT be created for failed download with expired token")
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_TokenUsedOnce_CannotReuseToken(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
// Generate download token
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&tokenResponse,
|
||||
)
|
||||
|
||||
// Download with token (first time - should succeed)
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), tokenResponse.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
// Try to download again with same token (should fail)
|
||||
testResp := test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), tokenResponse.Token),
|
||||
"",
|
||||
http.StatusUnauthorized,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "invalid or expired download token")
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_WithDifferentBackupToken_Unauthorized(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database1 := createTestDatabase("Database 1", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
configService := backups_config.GetBackupConfigService()
|
||||
config1, err := configService.GetBackupConfigByDbId(database1.ID)
|
||||
assert.NoError(t, err)
|
||||
config1.IsBackupsEnabled = true
|
||||
config1.StorageID = &storage.ID
|
||||
config1.Storage = storage
|
||||
_, err = configService.SaveBackupConfig(config1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backup1 := createTestBackup(database1, owner)
|
||||
|
||||
database2 := createTestDatabase("Database 2", workspace.ID, owner.Token, router)
|
||||
config2, err := configService.GetBackupConfigByDbId(database2.ID)
|
||||
assert.NoError(t, err)
|
||||
config2.IsBackupsEnabled = true
|
||||
config2.StorageID = &storage.ID
|
||||
config2.Storage = storage
|
||||
_, err = configService.SaveBackupConfig(config2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backup2 := createTestBackup(database2, owner)
|
||||
|
||||
// Generate token for backup1
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup1.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&tokenResponse,
|
||||
)
|
||||
|
||||
// Try to use backup1's token to download backup2 (should fail)
|
||||
testResp := test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup2.ID.String(), tokenResponse.Token),
|
||||
"",
|
||||
http.StatusUnauthorized,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "invalid or expired download token")
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
@@ -463,11 +714,24 @@ func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
// Generate download token
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&tokenResponse,
|
||||
)
|
||||
|
||||
// Download with token
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), tokenResponse.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
@@ -542,11 +806,28 @@ func Test_DownloadBackup_ProperFilenameForPostgreSQL(t *testing.T) {
|
||||
|
||||
backup := createTestBackup(database, owner)
|
||||
|
||||
// Generate download token
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&tokenResponse,
|
||||
)
|
||||
|
||||
// Download with token
|
||||
resp := test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/backups/%s/file?token=%s",
|
||||
backup.ID.String(),
|
||||
tokenResponse.Token,
|
||||
),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
@@ -617,22 +898,22 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
|
||||
_, err = configService.SaveBackupConfig(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backup := &Backup{
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: BackupStatusInProgress,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
BackupSizeMb: 0,
|
||||
BackupDurationMs: 0,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
repo := &BackupRepository{}
|
||||
repo := &backups_core.BackupRepository{}
|
||||
err = repo.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Register a cancellable context for the backup
|
||||
GetBackupService().backupContextManager.RegisterBackup(backup.ID, func() {})
|
||||
GetBackupService().taskCancelManager.RegisterTask(backup.ID, func() {})
|
||||
|
||||
resp := test_utils.MakePostRequest(
|
||||
t,
|
||||
@@ -669,6 +950,189 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
|
||||
assert.True(t, foundCancelLog, "Cancel audit log should be created")
|
||||
}
|
||||
|
||||
func Test_ConcurrentDownloadPrevention(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var token1Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token1Response,
|
||||
)
|
||||
|
||||
var token2Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token2Response,
|
||||
)
|
||||
|
||||
downloadInProgress := make(chan bool, 1)
|
||||
downloadComplete := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/backups/%s/file?token=%s",
|
||||
backup.ID.String(),
|
||||
token1Response.Token,
|
||||
),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
downloadComplete <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
service := GetBackupService()
|
||||
if !service.IsDownloadInProgress(owner.UserID) {
|
||||
t.Log("Warning: First download completed before we could test concurrency")
|
||||
<-downloadComplete
|
||||
return
|
||||
}
|
||||
|
||||
downloadInProgress <- true
|
||||
|
||||
resp := test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), token2Response.Token),
|
||||
"",
|
||||
http.StatusConflict,
|
||||
)
|
||||
|
||||
var errorResponse map[string]string
|
||||
err := json.Unmarshal(resp.Body, &errorResponse)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, errorResponse["error"], "download already in progress")
|
||||
|
||||
<-downloadComplete
|
||||
<-downloadInProgress
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
var token3Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token3Response,
|
||||
)
|
||||
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), token3Response.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
t.Log("Database:", database.Name)
|
||||
t.Log(
|
||||
"Successfully prevented concurrent downloads and allowed subsequent downloads after completion",
|
||||
)
|
||||
}
|
||||
|
||||
func Test_GenerateDownloadToken_BlockedWhenDownloadInProgress(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var token1Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token1Response,
|
||||
)
|
||||
|
||||
downloadComplete := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/backups/%s/file?token=%s",
|
||||
backup.ID.String(),
|
||||
token1Response.Token,
|
||||
),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
downloadComplete <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
service := GetBackupService()
|
||||
if !service.IsDownloadInProgress(owner.UserID) {
|
||||
t.Log("Warning: First download completed before we could test token generation blocking")
|
||||
<-downloadComplete
|
||||
return
|
||||
}
|
||||
|
||||
resp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusConflict,
|
||||
)
|
||||
|
||||
var errorResponse map[string]string
|
||||
err := json.Unmarshal(resp.Body, &errorResponse)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, errorResponse["error"], "download already in progress")
|
||||
|
||||
<-downloadComplete
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
var token2Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token2Response,
|
||||
)
|
||||
|
||||
assert.NotEmpty(t, token2Response.Token)
|
||||
assert.NotEqual(t, token1Response.Token, token2Response.Token)
|
||||
|
||||
t.Log("Database:", database.Name)
|
||||
t.Log(
|
||||
"Successfully blocked token generation during download and allowed generation after completion",
|
||||
)
|
||||
}
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
return CreateTestRouter()
|
||||
}
|
||||
@@ -679,7 +1143,13 @@ func createTestDatabase(
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *databases.Database {
|
||||
testDbName := "test_db"
|
||||
env := config.GetEnv()
|
||||
port, err := strconv.Atoi(env.TestPostgres16Port)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
|
||||
}
|
||||
|
||||
testDbName := "testdb"
|
||||
request := databases.Database{
|
||||
Name: name,
|
||||
WorkspaceID: &workspaceID,
|
||||
@@ -687,9 +1157,9 @@ func createTestDatabase(
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
},
|
||||
@@ -752,7 +1222,7 @@ func createTestDatabaseWithBackups(
|
||||
workspace *workspaces_models.Workspace,
|
||||
owner *users_dto.SignInResponseDTO,
|
||||
router *gin.Engine,
|
||||
) (*databases.Database, *Backup) {
|
||||
) (*databases.Database, *backups_core.Backup) {
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
@@ -778,7 +1248,7 @@ func createTestDatabaseWithBackups(
|
||||
func createTestBackup(
|
||||
database *databases.Database,
|
||||
owner *users_dto.SignInResponseDTO,
|
||||
) *Backup {
|
||||
) *backups_core.Backup {
|
||||
userService := users_services.GetUserService()
|
||||
user, err := userService.GetUserFromToken(owner.Token)
|
||||
if err != nil {
|
||||
@@ -790,17 +1260,17 @@ func createTestBackup(
|
||||
panic("No storage found for workspace")
|
||||
}
|
||||
|
||||
backup := &Backup{
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storages[0].ID,
|
||||
Status: BackupStatusCompleted,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10.5,
|
||||
BackupDurationMs: 1000,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
repo := &BackupRepository{}
|
||||
repo := &backups_core.BackupRepository{}
|
||||
if err := repo.Save(backup); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@@ -809,9 +1279,302 @@ func createTestBackup(
|
||||
dummyContent := []byte("dummy backup content for testing")
|
||||
reader := strings.NewReader(string(dummyContent))
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
if err := storages[0].SaveFile(context.Background(), encryption.GetFieldEncryptor(), logger, backup.ID, reader); err != nil {
|
||||
if err := storages[0].SaveFile(
|
||||
context.Background(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
logger,
|
||||
backup.ID,
|
||||
reader,
|
||||
); err != nil {
|
||||
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
|
||||
}
|
||||
|
||||
return backup
|
||||
}
|
||||
|
||||
func createExpiredDownloadToken(backupID, userID uuid.UUID) string {
|
||||
tokenService := GetBackupService().downloadTokenService
|
||||
token, err := tokenService.Generate(backupID, userID)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to generate download token: %v", err))
|
||||
}
|
||||
|
||||
// Manually update the token to be expired
|
||||
repo := &backups_download.DownloadTokenRepository{}
|
||||
downloadToken, err := repo.FindByToken(token)
|
||||
if err != nil || downloadToken == nil {
|
||||
panic(fmt.Sprintf("Failed to find generated token: %v", err))
|
||||
}
|
||||
|
||||
// Set expiration to 10 minutes ago
|
||||
downloadToken.ExpiresAt = time.Now().UTC().Add(-10 * time.Minute)
|
||||
if err := repo.Update(downloadToken); err != nil {
|
||||
panic(fmt.Sprintf("Failed to update token expiration: %v", err))
|
||||
}
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
func Test_BandwidthThrottling_SingleDownload_Uses75Percent(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
bandwidthManager := backups_download.GetBandwidthManager()
|
||||
initialCount := bandwidthManager.GetActiveDownloadCount()
|
||||
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&tokenResponse,
|
||||
)
|
||||
|
||||
downloadStarted := make(chan bool, 1)
|
||||
downloadComplete := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/backups/%s/file?token=%s",
|
||||
backup.ID.String(),
|
||||
tokenResponse.Token,
|
||||
),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
downloadComplete <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
activeCount := bandwidthManager.GetActiveDownloadCount()
|
||||
if activeCount > initialCount {
|
||||
downloadStarted <- true
|
||||
assert.Equal(t, initialCount+1, activeCount, "Should have one active download")
|
||||
}
|
||||
|
||||
<-downloadComplete
|
||||
if len(downloadStarted) > 0 {
|
||||
<-downloadStarted
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
finalCount := bandwidthManager.GetActiveDownloadCount()
|
||||
assert.Equal(t, initialCount, finalCount, "Download should be unregistered after completion")
|
||||
}
|
||||
|
||||
func Test_BandwidthThrottling_MultipleDownloads_ShareBandwidth(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
owner3 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner1, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
owner2,
|
||||
users_enums.WorkspaceRoleMember,
|
||||
owner1.Token,
|
||||
router,
|
||||
)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
owner3,
|
||||
users_enums.WorkspaceRoleMember,
|
||||
owner1.Token,
|
||||
router,
|
||||
)
|
||||
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner1.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
configService := backups_config.GetBackupConfigService()
|
||||
config, err := configService.GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
config.IsBackupsEnabled = true
|
||||
config.StorageID = &storage.ID
|
||||
config.Storage = storage
|
||||
_, err = configService.SaveBackupConfig(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backup1 := createTestBackup(database, owner1)
|
||||
backup2 := createTestBackup(database, owner2)
|
||||
backup3 := createTestBackup(database, owner3)
|
||||
|
||||
var token1, token2, token3 backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup1.ID.String()),
|
||||
"Bearer "+owner1.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token1,
|
||||
)
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup2.ID.String()),
|
||||
"Bearer "+owner2.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token2,
|
||||
)
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup3.ID.String()),
|
||||
"Bearer "+owner3.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token3,
|
||||
)
|
||||
|
||||
bandwidthManager := backups_download.GetBandwidthManager()
|
||||
initialCount := bandwidthManager.GetActiveDownloadCount()
|
||||
|
||||
complete1 := make(chan bool, 1)
|
||||
complete2 := make(chan bool, 1)
|
||||
complete3 := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup1.ID.String(), token1.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
complete1 <- true
|
||||
}()
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup2.ID.String(), token2.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
complete2 <- true
|
||||
}()
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup3.ID.String(), token3.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
complete3 <- true
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
<-complete1
|
||||
<-complete2
|
||||
<-complete3
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
finalCount := bandwidthManager.GetActiveDownloadCount()
|
||||
assert.Equal(t, initialCount, finalCount, "All downloads should be unregistered")
|
||||
}
|
||||
|
||||
func Test_BandwidthThrottling_DynamicAdjustment(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner1, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
owner2,
|
||||
users_enums.WorkspaceRoleMember,
|
||||
owner1.Token,
|
||||
router,
|
||||
)
|
||||
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner1.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
configService := backups_config.GetBackupConfigService()
|
||||
config, err := configService.GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
config.IsBackupsEnabled = true
|
||||
config.StorageID = &storage.ID
|
||||
config.Storage = storage
|
||||
_, err = configService.SaveBackupConfig(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backup1 := createTestBackup(database, owner1)
|
||||
backup2 := createTestBackup(database, owner2)
|
||||
|
||||
var token1, token2 backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup1.ID.String()),
|
||||
"Bearer "+owner1.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token1,
|
||||
)
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup2.ID.String()),
|
||||
"Bearer "+owner2.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token2,
|
||||
)
|
||||
|
||||
bandwidthManager := backups_download.GetBandwidthManager()
|
||||
initialCount := bandwidthManager.GetActiveDownloadCount()
|
||||
|
||||
complete1 := make(chan bool, 1)
|
||||
complete2 := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup1.ID.String(), token1.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
complete1 <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup2.ID.String(), token2.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
complete2 <- true
|
||||
}()
|
||||
|
||||
<-complete1
|
||||
<-complete2
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
finalCount := bandwidthManager.GetActiveDownloadCount()
|
||||
assert.Equal(t, initialCount, finalCount, "All downloads completed and unregistered")
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package backups
|
||||
package backups_core
|
||||
|
||||
type BackupStatus string
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package backups
|
||||
package backups_core
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -1,4 +1,4 @@
|
||||
package backups
|
||||
package backups_core
|
||||
|
||||
import (
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
@@ -1,4 +1,4 @@
|
||||
package backups
|
||||
package backups_core
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/storage"
|
||||
@@ -1,61 +1,47 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
"databasus-backend/internal/features/backups/backups/usecases"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var backupRepository = &BackupRepository{}
|
||||
var backupRepository = &backups_core.BackupRepository{}
|
||||
|
||||
var backupContextManager = NewBackupContextManager()
|
||||
var taskCancelManager = task_cancellation.GetTaskCancelManager()
|
||||
|
||||
var backupService = &BackupService{
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
backupRepository,
|
||||
notifiers.GetNotifierService(),
|
||||
notifiers.GetNotifierService(),
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption_secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
audit_logs.GetAuditLogService(),
|
||||
backupContextManager,
|
||||
}
|
||||
|
||||
var backupBackgroundService = &BackupBackgroundService{
|
||||
backupService,
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
time.Now().UTC(),
|
||||
logger.GetLogger(),
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
backupRepository: backupRepository,
|
||||
notifierService: notifiers.GetNotifierService(),
|
||||
notificationSender: notifiers.GetNotifierService(),
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
secretKeyService: encryption_secrets.GetSecretKeyService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
createBackupUseCase: usecases.GetCreateBackupUsecase(),
|
||||
logger: logger.GetLogger(),
|
||||
backupRemoveListeners: []backups_core.BackupRemoveListener{},
|
||||
workspaceService: workspaces_services.GetWorkspaceService(),
|
||||
auditLogService: audit_logs.GetAuditLogService(),
|
||||
taskCancelManager: taskCancelManager,
|
||||
downloadTokenService: backups_download.GetDownloadTokenService(),
|
||||
backupSchedulerService: backuping.GetBackupsScheduler(),
|
||||
}
|
||||
|
||||
var backupController = &BackupController{
|
||||
backupService,
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
backups_config.
|
||||
GetBackupConfigService().
|
||||
SetDatabaseStorageChangeListener(backupService)
|
||||
|
||||
databases.GetDatabaseService().AddDbRemoveListener(backupService)
|
||||
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
|
||||
backupService: backupService,
|
||||
}
|
||||
|
||||
func GetBackupService() *BackupService {
|
||||
@@ -66,6 +52,11 @@ func GetBackupController() *BackupController {
|
||||
return backupController
|
||||
}
|
||||
|
||||
func GetBackupBackgroundService() *BackupBackgroundService {
|
||||
return backupBackgroundService
|
||||
func SetupDependencies() {
|
||||
backups_config.
|
||||
GetBackupConfigService().
|
||||
SetDatabaseStorageChangeListener(backupService)
|
||||
|
||||
databases.GetDatabaseService().AddDbRemoveListener(backupService)
|
||||
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
|
||||
}
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DownloadTokenBackgroundService struct {
|
||||
downloadTokenService *DownloadTokenService
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *DownloadTokenBackgroundService) Run(ctx context.Context) {
|
||||
s.logger.Info("Starting download token cleanup background service")
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
|
||||
s.logger.Error("Failed to clean expired download tokens", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BandwidthManager struct {
|
||||
mu sync.RWMutex
|
||||
activeDownloads map[uuid.UUID]*activeDownload
|
||||
maxTotalBytesPerSecond int64
|
||||
bytesPerSecondPerDownload int64
|
||||
}
|
||||
|
||||
type activeDownload struct {
|
||||
userID uuid.UUID
|
||||
rateLimiter *RateLimiter
|
||||
}
|
||||
|
||||
func NewBandwidthManager(throughputMBs int) *BandwidthManager {
|
||||
// Use 75% of total throughput
|
||||
maxBytes := int64(throughputMBs) * 1024 * 1024 * 75 / 100
|
||||
|
||||
return &BandwidthManager{
|
||||
activeDownloads: make(map[uuid.UUID]*activeDownload),
|
||||
maxTotalBytesPerSecond: maxBytes,
|
||||
bytesPerSecondPerDownload: maxBytes,
|
||||
}
|
||||
}
|
||||
|
||||
func (bm *BandwidthManager) RegisterDownload(userID uuid.UUID) (*RateLimiter, error) {
|
||||
bm.mu.Lock()
|
||||
defer bm.mu.Unlock()
|
||||
|
||||
if _, exists := bm.activeDownloads[userID]; exists {
|
||||
return nil, fmt.Errorf("download already registered for user %s", userID)
|
||||
}
|
||||
|
||||
rateLimiter := NewRateLimiter(bm.bytesPerSecondPerDownload)
|
||||
|
||||
bm.activeDownloads[userID] = &activeDownload{
|
||||
userID: userID,
|
||||
rateLimiter: rateLimiter,
|
||||
}
|
||||
|
||||
bm.recalculateRates()
|
||||
|
||||
return rateLimiter, nil
|
||||
}
|
||||
|
||||
func (bm *BandwidthManager) UnregisterDownload(userID uuid.UUID) {
|
||||
bm.mu.Lock()
|
||||
defer bm.mu.Unlock()
|
||||
|
||||
delete(bm.activeDownloads, userID)
|
||||
bm.recalculateRates()
|
||||
}
|
||||
|
||||
func (bm *BandwidthManager) GetActiveDownloadCount() int {
|
||||
bm.mu.RLock()
|
||||
defer bm.mu.RUnlock()
|
||||
return len(bm.activeDownloads)
|
||||
}
|
||||
|
||||
func (bm *BandwidthManager) recalculateRates() {
|
||||
activeCount := len(bm.activeDownloads)
|
||||
|
||||
if activeCount == 0 {
|
||||
bm.bytesPerSecondPerDownload = bm.maxTotalBytesPerSecond
|
||||
return
|
||||
}
|
||||
|
||||
newRate := bm.maxTotalBytesPerSecond / int64(activeCount)
|
||||
bm.bytesPerSecondPerDownload = newRate
|
||||
|
||||
for _, download := range bm.activeDownloads {
|
||||
download.rateLimiter.UpdateRate(newRate)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,150 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_BandwidthManager_RegisterSingleDownload(t *testing.T) {
|
||||
throughputMBs := 100
|
||||
manager := NewBandwidthManager(throughputMBs)
|
||||
|
||||
expectedBytesPerSec := int64(100 * 1024 * 1024 * 75 / 100)
|
||||
assert.Equal(t, expectedBytesPerSec, manager.maxTotalBytesPerSecond)
|
||||
assert.Equal(t, expectedBytesPerSec, manager.bytesPerSecondPerDownload)
|
||||
|
||||
userID := uuid.New()
|
||||
rateLimiter, err := manager.RegisterDownload(userID)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, rateLimiter)
|
||||
|
||||
assert.Equal(t, 1, manager.GetActiveDownloadCount())
|
||||
assert.Equal(t, expectedBytesPerSec, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, expectedBytesPerSec, rateLimiter.bytesPerSecond)
|
||||
}
|
||||
|
||||
func Test_BandwidthManager_RegisterMultipleDownloads_BandwidthShared(t *testing.T) {
|
||||
throughputMBs := 100
|
||||
manager := NewBandwidthManager(throughputMBs)
|
||||
|
||||
maxBytes := int64(100 * 1024 * 1024 * 75 / 100)
|
||||
|
||||
user1 := uuid.New()
|
||||
rateLimiter1, err := manager.RegisterDownload(user1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, maxBytes, rateLimiter1.bytesPerSecond)
|
||||
|
||||
user2 := uuid.New()
|
||||
rateLimiter2, err := manager.RegisterDownload(user2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expectedPerDownload := maxBytes / 2
|
||||
assert.Equal(t, expectedPerDownload, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter2.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter2.bytesPerSecond)
|
||||
|
||||
user3 := uuid.New()
|
||||
rateLimiter3, err := manager.RegisterDownload(user3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expectedPerDownload = maxBytes / 3
|
||||
assert.Equal(t, expectedPerDownload, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter2.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter3.bytesPerSecond)
|
||||
assert.Equal(t, 3, manager.GetActiveDownloadCount())
|
||||
}
|
||||
|
||||
func Test_BandwidthManager_UnregisterDownload_BandwidthRebalanced(t *testing.T) {
|
||||
throughputMBs := 100
|
||||
manager := NewBandwidthManager(throughputMBs)
|
||||
|
||||
maxBytes := int64(100 * 1024 * 1024 * 75 / 100)
|
||||
|
||||
user1 := uuid.New()
|
||||
rateLimiter1, _ := manager.RegisterDownload(user1)
|
||||
|
||||
user2 := uuid.New()
|
||||
_, _ = manager.RegisterDownload(user2)
|
||||
|
||||
user3 := uuid.New()
|
||||
rateLimiter3, _ := manager.RegisterDownload(user3)
|
||||
|
||||
assert.Equal(t, 3, manager.GetActiveDownloadCount())
|
||||
expectedPerDownload := maxBytes / 3
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
|
||||
|
||||
manager.UnregisterDownload(user2)
|
||||
|
||||
assert.Equal(t, 2, manager.GetActiveDownloadCount())
|
||||
expectedPerDownload = maxBytes / 2
|
||||
assert.Equal(t, expectedPerDownload, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter3.bytesPerSecond)
|
||||
|
||||
manager.UnregisterDownload(user1)
|
||||
|
||||
assert.Equal(t, 1, manager.GetActiveDownloadCount())
|
||||
assert.Equal(t, maxBytes, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, maxBytes, rateLimiter3.bytesPerSecond)
|
||||
|
||||
manager.UnregisterDownload(user3)
|
||||
assert.Equal(t, 0, manager.GetActiveDownloadCount())
|
||||
assert.Equal(t, maxBytes, manager.bytesPerSecondPerDownload)
|
||||
}
|
||||
|
||||
func Test_BandwidthManager_RegisterDuplicateUser_ReturnsError(t *testing.T) {
|
||||
manager := NewBandwidthManager(100)
|
||||
|
||||
userID := uuid.New()
|
||||
_, err := manager.RegisterDownload(userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = manager.RegisterDownload(userID)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "download already registered")
|
||||
}
|
||||
|
||||
func Test_RateLimiter_TokenBucketBasic(t *testing.T) {
|
||||
bytesPerSec := int64(1024 * 1024)
|
||||
limiter := NewRateLimiter(bytesPerSec)
|
||||
|
||||
assert.Equal(t, bytesPerSec, limiter.bytesPerSecond)
|
||||
assert.Equal(t, bytesPerSec*2, limiter.bucketSize)
|
||||
|
||||
start := time.Now()
|
||||
limiter.Wait(512 * 1024)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.Less(t, elapsed, 100*time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_RateLimiter_UpdateRate(t *testing.T) {
|
||||
limiter := NewRateLimiter(1024 * 1024)
|
||||
|
||||
assert.Equal(t, int64(1024*1024), limiter.bytesPerSecond)
|
||||
|
||||
newRate := int64(2 * 1024 * 1024)
|
||||
limiter.UpdateRate(newRate)
|
||||
|
||||
assert.Equal(t, newRate, limiter.bytesPerSecond)
|
||||
assert.Equal(t, newRate*2, limiter.bucketSize)
|
||||
}
|
||||
|
||||
func Test_RateLimiter_ThrottlesCorrectly(t *testing.T) {
|
||||
bytesPerSec := int64(1024 * 1024)
|
||||
limiter := NewRateLimiter(bytesPerSec)
|
||||
|
||||
limiter.availableTokens = 0
|
||||
|
||||
start := time.Now()
|
||||
limiter.Wait(bytesPerSec / 2)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.GreaterOrEqual(t, elapsed, 400*time.Millisecond)
|
||||
assert.LessOrEqual(t, elapsed, 700*time.Millisecond)
|
||||
}
|
||||
48
backend/internal/features/backups/backups/download/di.go
Normal file
48
backend/internal/features/backups/backups/download/di.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var downloadTokenRepository = &DownloadTokenRepository{}
|
||||
|
||||
var downloadTracker = NewDownloadTracker(cache_utils.GetValkeyClient())
|
||||
|
||||
var bandwidthManager *BandwidthManager
|
||||
var downloadTokenService *DownloadTokenService
|
||||
var downloadTokenBackgroundService *DownloadTokenBackgroundService
|
||||
|
||||
func init() {
|
||||
env := config.GetEnv()
|
||||
throughputMBs := env.NodeNetworkThroughputMBs
|
||||
if throughputMBs == 0 {
|
||||
throughputMBs = 125
|
||||
}
|
||||
bandwidthManager = NewBandwidthManager(throughputMBs)
|
||||
|
||||
downloadTokenService = &DownloadTokenService{
|
||||
downloadTokenRepository,
|
||||
logger.GetLogger(),
|
||||
downloadTracker,
|
||||
bandwidthManager,
|
||||
}
|
||||
|
||||
downloadTokenBackgroundService = &DownloadTokenBackgroundService{
|
||||
downloadTokenService,
|
||||
logger.GetLogger(),
|
||||
}
|
||||
}
|
||||
|
||||
func GetDownloadTokenService() *DownloadTokenService {
|
||||
return downloadTokenService
|
||||
}
|
||||
|
||||
func GetDownloadTokenBackgroundService() *DownloadTokenBackgroundService {
|
||||
return downloadTokenBackgroundService
|
||||
}
|
||||
|
||||
func GetBandwidthManager() *BandwidthManager {
|
||||
return bandwidthManager
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
package backups_download
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
type GenerateDownloadTokenResponse struct {
|
||||
Token string `json:"token"`
|
||||
Filename string `json:"filename"`
|
||||
BackupID uuid.UUID `json:"backupId"`
|
||||
}
|
||||
21
backend/internal/features/backups/backups/download/model.go
Normal file
21
backend/internal/features/backups/backups/download/model.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type DownloadToken struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey"`
|
||||
Token string `json:"token" gorm:"column:token;uniqueIndex;not null"`
|
||||
BackupID uuid.UUID `json:"backupId" gorm:"column:backup_id;not null"`
|
||||
UserID uuid.UUID `json:"userId" gorm:"column:user_id;not null"`
|
||||
ExpiresAt time.Time `json:"expiresAt" gorm:"column:expires_at;not null"`
|
||||
Used bool `json:"used" gorm:"column:used;not null;default:false"`
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;not null"`
|
||||
}
|
||||
|
||||
func (DownloadToken) TableName() string {
|
||||
return "download_tokens"
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RateLimiter struct {
|
||||
mu sync.Mutex
|
||||
bytesPerSecond int64
|
||||
bucketSize int64
|
||||
availableTokens float64
|
||||
lastRefill time.Time
|
||||
}
|
||||
|
||||
func NewRateLimiter(bytesPerSecond int64) *RateLimiter {
|
||||
if bytesPerSecond <= 0 {
|
||||
bytesPerSecond = 1024 * 1024 * 100
|
||||
}
|
||||
|
||||
return &RateLimiter{
|
||||
bytesPerSecond: bytesPerSecond,
|
||||
bucketSize: bytesPerSecond * 2,
|
||||
availableTokens: float64(bytesPerSecond * 2),
|
||||
lastRefill: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) UpdateRate(bytesPerSecond int64) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
if bytesPerSecond <= 0 {
|
||||
bytesPerSecond = 1024 * 1024 * 100
|
||||
}
|
||||
|
||||
rl.bytesPerSecond = bytesPerSecond
|
||||
rl.bucketSize = bytesPerSecond * 2
|
||||
|
||||
if rl.availableTokens > float64(rl.bucketSize) {
|
||||
rl.availableTokens = float64(rl.bucketSize)
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) Wait(bytes int64) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
for {
|
||||
now := time.Now().UTC()
|
||||
elapsed := now.Sub(rl.lastRefill).Seconds()
|
||||
|
||||
tokensToAdd := elapsed * float64(rl.bytesPerSecond)
|
||||
rl.availableTokens += tokensToAdd
|
||||
if rl.availableTokens > float64(rl.bucketSize) {
|
||||
rl.availableTokens = float64(rl.bucketSize)
|
||||
}
|
||||
rl.lastRefill = now
|
||||
|
||||
if rl.availableTokens >= float64(bytes) {
|
||||
rl.availableTokens -= float64(bytes)
|
||||
return
|
||||
}
|
||||
|
||||
tokensNeeded := float64(bytes) - rl.availableTokens
|
||||
waitTime := time.Duration(tokensNeeded/float64(rl.bytesPerSecond)*1000) * time.Millisecond
|
||||
|
||||
if waitTime < time.Millisecond {
|
||||
waitTime = time.Millisecond
|
||||
}
|
||||
|
||||
rl.mu.Unlock()
|
||||
time.Sleep(waitTime)
|
||||
rl.mu.Lock()
|
||||
}
|
||||
}
|
||||
|
||||
type RateLimitedReader struct {
|
||||
reader io.ReadCloser
|
||||
rateLimiter *RateLimiter
|
||||
}
|
||||
|
||||
func NewRateLimitedReader(reader io.ReadCloser, limiter *RateLimiter) *RateLimitedReader {
|
||||
return &RateLimitedReader{
|
||||
reader: reader,
|
||||
rateLimiter: limiter,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RateLimitedReader) Read(p []byte) (n int, err error) {
|
||||
n, err = r.reader.Read(p)
|
||||
if n > 0 {
|
||||
r.rateLimiter.Wait(int64(n))
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *RateLimitedReader) Close() error {
|
||||
return r.reader.Close()
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"databasus-backend/internal/storage"
|
||||
"encoding/base64"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type DownloadTokenRepository struct{}
|
||||
|
||||
func (r *DownloadTokenRepository) Create(token *DownloadToken) error {
|
||||
if token.ID == uuid.Nil {
|
||||
token.ID = uuid.New()
|
||||
}
|
||||
if token.CreatedAt.IsZero() {
|
||||
token.CreatedAt = time.Now().UTC()
|
||||
}
|
||||
return storage.GetDb().Create(token).Error
|
||||
}
|
||||
|
||||
func (r *DownloadTokenRepository) FindByToken(token string) (*DownloadToken, error) {
|
||||
var downloadToken DownloadToken
|
||||
|
||||
err := storage.GetDb().
|
||||
Where("token = ?", token).
|
||||
First(&downloadToken).Error
|
||||
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &downloadToken, nil
|
||||
}
|
||||
|
||||
func (r *DownloadTokenRepository) Update(token *DownloadToken) error {
|
||||
return storage.GetDb().Save(token).Error
|
||||
}
|
||||
|
||||
func (r *DownloadTokenRepository) DeleteExpired(before time.Time) error {
|
||||
return storage.GetDb().
|
||||
Where("expires_at < ?", before).
|
||||
Delete(&DownloadToken{}).Error
|
||||
}
|
||||
|
||||
func GenerateSecureToken() string {
|
||||
b := make([]byte, 32)
|
||||
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic("failed to generate secure random token: " + err.Error())
|
||||
}
|
||||
|
||||
return base64.URLEncoding.EncodeToString(b)
|
||||
}
|
||||
105
backend/internal/features/backups/backups/download/service.go
Normal file
105
backend/internal/features/backups/backups/download/service.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type DownloadTokenService struct {
|
||||
repository *DownloadTokenRepository
|
||||
logger *slog.Logger
|
||||
downloadTracker *DownloadTracker
|
||||
bandwidthManager *BandwidthManager
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) Generate(backupID, userID uuid.UUID) (string, error) {
|
||||
if s.downloadTracker.IsDownloadInProgress(userID) {
|
||||
return "", ErrDownloadAlreadyInProgress
|
||||
}
|
||||
|
||||
token := GenerateSecureToken()
|
||||
|
||||
downloadToken := &DownloadToken{
|
||||
Token: token,
|
||||
BackupID: backupID,
|
||||
UserID: userID,
|
||||
ExpiresAt: time.Now().UTC().Add(5 * time.Minute),
|
||||
Used: false,
|
||||
}
|
||||
|
||||
if err := s.repository.Create(downloadToken); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
s.logger.Info("Generated download token", "backupId", backupID, "userId", userID)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) ValidateAndConsume(
|
||||
token string,
|
||||
) (*DownloadToken, *RateLimiter, error) {
|
||||
dt, err := s.repository.FindByToken(token)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if dt == nil {
|
||||
return nil, nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
if dt.Used {
|
||||
return nil, nil, errors.New("token already used")
|
||||
}
|
||||
|
||||
if time.Now().UTC().After(dt.ExpiresAt) {
|
||||
return nil, nil, errors.New("token expired")
|
||||
}
|
||||
|
||||
if err := s.downloadTracker.AcquireDownloadLock(dt.UserID); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
rateLimiter, err := s.bandwidthManager.RegisterDownload(dt.UserID)
|
||||
if err != nil {
|
||||
s.downloadTracker.ReleaseDownloadLock(dt.UserID)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
dt.Used = true
|
||||
if err := s.repository.Update(dt); err != nil {
|
||||
s.logger.Error("Failed to mark token as used", "error", err)
|
||||
}
|
||||
|
||||
s.logger.Info("Token validated and consumed", "backupId", dt.BackupID, "userId", dt.UserID)
|
||||
return dt, rateLimiter, nil
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) RefreshDownloadLock(userID uuid.UUID) {
|
||||
s.downloadTracker.RefreshDownloadLock(userID)
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) ReleaseDownloadLock(userID uuid.UUID) {
|
||||
s.downloadTracker.ReleaseDownloadLock(userID)
|
||||
s.logger.Info("Released download lock", "userId", userID)
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) IsDownloadInProgress(userID uuid.UUID) bool {
|
||||
return s.downloadTracker.IsDownloadInProgress(userID)
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) UnregisterDownload(userID uuid.UUID) {
|
||||
s.bandwidthManager.UnregisterDownload(userID)
|
||||
s.logger.Info("Unregistered from bandwidth manager", "userId", userID)
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) CleanExpiredTokens() error {
|
||||
now := time.Now().UTC()
|
||||
if err := s.repository.DeleteExpired(now); err != nil {
|
||||
return err
|
||||
}
|
||||
s.logger.Debug("Cleaned expired download tokens")
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/valkey-io/valkey-go"
|
||||
)
|
||||
|
||||
const (
|
||||
downloadLockPrefix = "backup_download_lock:"
|
||||
downloadLockTTL = 5 * time.Second
|
||||
downloadLockValue = "1"
|
||||
downloadHeartbeatDelay = 3 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDownloadAlreadyInProgress = errors.New("download already in progress for this user")
|
||||
)
|
||||
|
||||
type DownloadTracker struct {
|
||||
cache *cache_utils.CacheUtil[string]
|
||||
}
|
||||
|
||||
func NewDownloadTracker(client valkey.Client) *DownloadTracker {
|
||||
return &DownloadTracker{
|
||||
cache: cache_utils.NewCacheUtil[string](client, downloadLockPrefix),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *DownloadTracker) AcquireDownloadLock(userID uuid.UUID) error {
|
||||
key := userID.String()
|
||||
|
||||
existingLock := t.cache.Get(key)
|
||||
if existingLock != nil {
|
||||
return ErrDownloadAlreadyInProgress
|
||||
}
|
||||
|
||||
value := downloadLockValue
|
||||
t.cache.Set(key, &value)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *DownloadTracker) RefreshDownloadLock(userID uuid.UUID) {
|
||||
key := userID.String()
|
||||
value := downloadLockValue
|
||||
t.cache.Set(key, &value)
|
||||
}
|
||||
|
||||
func (t *DownloadTracker) ReleaseDownloadLock(userID uuid.UUID) {
|
||||
key := userID.String()
|
||||
t.cache.Invalidate(key)
|
||||
}
|
||||
|
||||
func (t *DownloadTracker) IsDownloadInProgress(userID uuid.UUID) bool {
|
||||
key := userID.String()
|
||||
existingLock := t.cache.Get(key)
|
||||
return existingLock != nil
|
||||
}
|
||||
|
||||
func GetDownloadHeartbeatInterval() time.Duration {
|
||||
return downloadHeartbeatDelay
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/encryption"
|
||||
"io"
|
||||
)
|
||||
@@ -12,17 +13,17 @@ type GetBackupsRequest struct {
|
||||
}
|
||||
|
||||
type GetBackupsResponse struct {
|
||||
Backups []*Backup `json:"backups"`
|
||||
Total int64 `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
Backups []*backups_core.Backup `json:"backups"`
|
||||
Total int64 `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
}
|
||||
|
||||
type decryptionReaderCloser struct {
|
||||
type DecryptionReaderCloser struct {
|
||||
*encryption.DecryptionReader
|
||||
baseReader io.ReadCloser
|
||||
BaseReader io.ReadCloser
|
||||
}
|
||||
|
||||
func (r *decryptionReaderCloser) Close() error {
|
||||
return r.baseReader.Close()
|
||||
func (r *DecryptionReaderCloser) Close() error {
|
||||
return r.BaseReader.Close()
|
||||
}
|
||||
|
||||
@@ -1,23 +1,23 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
"databasus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
users_models "databasus-backend/internal/features/users/models"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
@@ -28,25 +28,27 @@ import (
|
||||
type BackupService struct {
|
||||
databaseService *databases.DatabaseService
|
||||
storageService *storages.StorageService
|
||||
backupRepository *BackupRepository
|
||||
backupRepository *backups_core.BackupRepository
|
||||
notifierService *notifiers.NotifierService
|
||||
notificationSender NotificationSender
|
||||
notificationSender backups_core.NotificationSender
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
secretKeyService *encryption_secrets.SecretKeyService
|
||||
fieldEncryptor util_encryption.FieldEncryptor
|
||||
|
||||
createBackupUseCase CreateBackupUsecase
|
||||
createBackupUseCase backups_core.CreateBackupUsecase
|
||||
|
||||
logger *slog.Logger
|
||||
|
||||
backupRemoveListeners []BackupRemoveListener
|
||||
backupRemoveListeners []backups_core.BackupRemoveListener
|
||||
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
auditLogService *audit_logs.AuditLogService
|
||||
backupContextManager *BackupContextManager
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
auditLogService *audit_logs.AuditLogService
|
||||
taskCancelManager *task_cancellation.TaskCancelManager
|
||||
downloadTokenService *backups_download.DownloadTokenService
|
||||
backupSchedulerService *backuping.BackupsScheduler
|
||||
}
|
||||
|
||||
func (s *BackupService) AddBackupRemoveListener(listener BackupRemoveListener) {
|
||||
func (s *BackupService) AddBackupRemoveListener(listener backups_core.BackupRemoveListener) {
|
||||
s.backupRemoveListeners = append(s.backupRemoveListeners, listener)
|
||||
}
|
||||
|
||||
@@ -89,7 +91,7 @@ func (s *BackupService) MakeBackupWithAuth(
|
||||
return errors.New("insufficient permissions to create backup for this database")
|
||||
}
|
||||
|
||||
go s.MakeBackup(databaseID, true)
|
||||
s.backupSchedulerService.StartBackup(databaseID, true)
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Backup manually initiated for database: %s", database.Name),
|
||||
@@ -173,7 +175,7 @@ func (s *BackupService) DeleteBackup(
|
||||
return errors.New("insufficient permissions to delete backup for this database")
|
||||
}
|
||||
|
||||
if backup.Status == BackupStatusInProgress {
|
||||
if backup.Status == backups_core.BackupStatusInProgress {
|
||||
return errors.New("backup is in progress")
|
||||
}
|
||||
|
||||
@@ -190,265 +192,7 @@ func (s *BackupService) DeleteBackup(
|
||||
return s.deleteBackup(backup)
|
||||
}
|
||||
|
||||
func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
|
||||
database, err := s.databaseService.GetDatabaseByID(databaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get database by ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
lastBackup, err := s.backupRepository.FindLastByDatabaseID(databaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to find last backup by database ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if lastBackup != nil && lastBackup.Status == BackupStatusInProgress {
|
||||
s.logger.Error("Backup is in progress")
|
||||
return
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(databaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !backupConfig.IsBackupsEnabled {
|
||||
s.logger.Info("Backups are not enabled for this database")
|
||||
return
|
||||
}
|
||||
|
||||
if backupConfig.StorageID == nil {
|
||||
s.logger.Error("Backup config storage ID is not defined")
|
||||
return
|
||||
}
|
||||
|
||||
storage, err := s.storageService.GetStorageByID(*backupConfig.StorageID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get storage by ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
backup := &Backup{
|
||||
DatabaseID: databaseID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusInProgress,
|
||||
|
||||
BackupSizeMb: 0,
|
||||
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error("Failed to save backup", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now().UTC()
|
||||
|
||||
backupProgressListener := func(
|
||||
completedMBs float64,
|
||||
) {
|
||||
backup.BackupSizeMb = completedMBs
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error("Failed to update backup progress", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s.backupContextManager.RegisterBackup(backup.ID, cancel)
|
||||
defer s.backupContextManager.UnregisterBackup(backup.ID)
|
||||
|
||||
backupMetadata, err := s.createBackupUseCase.Execute(
|
||||
ctx,
|
||||
backup.ID,
|
||||
backupConfig,
|
||||
database,
|
||||
storage,
|
||||
backupProgressListener,
|
||||
)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
|
||||
// Check if backup was cancelled (not due to shutdown)
|
||||
isCancelled := strings.Contains(errMsg, "backup cancelled") ||
|
||||
strings.Contains(errMsg, "context canceled") ||
|
||||
errors.Is(err, context.Canceled)
|
||||
isShutdown := strings.Contains(errMsg, "shutdown")
|
||||
|
||||
if isCancelled && !isShutdown {
|
||||
backup.Status = BackupStatusCanceled
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error("Failed to save cancelled backup", "error", err)
|
||||
}
|
||||
|
||||
// Delete partial backup from storage
|
||||
storage, storageErr := s.storageService.GetStorageByID(backup.StorageID)
|
||||
if storageErr == nil {
|
||||
if deleteErr := storage.DeleteFile(s.fieldEncryptor, backup.ID); deleteErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to delete partial backup file",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
deleteErr,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
backup.FailMessage = &errMsg
|
||||
backup.Status = BackupStatusFailed
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
if updateErr := s.databaseService.SetBackupError(databaseID, errMsg); updateErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to update database last backup time",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
updateErr,
|
||||
)
|
||||
}
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error("Failed to save backup", "error", err)
|
||||
}
|
||||
|
||||
s.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupFailed,
|
||||
&errMsg,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
backup.Status = BackupStatusCompleted
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
// Update backup with encryption metadata if provided
|
||||
if backupMetadata != nil {
|
||||
backup.EncryptionSalt = backupMetadata.EncryptionSalt
|
||||
backup.EncryptionIV = backupMetadata.EncryptionIV
|
||||
backup.Encryption = backupMetadata.Encryption
|
||||
}
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error("Failed to save backup", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Update database last backup time
|
||||
now := time.Now().UTC()
|
||||
if updateErr := s.databaseService.SetLastBackupTime(databaseID, now); updateErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to update database last backup time",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
updateErr,
|
||||
)
|
||||
}
|
||||
|
||||
if backup.Status != BackupStatusCompleted && !isLastTry {
|
||||
return
|
||||
}
|
||||
|
||||
s.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupSuccess,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *BackupService) SendBackupNotification(
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
backup *Backup,
|
||||
notificationType backups_config.BackupNotificationType,
|
||||
errorMessage *string,
|
||||
) {
|
||||
database, err := s.databaseService.GetDatabaseByID(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
workspace, err := s.workspaceService.GetWorkspaceByID(*database.WorkspaceID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, notifier := range database.Notifiers {
|
||||
if !slices.Contains(
|
||||
backupConfig.SendNotificationsOn,
|
||||
notificationType,
|
||||
) {
|
||||
continue
|
||||
}
|
||||
|
||||
title := ""
|
||||
switch notificationType {
|
||||
case backups_config.NotificationBackupFailed:
|
||||
title = fmt.Sprintf(
|
||||
"❌ Backup failed for database \"%s\" (workspace \"%s\")",
|
||||
database.Name,
|
||||
workspace.Name,
|
||||
)
|
||||
case backups_config.NotificationBackupSuccess:
|
||||
title = fmt.Sprintf(
|
||||
"✅ Backup completed for database \"%s\" (workspace \"%s\")",
|
||||
database.Name,
|
||||
workspace.Name,
|
||||
)
|
||||
}
|
||||
|
||||
message := ""
|
||||
if errorMessage != nil {
|
||||
message = *errorMessage
|
||||
} else {
|
||||
// Format size conditionally
|
||||
var sizeStr string
|
||||
if backup.BackupSizeMb < 1024 {
|
||||
sizeStr = fmt.Sprintf("%.2f MB", backup.BackupSizeMb)
|
||||
} else {
|
||||
sizeGB := backup.BackupSizeMb / 1024
|
||||
sizeStr = fmt.Sprintf("%.2f GB", sizeGB)
|
||||
}
|
||||
|
||||
// Format duration as "0m 0s 0ms"
|
||||
totalMs := backup.BackupDurationMs
|
||||
minutes := totalMs / (1000 * 60)
|
||||
seconds := (totalMs % (1000 * 60)) / 1000
|
||||
durationStr := fmt.Sprintf("%dm %ds", minutes, seconds)
|
||||
|
||||
message = fmt.Sprintf(
|
||||
"Backup completed successfully in %s.\nCompressed backup size: %s",
|
||||
durationStr,
|
||||
sizeStr,
|
||||
)
|
||||
}
|
||||
|
||||
s.notificationSender.SendNotification(
|
||||
¬ifier,
|
||||
title,
|
||||
message,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupService) GetBackup(backupID uuid.UUID) (*Backup, error) {
|
||||
func (s *BackupService) GetBackup(backupID uuid.UUID) (*backups_core.Backup, error) {
|
||||
return s.backupRepository.FindByID(backupID)
|
||||
}
|
||||
|
||||
@@ -478,11 +222,11 @@ func (s *BackupService) CancelBackup(
|
||||
return errors.New("insufficient permissions to cancel backup for this database")
|
||||
}
|
||||
|
||||
if backup.Status != BackupStatusInProgress {
|
||||
if backup.Status != backups_core.BackupStatusInProgress {
|
||||
return errors.New("backup is not in progress")
|
||||
}
|
||||
|
||||
if err := s.backupContextManager.CancelBackup(backupID); err != nil {
|
||||
if err := s.taskCancelManager.CancelTask(backupID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -502,7 +246,7 @@ func (s *BackupService) CancelBackup(
|
||||
func (s *BackupService) GetBackupFile(
|
||||
user *users_models.User,
|
||||
backupID uuid.UUID,
|
||||
) (io.ReadCloser, *Backup, *databases.Database, error) {
|
||||
) (io.ReadCloser, *backups_core.Backup, *databases.Database, error) {
|
||||
backup, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
@@ -548,7 +292,7 @@ func (s *BackupService) GetBackupFile(
|
||||
return reader, backup, database, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) deleteBackup(backup *Backup) error {
|
||||
func (s *BackupService) deleteBackup(backup *backups_core.Backup) error {
|
||||
for _, listener := range s.backupRemoveListeners {
|
||||
if err := listener.OnBeforeBackupRemove(backup); err != nil {
|
||||
return err
|
||||
@@ -574,7 +318,7 @@ func (s *BackupService) deleteBackup(backup *Backup) error {
|
||||
func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
|
||||
dbBackupsInProgress, err := s.backupRepository.FindByDatabaseIdAndStatus(
|
||||
databaseID,
|
||||
BackupStatusInProgress,
|
||||
backups_core.BackupStatusInProgress,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -683,8 +427,136 @@ func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, erro
|
||||
|
||||
s.logger.Info("Returning encrypted backup with decryption", "backupId", backupID)
|
||||
|
||||
return &decryptionReaderCloser{
|
||||
decryptionReader,
|
||||
fileReader,
|
||||
return &DecryptionReaderCloser{
|
||||
DecryptionReader: decryptionReader,
|
||||
BaseReader: fileReader,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) GenerateDownloadToken(
|
||||
user *users_models.User,
|
||||
backupID uuid.UUID,
|
||||
) (*backups_download.GenerateDownloadTokenResponse, error) {
|
||||
backup, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if database.WorkspaceID == nil {
|
||||
return nil, errors.New("cannot download backup for database without workspace")
|
||||
}
|
||||
|
||||
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !canAccess {
|
||||
return nil, errors.New("insufficient permissions to download backup for this database")
|
||||
}
|
||||
|
||||
token, err := s.downloadTokenService.Generate(backupID, user.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
filename := s.generateBackupFilename(backup, database)
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Download token generated for backup of database: %s", database.Name),
|
||||
&user.ID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
|
||||
return &backups_download.GenerateDownloadTokenResponse{
|
||||
Token: token,
|
||||
Filename: filename,
|
||||
BackupID: backupID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) ValidateDownloadToken(
|
||||
token string,
|
||||
) (*backups_download.DownloadToken, *backups_download.RateLimiter, error) {
|
||||
return s.downloadTokenService.ValidateAndConsume(token)
|
||||
}
|
||||
|
||||
func (s *BackupService) GetBackupFileWithoutAuth(
|
||||
backupID uuid.UUID,
|
||||
) (io.ReadCloser, *backups_core.Backup, *databases.Database, error) {
|
||||
backup, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
reader, err := s.getBackupReader(backupID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
return reader, backup, database, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) WriteAuditLogForDownload(
|
||||
userID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
database *databases.Database,
|
||||
) {
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Backup file downloaded for database: %s (ID: %s)",
|
||||
database.Name,
|
||||
backup.ID.String(),
|
||||
),
|
||||
&userID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *BackupService) RefreshDownloadLock(userID uuid.UUID) {
|
||||
s.downloadTokenService.RefreshDownloadLock(userID)
|
||||
}
|
||||
|
||||
func (s *BackupService) ReleaseDownloadLock(userID uuid.UUID) {
|
||||
s.downloadTokenService.ReleaseDownloadLock(userID)
|
||||
}
|
||||
|
||||
func (s *BackupService) IsDownloadInProgress(userID uuid.UUID) bool {
|
||||
return s.downloadTokenService.IsDownloadInProgress(userID)
|
||||
}
|
||||
|
||||
func (s *BackupService) UnregisterDownload(userID uuid.UUID) {
|
||||
s.downloadTokenService.UnregisterDownload(userID)
|
||||
}
|
||||
|
||||
func (s *BackupService) generateBackupFilename(
|
||||
backup *backups_core.Backup,
|
||||
database *databases.Database,
|
||||
) string {
|
||||
timestamp := backup.CreatedAt.Format("2006-01-02_15-04-05")
|
||||
safeName := sanitizeFilename(database.Name)
|
||||
extension := s.getBackupExtension(database.Type)
|
||||
return fmt.Sprintf("%s_backup_%s%s", safeName, timestamp, extension)
|
||||
}
|
||||
|
||||
func (s *BackupService) getBackupExtension(dbType databases.DatabaseType) string {
|
||||
switch dbType {
|
||||
case databases.DatabaseTypeMysql, databases.DatabaseTypeMariadb:
|
||||
return ".sql.zst"
|
||||
case databases.DatabaseTypePostgres:
|
||||
return ".dump"
|
||||
case databases.DatabaseTypeMongodb:
|
||||
return ".archive"
|
||||
default:
|
||||
return ".backup"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,20 +1,77 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func CreateTestRouter() *gin.Engine {
|
||||
return workspaces_testing.CreateTestRouter(
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
databases.GetDatabaseController(),
|
||||
backups_config.GetBackupConfigController(),
|
||||
GetBackupController(),
|
||||
)
|
||||
|
||||
// Register public routes (no auth required - token-based)
|
||||
v1 := router.Group("/api/v1")
|
||||
GetBackupController().RegisterPublicRoutes(v1)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
// WaitForBackupCompletion waits for a new backup to be created and completed (or failed)
|
||||
// for the given database. It checks for backups with count greater than expectedInitialCount.
|
||||
func WaitForBackupCompletion(
|
||||
t *testing.T,
|
||||
databaseID uuid.UUID,
|
||||
expectedInitialCount int,
|
||||
timeout time.Duration,
|
||||
) {
|
||||
deadline := time.Now().UTC().Add(timeout)
|
||||
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
backups, err := backupRepository.FindByDatabaseID(databaseID)
|
||||
if err != nil {
|
||||
t.Logf("WaitForBackupCompletion: error finding backups: %v", err)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
t.Logf(
|
||||
"WaitForBackupCompletion: found %d backups (expected > %d)",
|
||||
len(backups),
|
||||
expectedInitialCount,
|
||||
)
|
||||
|
||||
if len(backups) > expectedInitialCount {
|
||||
// Check if the newest backup has completed or failed
|
||||
newestBackup := backups[0]
|
||||
t.Logf("WaitForBackupCompletion: newest backup status: %s", newestBackup.Status)
|
||||
|
||||
if newestBackup.Status == backups_core.BackupStatusCompleted ||
|
||||
newestBackup.Status == backups_core.BackupStatusFailed ||
|
||||
newestBackup.Status == backups_core.BackupStatusCanceled {
|
||||
t.Logf(
|
||||
"WaitForBackupCompletion: backup finished with status %s",
|
||||
newestBackup.Status,
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Logf("WaitForBackupCompletion: timeout waiting for backup to complete")
|
||||
}
|
||||
|
||||
@@ -64,10 +64,6 @@ func (uc *CreateMariadbBackupUsecase) Execute(
|
||||
"storageId", storage.ID,
|
||||
)
|
||||
|
||||
if !backupConfig.IsBackupsEnabled {
|
||||
return nil, fmt.Errorf("backups are not enabled for this database: \"%s\"", db.Name)
|
||||
}
|
||||
|
||||
mdb := db.Mariadb
|
||||
if mdb == nil {
|
||||
return nil, fmt.Errorf("mariadb database configuration is required")
|
||||
@@ -111,16 +107,22 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
|
||||
"--user=" + mdb.Username,
|
||||
"--single-transaction",
|
||||
"--routines",
|
||||
"--triggers",
|
||||
"--events",
|
||||
"--quick",
|
||||
"--verbose",
|
||||
}
|
||||
|
||||
if mdb.HasPrivilege("TRIGGER") {
|
||||
args = append(args, "--triggers")
|
||||
}
|
||||
if mdb.HasPrivilege("EVENT") {
|
||||
args = append(args, "--events")
|
||||
}
|
||||
|
||||
args = append(args, "--compress")
|
||||
|
||||
if mdb.IsHttps {
|
||||
args = append(args, "--ssl")
|
||||
args = append(args, "--skip-ssl-verify-server-cert")
|
||||
}
|
||||
|
||||
if mdb.Database != nil && *mdb.Database != "" {
|
||||
@@ -264,11 +266,24 @@ func (uc *CreateMariadbBackupUsecase) createTempMyCnfFile(
|
||||
mdbConfig *mariadbtypes.MariadbDatabase,
|
||||
password string,
|
||||
) (string, error) {
|
||||
tempDir, err := os.MkdirTemp(config.GetEnv().TempFolder, "mycnf_"+uuid.New().String())
|
||||
tempFolder := config.GetEnv().TempFolder
|
||||
if err := os.MkdirAll(tempFolder, 0700); err != nil {
|
||||
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
|
||||
}
|
||||
if err := os.Chmod(tempFolder, 0700); err != nil {
|
||||
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
|
||||
}
|
||||
|
||||
tempDir, err := os.MkdirTemp(tempFolder, "mycnf_"+uuid.New().String())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temp directory: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(tempDir, 0700); err != nil {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
return "", fmt.Errorf("failed to set temp directory permissions: %w", err)
|
||||
}
|
||||
|
||||
myCnfFile := filepath.Join(tempDir, ".my.cnf")
|
||||
|
||||
content := fmt.Sprintf(`[client]
|
||||
@@ -286,6 +301,7 @@ port=%d
|
||||
|
||||
err = os.WriteFile(myCnfFile, []byte(content), 0600)
|
||||
if err != nil {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
return "", fmt.Errorf("failed to write .my.cnf: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -58,10 +58,6 @@ func (uc *CreateMongodbBackupUsecase) Execute(
|
||||
"storageId", storage.ID,
|
||||
)
|
||||
|
||||
if !backupConfig.IsBackupsEnabled {
|
||||
return nil, fmt.Errorf("backups are not enabled for this database: \"%s\"", db.Name)
|
||||
}
|
||||
|
||||
mdb := db.Mongodb
|
||||
if mdb == nil {
|
||||
return nil, fmt.Errorf("mongodb database configuration is required")
|
||||
|
||||
@@ -64,10 +64,6 @@ func (uc *CreateMysqlBackupUsecase) Execute(
|
||||
"storageId", storage.ID,
|
||||
)
|
||||
|
||||
if !backupConfig.IsBackupsEnabled {
|
||||
return nil, fmt.Errorf("backups are not enabled for this database: \"%s\"", db.Name)
|
||||
}
|
||||
|
||||
my := db.Mysql
|
||||
if my == nil {
|
||||
return nil, fmt.Errorf("mysql database configuration is required")
|
||||
@@ -109,13 +105,18 @@ func (uc *CreateMysqlBackupUsecase) buildMysqldumpArgs(my *mysqltypes.MysqlDatab
|
||||
"--user=" + my.Username,
|
||||
"--single-transaction",
|
||||
"--routines",
|
||||
"--triggers",
|
||||
"--events",
|
||||
"--set-gtid-purged=OFF",
|
||||
"--quick",
|
||||
"--verbose",
|
||||
}
|
||||
|
||||
if my.HasPrivilege("TRIGGER") {
|
||||
args = append(args, "--triggers")
|
||||
}
|
||||
if my.HasPrivilege("EVENT") {
|
||||
args = append(args, "--events")
|
||||
}
|
||||
|
||||
args = append(args, uc.getNetworkCompressionArgs(my.Version)...)
|
||||
|
||||
if my.IsHttps {
|
||||
@@ -279,11 +280,24 @@ func (uc *CreateMysqlBackupUsecase) createTempMyCnfFile(
|
||||
myConfig *mysqltypes.MysqlDatabase,
|
||||
password string,
|
||||
) (string, error) {
|
||||
tempDir, err := os.MkdirTemp(config.GetEnv().TempFolder, "mycnf_"+uuid.New().String())
|
||||
tempFolder := config.GetEnv().TempFolder
|
||||
if err := os.MkdirAll(tempFolder, 0700); err != nil {
|
||||
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
|
||||
}
|
||||
if err := os.Chmod(tempFolder, 0700); err != nil {
|
||||
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
|
||||
}
|
||||
|
||||
tempDir, err := os.MkdirTemp(tempFolder, "mycnf_"+uuid.New().String())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temp directory: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(tempDir, 0700); err != nil {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
return "", fmt.Errorf("failed to set temp directory permissions: %w", err)
|
||||
}
|
||||
|
||||
myCnfFile := filepath.Join(tempDir, ".my.cnf")
|
||||
|
||||
content := fmt.Sprintf(`[client]
|
||||
@@ -299,6 +313,7 @@ port=%d
|
||||
|
||||
err = os.WriteFile(myCnfFile, []byte(content), 0600)
|
||||
if err != nil {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
return "", fmt.Errorf("failed to write .my.cnf: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -69,10 +69,6 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
|
||||
storage.ID,
|
||||
)
|
||||
|
||||
if !backupConfig.IsBackupsEnabled {
|
||||
return nil, fmt.Errorf("backups are not enabled for this database: \"%s\"", db.Name)
|
||||
}
|
||||
|
||||
pg := db.Postgresql
|
||||
|
||||
if pg == nil {
|
||||
@@ -139,7 +135,14 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
cmd := exec.CommandContext(ctx, pgBin, args...)
|
||||
uc.logger.Info("Executing PostgreSQL backup command", "command", cmd.String())
|
||||
|
||||
if err := uc.setupPgEnvironment(cmd, pgpassFile, db.Postgresql.IsHttps, password, db.Postgresql.CpuCount, pgBin); err != nil {
|
||||
if err := uc.setupPgEnvironment(
|
||||
cmd,
|
||||
pgpassFile,
|
||||
db.Postgresql.IsHttps,
|
||||
password,
|
||||
db.Postgresql.CpuCount,
|
||||
pgBin,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -335,11 +338,6 @@ func (uc *CreatePostgresqlBackupUsecase) buildPgDumpArgs(pg *pgtypes.PostgresqlD
|
||||
"--verbose",
|
||||
}
|
||||
|
||||
// Add parallel jobs based on CPU count
|
||||
if pg.CpuCount > 1 {
|
||||
args = append(args, "-j", strconv.Itoa(pg.CpuCount))
|
||||
}
|
||||
|
||||
for _, schema := range pg.IncludeSchemas {
|
||||
args = append(args, "-n", schema)
|
||||
}
|
||||
@@ -759,14 +757,28 @@ func (uc *CreatePostgresqlBackupUsecase) createTempPgpassFile(
|
||||
escapedPassword,
|
||||
)
|
||||
|
||||
tempDir, err := os.MkdirTemp("", "pgpass")
|
||||
tempFolder := config.GetEnv().TempFolder
|
||||
if err := os.MkdirAll(tempFolder, 0700); err != nil {
|
||||
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
|
||||
}
|
||||
if err := os.Chmod(tempFolder, 0700); err != nil {
|
||||
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
|
||||
}
|
||||
|
||||
tempDir, err := os.MkdirTemp(tempFolder, "pgpass_"+uuid.New().String())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temporary directory: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(tempDir, 0700); err != nil {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
return "", fmt.Errorf("failed to set temporary directory permissions: %w", err)
|
||||
}
|
||||
|
||||
pgpassFile := filepath.Join(tempDir, ".pgpass")
|
||||
err = os.WriteFile(pgpassFile, []byte(pgpassContent), 0600)
|
||||
if err != nil {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
return "", fmt.Errorf("failed to write temporary .pgpass file: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package backups_config
|
||||
|
||||
import (
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -16,6 +18,8 @@ func (c *BackupConfigController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
router.POST("/backup-configs/save", c.SaveBackupConfig)
|
||||
router.GET("/backup-configs/database/:id", c.GetBackupConfigByDbID)
|
||||
router.GET("/backup-configs/storage/:id/is-using", c.IsStorageUsing)
|
||||
router.GET("/backup-configs/storage/:id/databases-count", c.CountDatabasesForStorage)
|
||||
router.POST("/backup-configs/database/:id/transfer", c.TransferDatabase)
|
||||
}
|
||||
|
||||
// SaveBackupConfig
|
||||
@@ -120,3 +124,86 @@ func (c *BackupConfigController) IsStorageUsing(ctx *gin.Context) {
|
||||
|
||||
ctx.JSON(http.StatusOK, gin.H{"isUsing": isUsing})
|
||||
}
|
||||
|
||||
// CountDatabasesForStorage
|
||||
// @Summary Count databases using a storage
|
||||
// @Description Get the count of databases that are using a specific storage
|
||||
// @Tags backup-configs
|
||||
// @Produce json
|
||||
// @Param id path string true "Storage ID"
|
||||
// @Success 200 {object} map[string]int
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Router /backup-configs/storage/{id}/databases-count [get]
|
||||
func (c *BackupConfigController) CountDatabasesForStorage(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid storage ID"})
|
||||
return
|
||||
}
|
||||
|
||||
count, err := c.backupConfigService.CountDatabasesForStorage(user, id)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, gin.H{"count": count})
|
||||
}
|
||||
|
||||
// TransferDatabase
|
||||
// @Summary Transfer database to another workspace
|
||||
// @Description Transfer a database from one workspace to another. Can transfer to a new storage or transfer with the existing storage. Can also specify target notifiers from the target workspace.
|
||||
// @Tags backup-configs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param id path string true "Database ID"
|
||||
// @Param request body TransferDatabaseRequest true "Transfer request with targetWorkspaceId, storage options (targetStorageId or isTransferWithStorage), and optional targetNotifierIds"
|
||||
// @Success 200 {object} map[string]string "Database transferred successfully"
|
||||
// @Failure 400 {object} map[string]string "Invalid request, target storage/notifier not in target workspace, or transfer failed"
|
||||
// @Failure 401 {object} map[string]string "User not authenticated"
|
||||
// @Failure 403 {object} map[string]string "Insufficient permissions"
|
||||
// @Router /backup-configs/database/{id}/transfer [post]
|
||||
func (c *BackupConfigController) TransferDatabase(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var request TransferDatabaseRequest
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if request.TargetWorkspaceID == uuid.Nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "targetWorkspaceId is required"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.backupConfigService.TransferDatabaseToWorkspace(user, id, &request); err != nil {
|
||||
if errors.Is(err, ErrInsufficientPermissionsInSourceWorkspace) ||
|
||||
errors.Is(err, ErrInsufficientPermissionsInTargetWorkspace) {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, gin.H{"message": "database transferred successfully"})
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,6 +2,7 @@ package backups_config
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
)
|
||||
@@ -11,6 +12,7 @@ var backupConfigService = &BackupConfigService{
|
||||
backupConfigRepository,
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
nil,
|
||||
}
|
||||
@@ -25,3 +27,7 @@ func GetBackupConfigController() *BackupConfigController {
|
||||
func GetBackupConfigService() *BackupConfigService {
|
||||
return backupConfigService
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService)
|
||||
}
|
||||
|
||||
11
backend/internal/features/backups/config/dto.go
Normal file
11
backend/internal/features/backups/config/dto.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package backups_config
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
type TransferDatabaseRequest struct {
|
||||
TargetWorkspaceID uuid.UUID `json:"targetWorkspaceId" binding:"required"`
|
||||
TargetStorageID *uuid.UUID `json:"targetStorageId,omitempty"`
|
||||
IsTransferWithStorage bool `json:"isTransferWithStorage,omitempty"`
|
||||
IsTransferWithNotifiers bool `json:"isTransferWithNotifiers,omitempty"`
|
||||
TargetNotifierIDs []uuid.UUID `json:"targetNotifierIds,omitempty"`
|
||||
}
|
||||
30
backend/internal/features/backups/config/errors.go
Normal file
30
backend/internal/features/backups/config/errors.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package backups_config
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrInsufficientPermissionsInSourceWorkspace = errors.New(
|
||||
"insufficient permissions to manage database in source workspace",
|
||||
)
|
||||
ErrInsufficientPermissionsInTargetWorkspace = errors.New(
|
||||
"insufficient permissions to manage database in target workspace",
|
||||
)
|
||||
ErrTargetStorageNotInTargetWorkspace = errors.New(
|
||||
"target storage does not belong to target workspace",
|
||||
)
|
||||
ErrTargetNotifierNotInTargetWorkspace = errors.New(
|
||||
"target notifier does not belong to target workspace",
|
||||
)
|
||||
ErrStorageHasOtherAttachedDatabases = errors.New(
|
||||
"storage has other attached databases and cannot be transferred with this database",
|
||||
)
|
||||
ErrDatabaseHasNoStorage = errors.New(
|
||||
"database has no storage attached",
|
||||
)
|
||||
ErrDatabaseHasNoWorkspace = errors.New(
|
||||
"database has no workspace",
|
||||
)
|
||||
ErrTargetStorageNotSpecified = errors.New(
|
||||
"target storage is not specified",
|
||||
)
|
||||
)
|
||||
166
backend/internal/features/backups/config/notifiers_test.go
Normal file
166
backend/internal/features/backups/config/notifiers_test.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package backups_config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
)
|
||||
|
||||
func Test_AttachNotifierFromSameWorkspace_SuccessfullyAttached(t *testing.T) {
|
||||
router := createTestRouterWithNotifier()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
|
||||
database.Notifiers = []notifiers.Notifier{*notifier}
|
||||
|
||||
var response databases.Database
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/databases/update",
|
||||
"Bearer "+owner.Token,
|
||||
database,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, response.ID)
|
||||
assert.Len(t, response.Notifiers, 1)
|
||||
assert.Equal(t, notifier.ID, response.Notifiers[0].ID)
|
||||
}
|
||||
|
||||
func Test_AttachNotifierFromDifferentWorkspace_ReturnsForbidden(t *testing.T) {
|
||||
router := createTestRouterWithNotifier()
|
||||
|
||||
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace1 := workspaces_testing.CreateTestWorkspace("Workspace 1", owner1, router)
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace1.ID, owner1.Token, router)
|
||||
|
||||
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace2 := workspaces_testing.CreateTestWorkspace("Workspace 2", owner2, router)
|
||||
notifier := notifiers.CreateTestNotifier(workspace2.ID)
|
||||
|
||||
database.Notifiers = []notifiers.Notifier{*notifier}
|
||||
|
||||
testResp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/databases/update",
|
||||
"Bearer "+owner1.Token,
|
||||
database,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "notifier does not belong to this workspace")
|
||||
}
|
||||
|
||||
func Test_DeleteNotifierWithAttachedDatabases_CannotDelete(t *testing.T) {
|
||||
router := createTestRouterWithNotifier()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
|
||||
database.Notifiers = []notifiers.Notifier{*notifier}
|
||||
|
||||
var response databases.Database
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/databases/update",
|
||||
"Bearer "+owner.Token,
|
||||
database,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
testResp := test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/notifiers/%s", notifier.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
assert.Contains(
|
||||
t,
|
||||
string(testResp.Body),
|
||||
"notifier has attached databases and cannot be deleted",
|
||||
)
|
||||
}
|
||||
|
||||
func Test_TransferNotifierWithAttachedDatabase_CannotTransfer(t *testing.T) {
|
||||
router := createTestRouterWithNotifier()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
targetWorkspace := workspaces_testing.CreateTestWorkspace("Target Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
|
||||
database.Notifiers = []notifiers.Notifier{*notifier}
|
||||
|
||||
var response databases.Database
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/databases/update",
|
||||
"Bearer "+owner.Token,
|
||||
database,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
transferRequest := notifiers.TransferNotifierRequest{
|
||||
TargetWorkspaceID: targetWorkspace.ID,
|
||||
}
|
||||
|
||||
testResp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/notifiers/%s/transfer", notifier.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
transferRequest,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
assert.Contains(
|
||||
t,
|
||||
string(testResp.Body),
|
||||
"notifier has attached databases and cannot be transferred",
|
||||
)
|
||||
}
|
||||
|
||||
func createTestRouterWithNotifier() *gin.Engine {
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
databases.GetDatabaseController(),
|
||||
GetBackupConfigController(),
|
||||
storages.GetStorageController(),
|
||||
notifiers.GetNotifierController(),
|
||||
)
|
||||
|
||||
storages.SetupDependencies()
|
||||
databases.SetupDependencies()
|
||||
notifiers.SetupDependencies()
|
||||
SetupDependencies()
|
||||
|
||||
return router
|
||||
}
|
||||
@@ -102,3 +102,19 @@ func (r *BackupConfigRepository) IsStorageUsing(storageID uuid.UUID) (bool, erro
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func (r *BackupConfigRepository) GetDatabasesIDsByStorageID(
|
||||
storageID uuid.UUID,
|
||||
) ([]uuid.UUID, error) {
|
||||
var databasesIDs []uuid.UUID
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Table("backup_configs").
|
||||
Where("storage_id = ?", storageID).
|
||||
Pluck("database_id", &databasesIDs).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return databasesIDs, nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_models "databasus-backend/internal/features/users/models"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
@@ -17,6 +18,7 @@ type BackupConfigService struct {
|
||||
backupConfigRepository *BackupConfigRepository
|
||||
databaseService *databases.DatabaseService
|
||||
storageService *storages.StorageService
|
||||
notifierService *notifiers.NotifierService
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
|
||||
dbStorageChangeListener BackupConfigStorageChangeListener
|
||||
@@ -28,6 +30,17 @@ func (s *BackupConfigService) SetDatabaseStorageChangeListener(
|
||||
s.dbStorageChangeListener = dbStorageChangeListener
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) GetStorageAttachedDatabasesIDs(
|
||||
storageID uuid.UUID,
|
||||
) ([]uuid.UUID, error) {
|
||||
databasesIDs, err := s.backupConfigRepository.GetDatabasesIDsByStorageID(storageID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return databasesIDs, nil
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) SaveBackupConfigWithAuth(
|
||||
user *users_models.User,
|
||||
backupConfig *BackupConfig,
|
||||
@@ -53,6 +66,16 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
|
||||
return nil, errors.New("insufficient permissions to modify backup configuration")
|
||||
}
|
||||
|
||||
if backupConfig.Storage != nil && backupConfig.Storage.ID != uuid.Nil {
|
||||
storage, err := s.storageService.GetStorageByID(backupConfig.Storage.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if storage.WorkspaceID != *database.WorkspaceID {
|
||||
return nil, errors.New("storage does not belong to the same workspace as the database")
|
||||
}
|
||||
}
|
||||
|
||||
return s.SaveBackupConfig(backupConfig)
|
||||
}
|
||||
|
||||
@@ -129,6 +152,23 @@ func (s *BackupConfigService) IsStorageUsing(
|
||||
return s.backupConfigRepository.IsStorageUsing(storageID)
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) CountDatabasesForStorage(
|
||||
user *users_models.User,
|
||||
storageID uuid.UUID,
|
||||
) (int, error) {
|
||||
_, err := s.storageService.GetStorage(user, storageID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
databaseIDs, err := s.backupConfigRepository.GetDatabasesIDsByStorageID(storageID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(databaseIDs), nil
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) GetBackupConfigsWithEnabledBackups() ([]*BackupConfig, error) {
|
||||
return s.backupConfigRepository.GetWithEnabledBackups()
|
||||
}
|
||||
@@ -176,6 +216,157 @@ func (s *BackupConfigService) initializeDefaultConfig(
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) TransferDatabaseToWorkspace(
|
||||
user *users_models.User,
|
||||
databaseID uuid.UUID,
|
||||
request *TransferDatabaseRequest,
|
||||
) error {
|
||||
database, err := s.databaseService.GetDatabaseByID(databaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if database.WorkspaceID == nil {
|
||||
return ErrDatabaseHasNoWorkspace
|
||||
}
|
||||
|
||||
canManageSource, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !canManageSource {
|
||||
return ErrInsufficientPermissionsInSourceWorkspace
|
||||
}
|
||||
|
||||
canManageTarget, err := s.workspaceService.CanUserManageDBs(request.TargetWorkspaceID, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !canManageTarget {
|
||||
return ErrInsufficientPermissionsInTargetWorkspace
|
||||
}
|
||||
|
||||
if err := s.validateTargetNotifiers(request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
backupConfig, err := s.GetBackupConfigByDbId(databaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if request.IsTransferWithNotifiers {
|
||||
s.transferNotifiers(user, database, request.TargetWorkspaceID)
|
||||
}
|
||||
|
||||
if request.IsTransferWithStorage {
|
||||
if backupConfig.StorageID == nil {
|
||||
return ErrDatabaseHasNoStorage
|
||||
}
|
||||
|
||||
attachedDatabasesIDs, err := s.GetStorageAttachedDatabasesIDs(*backupConfig.StorageID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, dbID := range attachedDatabasesIDs {
|
||||
if dbID != databaseID {
|
||||
return ErrStorageHasOtherAttachedDatabases
|
||||
}
|
||||
}
|
||||
|
||||
err = s.storageService.TransferStorageToWorkspace(
|
||||
user,
|
||||
*backupConfig.StorageID,
|
||||
request.TargetWorkspaceID,
|
||||
&databaseID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if request.TargetStorageID != nil {
|
||||
targetStorage, err := s.storageService.GetStorageByID(*request.TargetStorageID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if targetStorage.WorkspaceID != request.TargetWorkspaceID {
|
||||
return ErrTargetStorageNotInTargetWorkspace
|
||||
}
|
||||
|
||||
backupConfig.StorageID = request.TargetStorageID
|
||||
backupConfig.Storage = targetStorage
|
||||
|
||||
_, err = s.backupConfigRepository.Save(backupConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return ErrTargetStorageNotSpecified
|
||||
}
|
||||
|
||||
err = s.databaseService.TransferDatabaseToWorkspace(databaseID, request.TargetWorkspaceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(request.TargetNotifierIDs) > 0 {
|
||||
err = s.assignTargetNotifiers(databaseID, request.TargetNotifierIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) transferNotifiers(
|
||||
user *users_models.User,
|
||||
database *databases.Database,
|
||||
targetWorkspaceID uuid.UUID,
|
||||
) {
|
||||
for _, notifier := range database.Notifiers {
|
||||
_ = s.notifierService.TransferNotifierToWorkspace(
|
||||
user,
|
||||
notifier.ID,
|
||||
targetWorkspaceID,
|
||||
&database.ID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) validateTargetNotifiers(request *TransferDatabaseRequest) error {
|
||||
for _, notifierID := range request.TargetNotifierIDs {
|
||||
notifier, err := s.notifierService.GetNotifierByID(notifierID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if notifier.WorkspaceID != request.TargetWorkspaceID {
|
||||
return ErrTargetNotifierNotInTargetWorkspace
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) assignTargetNotifiers(
|
||||
databaseID uuid.UUID,
|
||||
notifierIDs []uuid.UUID,
|
||||
) error {
|
||||
targetNotifiers := make([]notifiers.Notifier, 0, len(notifierIDs))
|
||||
|
||||
for _, notifierID := range notifierIDs {
|
||||
notifier, err := s.notifierService.GetNotifierByID(notifierID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
targetNotifiers = append(targetNotifiers, *notifier)
|
||||
}
|
||||
|
||||
return s.databaseService.UpdateDatabaseNotifiers(databaseID, targetNotifiers)
|
||||
}
|
||||
|
||||
func storageIDsEqual(id1, id2 *uuid.UUID) bool {
|
||||
if id1 == nil && id2 == nil {
|
||||
return true
|
||||
|
||||
229
backend/internal/features/backups/config/storages_test.go
Normal file
229
backend/internal/features/backups/config/storages_test.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package backups_config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/util/period"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
)
|
||||
|
||||
func Test_AttachStorageFromSameWorkspace_SuccessfullyAttached(t *testing.T) {
|
||||
router := createTestRouterWithStorage()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
request := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
Storage: storage,
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
}
|
||||
|
||||
var response BackupConfig
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, response.DatabaseID)
|
||||
assert.NotNil(t, response.StorageID)
|
||||
assert.Equal(t, storage.ID, *response.StorageID)
|
||||
}
|
||||
|
||||
func Test_AttachStorageFromDifferentWorkspace_ReturnsForbidden(t *testing.T) {
|
||||
router := createTestRouterWithStorage()
|
||||
|
||||
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace1 := workspaces_testing.CreateTestWorkspace("Workspace 1", owner1, router)
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace1.ID, owner1.Token, router)
|
||||
|
||||
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace2 := workspaces_testing.CreateTestWorkspace("Workspace 2", owner2, router)
|
||||
storage := createTestStorage(workspace2.ID)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
request := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
Storage: storage,
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
}
|
||||
|
||||
testResp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner1.Token,
|
||||
request,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "storage does not belong to the same workspace")
|
||||
}
|
||||
|
||||
func Test_DeleteStorageWithAttachedDatabases_CannotDelete(t *testing.T) {
|
||||
router := createTestRouterWithStorage()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
request := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
Storage: storage,
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
}
|
||||
|
||||
var response BackupConfig
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
testResp := test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/storages/%s", storage.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
assert.Contains(
|
||||
t,
|
||||
string(testResp.Body),
|
||||
"storage has attached databases and cannot be deleted",
|
||||
)
|
||||
}
|
||||
|
||||
func Test_TransferStorageWithAttachedDatabase_CannotTransfer(t *testing.T) {
|
||||
router := createTestRouterWithStorage()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
targetWorkspace := workspaces_testing.CreateTestWorkspace("Target Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
request := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
Storage: storage,
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
}
|
||||
|
||||
var response BackupConfig
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
transferRequest := storages.TransferStorageRequest{
|
||||
TargetWorkspaceID: targetWorkspace.ID,
|
||||
}
|
||||
|
||||
testResp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/storages/%s/transfer", storage.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
transferRequest,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
assert.Contains(
|
||||
t,
|
||||
string(testResp.Body),
|
||||
"storage has attached databases and cannot be transferred",
|
||||
)
|
||||
}
|
||||
|
||||
func createTestRouterWithStorage() *gin.Engine {
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
databases.GetDatabaseController(),
|
||||
GetBackupConfigController(),
|
||||
storages.GetStorageController(),
|
||||
)
|
||||
|
||||
storages.SetupDependencies()
|
||||
databases.SetupDependencies()
|
||||
SetupDependencies()
|
||||
|
||||
return router
|
||||
}
|
||||
@@ -26,6 +26,7 @@ func (c *DatabaseController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
router.POST("/databases/test-connection-direct", c.TestDatabaseConnectionDirect)
|
||||
router.POST("/databases/:id/copy", c.CopyDatabase)
|
||||
router.GET("/databases/notifier/:id/is-using", c.IsNotifierUsing)
|
||||
router.GET("/databases/notifier/:id/databases-count", c.CountDatabasesByNotifier)
|
||||
router.POST("/databases/is-readonly", c.IsUserReadOnly)
|
||||
router.POST("/databases/create-readonly-user", c.CreateReadOnlyUser)
|
||||
}
|
||||
@@ -299,6 +300,39 @@ func (c *DatabaseController) IsNotifierUsing(ctx *gin.Context) {
|
||||
ctx.JSON(http.StatusOK, gin.H{"isUsing": isUsing})
|
||||
}
|
||||
|
||||
// CountDatabasesByNotifier
|
||||
// @Summary Count databases using a notifier
|
||||
// @Description Get the count of databases that are using a specific notifier
|
||||
// @Tags databases
|
||||
// @Produce json
|
||||
// @Param id path string true "Notifier ID"
|
||||
// @Success 200 {object} map[string]int
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Router /databases/notifier/{id}/databases-count [get]
|
||||
func (c *DatabaseController) CountDatabasesByNotifier(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid notifier ID"})
|
||||
return
|
||||
}
|
||||
|
||||
count, err := c.databaseService.CountDatabasesByNotifier(user, id)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, gin.H{"count": count})
|
||||
}
|
||||
|
||||
// CopyDatabase
|
||||
// @Summary Copy a database
|
||||
// @Description Copy an existing database configuration
|
||||
@@ -358,13 +392,13 @@ func (c *DatabaseController) IsUserReadOnly(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
isReadOnly, err := c.databaseService.IsUserReadOnly(user, &request)
|
||||
isReadOnly, privileges, err := c.databaseService.IsUserReadOnly(user, &request)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, IsReadOnlyResponse{IsReadOnly: isReadOnly})
|
||||
ctx.JSON(http.StatusOK, IsReadOnlyResponse{IsReadOnly: isReadOnly, Privileges: privileges})
|
||||
}
|
||||
|
||||
// CreateReadOnlyUser
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/databases/databases/mariadb"
|
||||
"databasus-backend/internal/features/databases/databases/mongodb"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
@@ -32,6 +34,71 @@ func createTestRouter() *gin.Engine {
|
||||
return router
|
||||
}
|
||||
|
||||
func getTestPostgresConfig() *postgresql.PostgresqlDatabase {
|
||||
env := config.GetEnv()
|
||||
port, err := strconv.Atoi(env.TestPostgres16Port)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
|
||||
}
|
||||
|
||||
testDbName := "testdb"
|
||||
return &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func getTestMariadbConfig() *mariadb.MariadbDatabase {
|
||||
env := config.GetEnv()
|
||||
portStr := env.TestMariadb1011Port
|
||||
if portStr == "" {
|
||||
portStr = "33111"
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_MARIADB_1011_PORT: %v", err))
|
||||
}
|
||||
|
||||
testDbName := "testdb"
|
||||
return &mariadb.MariadbDatabase{
|
||||
Version: tools.MariadbVersion1011,
|
||||
Host: "localhost",
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
Database: &testDbName,
|
||||
}
|
||||
}
|
||||
|
||||
func getTestMongodbConfig() *mongodb.MongodbDatabase {
|
||||
env := config.GetEnv()
|
||||
portStr := env.TestMongodb70Port
|
||||
if portStr == "" {
|
||||
portStr = "27070"
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_MONGODB_70_PORT: %v", err))
|
||||
}
|
||||
|
||||
return &mongodb.MongodbDatabase{
|
||||
Version: tools.MongodbVersion7,
|
||||
Host: "localhost",
|
||||
Port: port,
|
||||
Username: "root",
|
||||
Password: "rootpassword",
|
||||
Database: "testdb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateDatabase_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -84,24 +151,21 @@ func Test_CreateDatabase_PermissionsEnforced(t *testing.T) {
|
||||
testUserToken = owner.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
member,
|
||||
*tt.workspaceRole,
|
||||
owner.Token,
|
||||
router,
|
||||
)
|
||||
testUserToken = member.Token
|
||||
}
|
||||
|
||||
testDbName := "test_db"
|
||||
request := Database{
|
||||
Name: "Test Database",
|
||||
WorkspaceID: &workspace.ID,
|
||||
Type: DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
},
|
||||
Postgresql: getTestPostgresConfig(),
|
||||
}
|
||||
|
||||
var response Database
|
||||
@@ -132,20 +196,11 @@ func Test_CreateDatabase_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testin
|
||||
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
testDbName := "test_db"
|
||||
request := Database{
|
||||
Name: "Test Database",
|
||||
WorkspaceID: &workspace.ID,
|
||||
Type: DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
},
|
||||
Postgresql: getTestPostgresConfig(),
|
||||
}
|
||||
|
||||
testResp := test_utils.MakePostRequest(
|
||||
@@ -214,7 +269,13 @@ func Test_UpdateDatabase_PermissionsEnforced(t *testing.T) {
|
||||
testUserToken = owner.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
member,
|
||||
*tt.workspaceRole,
|
||||
owner.Token,
|
||||
router,
|
||||
)
|
||||
testUserToken = member.Token
|
||||
}
|
||||
|
||||
@@ -316,7 +377,13 @@ func Test_DeleteDatabase_PermissionsEnforced(t *testing.T) {
|
||||
testUserToken = owner.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
member,
|
||||
*tt.workspaceRole,
|
||||
owner.Token,
|
||||
router,
|
||||
)
|
||||
testUserToken = member.Token
|
||||
}
|
||||
|
||||
@@ -381,7 +448,13 @@ func Test_GetDatabase_PermissionsEnforced(t *testing.T) {
|
||||
testUser = admin.Token
|
||||
} else if tt.userRole != nil {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.userRole, owner.Token, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
member,
|
||||
*tt.userRole,
|
||||
owner.Token,
|
||||
router,
|
||||
)
|
||||
testUser = member.Token
|
||||
} else {
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
@@ -605,7 +678,13 @@ func Test_CopyDatabase_PermissionsEnforced(t *testing.T) {
|
||||
testUserToken = owner.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
member,
|
||||
*tt.workspaceRole,
|
||||
owner.Token,
|
||||
router,
|
||||
)
|
||||
testUserToken = member.Token
|
||||
}
|
||||
|
||||
@@ -737,7 +816,13 @@ func createTestDatabaseViaAPI(
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *Database {
|
||||
testDbName := "test_db"
|
||||
env := config.GetEnv()
|
||||
port, err := strconv.Atoi(env.TestPostgres16Port)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
|
||||
}
|
||||
|
||||
testDbName := "testdb"
|
||||
request := Database{
|
||||
Name: name,
|
||||
WorkspaceID: &workspaceID,
|
||||
@@ -745,9 +830,9 @@ func createTestDatabaseViaAPI(
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
},
|
||||
@@ -780,21 +865,14 @@ func Test_CreateDatabase_PasswordIsEncryptedInDB(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
testDbName := "test_db"
|
||||
plainPassword := "my-super-secret-password-123"
|
||||
pgConfig := getTestPostgresConfig()
|
||||
plainPassword := "testpassword"
|
||||
pgConfig.Password = plainPassword
|
||||
request := Database{
|
||||
Name: "Test Database",
|
||||
WorkspaceID: &workspace.ID,
|
||||
Type: DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: plainPassword,
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
},
|
||||
Postgresql: pgConfig,
|
||||
}
|
||||
|
||||
var createdDatabase Database
|
||||
@@ -854,38 +932,23 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
name: "PostgreSQL Database",
|
||||
databaseType: DatabaseTypePostgres,
|
||||
createDatabase: func(workspaceID uuid.UUID) *Database {
|
||||
testDbName := "test_db"
|
||||
pgConfig := getTestPostgresConfig()
|
||||
return &Database{
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: "Test PostgreSQL Database",
|
||||
Type: DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "original-password-secret",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
},
|
||||
Postgresql: pgConfig,
|
||||
}
|
||||
},
|
||||
updateDatabase: func(workspaceID uuid.UUID, databaseID uuid.UUID) *Database {
|
||||
testDbName := "updated_test_db"
|
||||
pgConfig := getTestPostgresConfig()
|
||||
pgConfig.Password = ""
|
||||
return &Database{
|
||||
ID: databaseID,
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: "Updated PostgreSQL Database",
|
||||
Type: DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion17,
|
||||
Host: "updated-host",
|
||||
Port: 5433,
|
||||
Username: "updated_user",
|
||||
Password: "",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
},
|
||||
Postgresql: pgConfig,
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, database *Database) {
|
||||
@@ -895,7 +958,7 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
decrypted, err := encryptor.Decrypt(database.ID, database.Postgresql.Password)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "original-password-secret", decrypted)
|
||||
assert.Equal(t, "testpassword", decrypted)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, database *Database) {
|
||||
assert.Equal(t, "", database.Postgresql.Password)
|
||||
@@ -905,36 +968,23 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
name: "MariaDB Database",
|
||||
databaseType: DatabaseTypeMariadb,
|
||||
createDatabase: func(workspaceID uuid.UUID) *Database {
|
||||
testDbName := "test_db"
|
||||
mariaConfig := getTestMariadbConfig()
|
||||
return &Database{
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: "Test MariaDB Database",
|
||||
Type: DatabaseTypeMariadb,
|
||||
Mariadb: &mariadb.MariadbDatabase{
|
||||
Version: tools.MariadbVersion1011,
|
||||
Host: "localhost",
|
||||
Port: 3306,
|
||||
Username: "root",
|
||||
Password: "original-password-secret",
|
||||
Database: &testDbName,
|
||||
},
|
||||
Mariadb: mariaConfig,
|
||||
}
|
||||
},
|
||||
updateDatabase: func(workspaceID uuid.UUID, databaseID uuid.UUID) *Database {
|
||||
testDbName := "updated_test_db"
|
||||
mariaConfig := getTestMariadbConfig()
|
||||
mariaConfig.Password = ""
|
||||
return &Database{
|
||||
ID: databaseID,
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: "Updated MariaDB Database",
|
||||
Type: DatabaseTypeMariadb,
|
||||
Mariadb: &mariadb.MariadbDatabase{
|
||||
Version: tools.MariadbVersion114,
|
||||
Host: "updated-host",
|
||||
Port: 3307,
|
||||
Username: "updated_user",
|
||||
Password: "",
|
||||
Database: &testDbName,
|
||||
},
|
||||
Mariadb: mariaConfig,
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, database *Database) {
|
||||
@@ -944,7 +994,7 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
decrypted, err := encryptor.Decrypt(database.ID, database.Mariadb.Password)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "original-password-secret", decrypted)
|
||||
assert.Equal(t, "testpassword", decrypted)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, database *Database) {
|
||||
assert.Equal(t, "", database.Mariadb.Password)
|
||||
@@ -954,40 +1004,23 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
name: "MongoDB Database",
|
||||
databaseType: DatabaseTypeMongodb,
|
||||
createDatabase: func(workspaceID uuid.UUID) *Database {
|
||||
mongoConfig := getTestMongodbConfig()
|
||||
return &Database{
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: "Test MongoDB Database",
|
||||
Type: DatabaseTypeMongodb,
|
||||
Mongodb: &mongodb.MongodbDatabase{
|
||||
Version: tools.MongodbVersion7,
|
||||
Host: "localhost",
|
||||
Port: 27017,
|
||||
Username: "root",
|
||||
Password: "original-password-secret",
|
||||
Database: "test_db",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
CpuCount: 1,
|
||||
},
|
||||
Mongodb: mongoConfig,
|
||||
}
|
||||
},
|
||||
updateDatabase: func(workspaceID uuid.UUID, databaseID uuid.UUID) *Database {
|
||||
mongoConfig := getTestMongodbConfig()
|
||||
mongoConfig.Password = ""
|
||||
return &Database{
|
||||
ID: databaseID,
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: "Updated MongoDB Database",
|
||||
Type: DatabaseTypeMongodb,
|
||||
Mongodb: &mongodb.MongodbDatabase{
|
||||
Version: tools.MongodbVersion8,
|
||||
Host: "updated-host",
|
||||
Port: 27018,
|
||||
Username: "updated_user",
|
||||
Password: "",
|
||||
Database: "updated_test_db",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
CpuCount: 1,
|
||||
},
|
||||
Mongodb: mongoConfig,
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, database *Database) {
|
||||
@@ -997,7 +1030,7 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
decrypted, err := encryptor.Decrypt(database.ID, database.Mongodb.Password)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "original-password-secret", decrypted)
|
||||
assert.Equal(t, "rootpassword", decrypted)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, database *Database) {
|
||||
assert.Equal(t, "", database.Mongodb.Password)
|
||||
|
||||
@@ -2,18 +2,20 @@ package mariadb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/tools"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
@@ -23,12 +25,13 @@ type MariadbDatabase struct {
|
||||
|
||||
Version tools.MariadbVersion `json:"version" gorm:"type:text;not null"`
|
||||
|
||||
Host string `json:"host" gorm:"type:text;not null"`
|
||||
Port int `json:"port" gorm:"type:int;not null"`
|
||||
Username string `json:"username" gorm:"type:text;not null"`
|
||||
Password string `json:"password" gorm:"type:text;not null"`
|
||||
Database *string `json:"database" gorm:"type:text"`
|
||||
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
|
||||
Host string `json:"host" gorm:"type:text;not null"`
|
||||
Port int `json:"port" gorm:"type:int;not null"`
|
||||
Username string `json:"username" gorm:"type:text;not null"`
|
||||
Password string `json:"password" gorm:"type:text;not null"`
|
||||
Database *string `json:"database" gorm:"type:text"`
|
||||
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
|
||||
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
|
||||
}
|
||||
|
||||
func (m *MariadbDatabase) TableName() string {
|
||||
@@ -94,6 +97,16 @@ func (m *MariadbDatabase) TestConnection(
|
||||
}
|
||||
m.Version = detectedVersion
|
||||
|
||||
privileges, err := detectPrivileges(ctx, db, *m.Database)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Privileges = privileges
|
||||
|
||||
if err := checkBackupPermissions(m.Privileges); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -111,6 +124,7 @@ func (m *MariadbDatabase) Update(incoming *MariadbDatabase) {
|
||||
m.Username = incoming.Username
|
||||
m.Database = incoming.Database
|
||||
m.IsHttps = incoming.IsHttps
|
||||
m.Privileges = incoming.Privileges
|
||||
|
||||
if incoming.Password != "" {
|
||||
m.Password = incoming.Password
|
||||
@@ -131,15 +145,48 @@ func (m *MariadbDatabase) EncryptSensitiveFields(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MariadbDatabase) PopulateVersionIfEmpty(
|
||||
func (m *MariadbDatabase) PopulateDbData(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
if m.Version != "" {
|
||||
if m.Database == nil || *m.Database == "" {
|
||||
return nil
|
||||
}
|
||||
return m.PopulateVersion(logger, encryptor, databaseID)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
dsn := m.buildDSN(password, *m.Database)
|
||||
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := db.Close(); closeErr != nil {
|
||||
logger.Error("Failed to close connection", "error", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
detectedVersion, err := detectMariadbVersion(ctx, db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Version = detectedVersion
|
||||
|
||||
privileges, err := detectPrivileges(ctx, db, *m.Database)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Privileges = privileges
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MariadbDatabase) PopulateVersion(
|
||||
@@ -175,8 +222,8 @@ func (m *MariadbDatabase) PopulateVersion(
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Version = detectedVersion
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -185,17 +232,17 @@ func (m *MariadbDatabase) IsUserReadOnly(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) (bool, error) {
|
||||
) (bool, []string, error) {
|
||||
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
dsn := m.buildDSN(password, *m.Database)
|
||||
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to connect to database: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := db.Close(); closeErr != nil {
|
||||
@@ -205,33 +252,44 @@ func (m *MariadbDatabase) IsUserReadOnly(
|
||||
|
||||
rows, err := db.QueryContext(ctx, "SHOW GRANTS FOR CURRENT_USER()")
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check grants: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to check grants: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
writePrivileges := []string{
|
||||
"INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER",
|
||||
"INDEX", "GRANT OPTION", "ALL PRIVILEGES", "SUPER",
|
||||
"EXECUTE", "FILE", "RELOAD", "SHUTDOWN", "CREATE ROUTINE",
|
||||
"ALTER ROUTINE", "CREATE USER",
|
||||
"CREATE TABLESPACE", "DELETE HISTORY", "REFERENCES",
|
||||
}
|
||||
|
||||
detectedPrivileges := make(map[string]bool)
|
||||
|
||||
for rows.Next() {
|
||||
var grant string
|
||||
if err := rows.Scan(&grant); err != nil {
|
||||
return false, fmt.Errorf("failed to scan grant: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to scan grant: %w", err)
|
||||
}
|
||||
|
||||
for _, priv := range writePrivileges {
|
||||
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
|
||||
return false, nil
|
||||
detectedPrivileges[priv] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return false, fmt.Errorf("error iterating grants: %w", err)
|
||||
return false, nil, fmt.Errorf("error iterating grants: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
privileges := make([]string, 0, len(detectedPrivileges))
|
||||
for priv := range detectedPrivileges {
|
||||
privileges = append(privileges, priv)
|
||||
}
|
||||
|
||||
isReadOnly := len(privileges) == 0
|
||||
return isReadOnly, privileges, nil
|
||||
}
|
||||
|
||||
func (m *MariadbDatabase) CreateReadOnlyUser(
|
||||
@@ -261,7 +319,7 @@ func (m *MariadbDatabase) CreateReadOnlyUser(
|
||||
for attempt := range maxRetries {
|
||||
// MariaDB 5.5 has a 16-character username limit, use shorter prefix
|
||||
newUsername := fmt.Sprintf("pgs-%s", uuid.New().String()[:8])
|
||||
newPassword := uuid.New().String()
|
||||
newPassword := encryption.GenerateComplexPassword()
|
||||
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
@@ -326,10 +384,31 @@ func (m *MariadbDatabase) CreateReadOnlyUser(
|
||||
return "", "", errors.New("failed to generate unique username after 3 attempts")
|
||||
}
|
||||
|
||||
func (m *MariadbDatabase) HasPrivilege(priv string) bool {
|
||||
return HasPrivilege(m.Privileges, priv)
|
||||
}
|
||||
|
||||
func HasPrivilege(privileges, priv string) bool {
|
||||
for _, p := range strings.Split(privileges, ",") {
|
||||
if strings.TrimSpace(p) == priv {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *MariadbDatabase) buildDSN(password string, database string) string {
|
||||
tlsConfig := "false"
|
||||
|
||||
if m.IsHttps {
|
||||
tlsConfig = "true"
|
||||
err := mysql.RegisterTLSConfig("mariadb-skip-verify", &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
if err != nil {
|
||||
// Config might already be registered, which is fine
|
||||
_ = err
|
||||
}
|
||||
tlsConfig = "mariadb-skip-verify"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
@@ -420,6 +499,104 @@ func mapMariadb11xVersion(minor string) (tools.MariadbVersion, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// detectPrivileges detects backup-related privileges and returns them as comma-separated string
|
||||
func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string, error) {
|
||||
rows, err := db.QueryContext(ctx, "SHOW GRANTS FOR CURRENT_USER()")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to check grants: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
backupPrivileges := []string{
|
||||
"SELECT", "SHOW VIEW", "LOCK TABLES", "TRIGGER", "EVENT",
|
||||
}
|
||||
|
||||
detectedPrivileges := make(map[string]bool)
|
||||
hasProcess := false
|
||||
hasAllPrivileges := false
|
||||
|
||||
dbPatternStr := fmt.Sprintf(
|
||||
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
|
||||
regexp.QuoteMeta(database),
|
||||
)
|
||||
dbPattern := regexp.MustCompile(dbPatternStr)
|
||||
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)
|
||||
allPrivilegesPattern := regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`)
|
||||
|
||||
for rows.Next() {
|
||||
var grant string
|
||||
if err := rows.Scan(&grant); err != nil {
|
||||
return "", fmt.Errorf("failed to scan grant: %w", err)
|
||||
}
|
||||
|
||||
isRelevantGrant := globalPattern.MatchString(grant) || dbPattern.MatchString(grant)
|
||||
|
||||
if allPrivilegesPattern.MatchString(grant) && isRelevantGrant {
|
||||
hasAllPrivileges = true
|
||||
}
|
||||
|
||||
if isRelevantGrant {
|
||||
for _, priv := range backupPrivileges {
|
||||
privPattern := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(priv) + `\b`)
|
||||
if privPattern.MatchString(grant) {
|
||||
detectedPrivileges[priv] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if globalPattern.MatchString(grant) {
|
||||
processPattern := regexp.MustCompile(`(?i)\bPROCESS\b`)
|
||||
if processPattern.MatchString(grant) {
|
||||
hasProcess = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return "", fmt.Errorf("error iterating grants: %w", err)
|
||||
}
|
||||
|
||||
if hasAllPrivileges {
|
||||
for _, priv := range backupPrivileges {
|
||||
detectedPrivileges[priv] = true
|
||||
}
|
||||
hasProcess = true
|
||||
}
|
||||
|
||||
privileges := make([]string, 0, len(detectedPrivileges)+1)
|
||||
for priv := range detectedPrivileges {
|
||||
privileges = append(privileges, priv)
|
||||
}
|
||||
if hasProcess {
|
||||
privileges = append(privileges, "PROCESS")
|
||||
}
|
||||
|
||||
sort.Strings(privileges)
|
||||
return strings.Join(privileges, ","), nil
|
||||
}
|
||||
|
||||
// checkBackupPermissions verifies the user has sufficient privileges for mariadb-dump backup.
|
||||
// Required: SELECT, SHOW VIEW
|
||||
func checkBackupPermissions(privileges string) error {
|
||||
requiredPrivileges := []string{"SELECT", "SHOW VIEW"}
|
||||
|
||||
var missingPrivileges []string
|
||||
for _, priv := range requiredPrivileges {
|
||||
if !HasPrivilege(privileges, priv) {
|
||||
missingPrivileges = append(missingPrivileges, priv)
|
||||
}
|
||||
}
|
||||
|
||||
if len(missingPrivileges) > 0 {
|
||||
return fmt.Errorf(
|
||||
"insufficient permissions for backup. Missing: %s. Required: SELECT, SHOW VIEW",
|
||||
strings.Join(missingPrivileges, ", "),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decryptPasswordIfNeeded(
|
||||
password string,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
|
||||
@@ -0,0 +1,757 @@
|
||||
package mariadb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MariadbVersion
|
||||
port string
|
||||
}{
|
||||
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
|
||||
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
|
||||
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
|
||||
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
|
||||
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
|
||||
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
|
||||
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
|
||||
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
|
||||
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
|
||||
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
|
||||
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMariadbContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS permission_test`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`CREATE TABLE permission_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`INSERT INTO permission_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
limitedUsername := fmt.Sprintf("limited_%s", uuid.New().String()[:8])
|
||||
limitedPassword := "limitedpassword123"
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
limitedUsername,
|
||||
limitedPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT SELECT ON `%s`.* TO '%s'@'%%'",
|
||||
container.Database,
|
||||
limitedUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer dropUserSafe(container.DB, limitedUsername)
|
||||
|
||||
mariadbModel := &MariadbDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: limitedUsername,
|
||||
Password: limitedPassword,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mariadbModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "insufficient permissions")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MariadbVersion
|
||||
port string
|
||||
}{
|
||||
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
|
||||
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
|
||||
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
|
||||
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
|
||||
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
|
||||
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
|
||||
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
|
||||
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
|
||||
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
|
||||
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
|
||||
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMariadbContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS backup_test`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`CREATE TABLE backup_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`INSERT INTO backup_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupUsername := fmt.Sprintf("backup_%s", uuid.New().String()[:8])
|
||||
backupPassword := "backuppassword123"
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
backupUsername,
|
||||
backupPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT SELECT, SHOW VIEW, LOCK TABLES, TRIGGER, EVENT ON `%s`.* TO '%s'@'%%'",
|
||||
container.Database,
|
||||
backupUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT PROCESS ON *.* TO '%s'@'%%'",
|
||||
backupUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer dropUserSafe(container.DB, backupUsername)
|
||||
|
||||
mariadbModel := &MariadbDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: backupUsername,
|
||||
Password: backupPassword,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mariadbModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MariadbVersion
|
||||
port string
|
||||
}{
|
||||
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
|
||||
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
|
||||
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
|
||||
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
|
||||
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
|
||||
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
|
||||
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
|
||||
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
|
||||
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
|
||||
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
|
||||
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMariadbContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
mariadbModel := createMariadbModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
isReadOnly, privileges, err := mariadbModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, isReadOnly, "Root user should not be read-only")
|
||||
assert.NotEmpty(t, privileges, "Root user should have privileges")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMariadbContainer(t, env.TestMariadb1011Port, tools.MariadbVersion1011)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS readonly_check_test`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`CREATE TABLE readonly_check_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`INSERT INTO readonly_check_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mariadbModel := createMariadbModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
|
||||
readOnlyModel := &MariadbDatabase{
|
||||
Version: mariadbModel.Version,
|
||||
Host: mariadbModel.Host,
|
||||
Port: mariadbModel.Port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: mariadbModel.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isReadOnly, "Read-only user should be read-only")
|
||||
assert.Empty(t, privileges, "Read-only user should have no write privileges")
|
||||
|
||||
dropUserSafe(container.DB, username)
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MariadbVersion
|
||||
port string
|
||||
}{
|
||||
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
|
||||
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
|
||||
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
|
||||
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
|
||||
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
|
||||
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
|
||||
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
|
||||
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
|
||||
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
|
||||
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
|
||||
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMariadbContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS readonly_test`)
|
||||
assert.NoError(t, err)
|
||||
_, err = container.DB.Exec(`DROP TABLE IF EXISTS hack_table`)
|
||||
assert.NoError(t, err)
|
||||
_, err = container.DB.Exec(`DROP TABLE IF EXISTS future_table`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`
|
||||
CREATE TABLE readonly_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(
|
||||
`INSERT INTO readonly_test (data) VALUES ('test1'), ('test2')`,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mariadbModel := createMariadbModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, username)
|
||||
assert.NotEmpty(t, password)
|
||||
assert.True(t, strings.HasPrefix(username, "pgs-"))
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
readOnlyModel := &MariadbDatabase{
|
||||
Version: mariadbModel.Version,
|
||||
Host: mariadbModel.Host,
|
||||
Port: mariadbModel.Port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: mariadbModel.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
|
||||
ctx,
|
||||
logger,
|
||||
nil,
|
||||
uuid.New(),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isReadOnly, "Created user should be read-only")
|
||||
assert.Empty(t, privileges, "Read-only user should have no write privileges")
|
||||
|
||||
readOnlyDSN := fmt.Sprintf(
|
||||
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
username,
|
||||
password,
|
||||
container.Host,
|
||||
container.Port,
|
||||
container.Database,
|
||||
)
|
||||
readOnlyConn, err := sqlx.Connect("mysql", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var count int
|
||||
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM readonly_test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
_, err = readOnlyConn.Exec("INSERT INTO readonly_test (data) VALUES ('should-fail')")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("UPDATE readonly_test SET data = 'hacked' WHERE id = 1")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("DELETE FROM readonly_test WHERE id = 1")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("CREATE TABLE hack_table (id INT)")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
dropUserSafe(container.DB, username)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ReadOnlyUser_FutureTables_NoSelectPermission(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMariadbContainer(t, env.TestMariadb1011Port, tools.MariadbVersion1011)
|
||||
defer container.DB.Close()
|
||||
|
||||
mariadbModel := createMariadbModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`DROP TABLE IF EXISTS future_table`)
|
||||
assert.NoError(t, err)
|
||||
_, err = container.DB.Exec(`
|
||||
CREATE TABLE future_table (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
_, err = container.DB.Exec(`INSERT INTO future_table (data) VALUES ('future_data')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
readOnlyDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
username, password, container.Host, container.Port, container.Database)
|
||||
readOnlyConn, err := sqlx.Connect("mysql", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var data string
|
||||
err = readOnlyConn.Get(&data, "SELECT data FROM future_table LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "future_data", data)
|
||||
|
||||
dropUserSafe(container.DB, username)
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMariadbContainer(t, env.TestMariadb1011Port, tools.MariadbVersion1011)
|
||||
defer container.DB.Close()
|
||||
|
||||
dashDbName := "test-db-with-dash"
|
||||
|
||||
_, err := container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dashDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", dashDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dashDbName))
|
||||
}()
|
||||
|
||||
dashDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username, container.Password, container.Host, container.Port, dashDbName)
|
||||
dashDB, err := sqlx.Connect("mysql", dashDSN)
|
||||
assert.NoError(t, err)
|
||||
defer dashDB.Close()
|
||||
|
||||
_, err = dashDB.Exec(`
|
||||
CREATE TABLE dash_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = dashDB.Exec(`INSERT INTO dash_test (data) VALUES ('test1'), ('test2')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mariadbModel := &MariadbDatabase{
|
||||
Version: tools.MariadbVersion1011,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: container.Username,
|
||||
Password: container.Password,
|
||||
Database: &dashDbName,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, username)
|
||||
assert.NotEmpty(t, password)
|
||||
assert.True(t, strings.HasPrefix(username, "pgs-"))
|
||||
|
||||
readOnlyDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
username, password, container.Host, container.Port, dashDbName)
|
||||
readOnlyConn, err := sqlx.Connect("mysql", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var count int
|
||||
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM dash_test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
_, err = readOnlyConn.Exec("INSERT INTO dash_test (data) VALUES ('should-fail')")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
dropUserSafe(dashDB, username)
|
||||
}
|
||||
|
||||
func Test_ReadOnlyUser_CannotDropOrAlterTables(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMariadbContainer(t, env.TestMariadb1011Port, tools.MariadbVersion1011)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS drop_test`)
|
||||
assert.NoError(t, err)
|
||||
_, err = container.DB.Exec(`
|
||||
CREATE TABLE drop_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
_, err = container.DB.Exec(`INSERT INTO drop_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mariadbModel := createMariadbModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
|
||||
readOnlyDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
username, password, container.Host, container.Port, container.Database)
|
||||
readOnlyConn, err := sqlx.Connect("mysql", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
_, err = readOnlyConn.Exec("DROP TABLE drop_test")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("ALTER TABLE drop_test ADD COLUMN new_col VARCHAR(100)")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("TRUNCATE TABLE drop_test")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
dropUserSafe(container.DB, username)
|
||||
}
|
||||
|
||||
func Test_TestConnection_DatabaseSpecificPrivilegesWithGlobalProcess_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MariadbVersion
|
||||
port string
|
||||
}{
|
||||
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
|
||||
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
|
||||
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
|
||||
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
|
||||
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
|
||||
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
|
||||
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
|
||||
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
|
||||
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
|
||||
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
|
||||
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMariadbContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS privilege_test`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`CREATE TABLE privilege_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`INSERT INTO privilege_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
specificUsername := fmt.Sprintf("spec_%s", uuid.New().String()[:8])
|
||||
specificPassword := "specificpass123"
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
specificUsername,
|
||||
specificPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT SELECT, SHOW VIEW ON %s.* TO '%s'@'%%'",
|
||||
container.Database,
|
||||
specificUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT PROCESS ON *.* TO '%s'@'%%'",
|
||||
specificUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer dropUserSafe(container.DB, specificUsername)
|
||||
|
||||
mariadbModel := &MariadbDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: specificUsername,
|
||||
Password: specificPassword,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mariadbModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMariadbContainer(t, env.TestMariadb1011Port, tools.MariadbVersion1011)
|
||||
defer container.DB.Close()
|
||||
|
||||
underscoreDbName := "test_db_name"
|
||||
|
||||
_, err := container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
|
||||
}()
|
||||
|
||||
underscoreDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username, container.Password, container.Host, container.Port, underscoreDbName)
|
||||
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
|
||||
assert.NoError(t, err)
|
||||
defer underscoreDB.Close()
|
||||
|
||||
_, err = underscoreDB.Exec(`
|
||||
CREATE TABLE underscore_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(`INSERT INTO underscore_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
underscoreUsername := fmt.Sprintf("under%s", uuid.New().String()[:8])
|
||||
underscorePassword := "underscorepass123"
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
underscoreUsername,
|
||||
underscorePassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"GRANT SELECT, SHOW VIEW ON `%s`.* TO '%s'@'%%'",
|
||||
underscoreDbName,
|
||||
underscoreUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer dropUserSafe(underscoreDB, underscoreUsername)
|
||||
|
||||
mariadbModel := &MariadbDatabase{
|
||||
Version: tools.MariadbVersion1011,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: underscoreUsername,
|
||||
Password: underscorePassword,
|
||||
Database: &underscoreDbName,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mariadbModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
type MariadbContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
Database string
|
||||
Version tools.MariadbVersion
|
||||
DB *sqlx.DB
|
||||
}
|
||||
|
||||
func connectToMariadbContainer(
|
||||
t *testing.T,
|
||||
port string,
|
||||
version tools.MariadbVersion,
|
||||
) *MariadbContainer {
|
||||
if port == "" {
|
||||
t.Skipf("MariaDB port not configured for version %s", version)
|
||||
}
|
||||
|
||||
dbName := "testdb"
|
||||
host := "127.0.0.1"
|
||||
username := "root"
|
||||
password := "rootpassword"
|
||||
|
||||
portInt, err := strconv.Atoi(port)
|
||||
assert.NoError(t, err)
|
||||
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
username, password, host, portInt, dbName)
|
||||
|
||||
db, err := sqlx.Connect("mysql", dsn)
|
||||
if err != nil {
|
||||
t.Skipf("Failed to connect to MariaDB %s: %v", version, err)
|
||||
}
|
||||
|
||||
return &MariadbContainer{
|
||||
Host: host,
|
||||
Port: portInt,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: dbName,
|
||||
Version: version,
|
||||
DB: db,
|
||||
}
|
||||
}
|
||||
|
||||
func createMariadbModel(container *MariadbContainer) *MariadbDatabase {
|
||||
return &MariadbDatabase{
|
||||
Version: container.Version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: container.Username,
|
||||
Password: container.Password,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
}
|
||||
|
||||
func dropUserSafe(db *sqlx.DB, username string) {
|
||||
_, _ = db.Exec(fmt.Sprintf("DROP USER '%s'@'%%'", username))
|
||||
}
|
||||
@@ -1,387 +0,0 @@
|
||||
package mariadb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MariadbVersion
|
||||
port string
|
||||
}{
|
||||
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
|
||||
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
|
||||
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
|
||||
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
|
||||
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
|
||||
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
|
||||
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
|
||||
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
|
||||
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
|
||||
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
|
||||
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMariadbContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
mariadbModel := createMariadbModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
isReadOnly, err := mariadbModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, isReadOnly, "Root user should not be read-only")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MariadbVersion
|
||||
port string
|
||||
}{
|
||||
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
|
||||
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
|
||||
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
|
||||
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
|
||||
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
|
||||
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
|
||||
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
|
||||
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
|
||||
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
|
||||
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
|
||||
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMariadbContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS readonly_test`)
|
||||
assert.NoError(t, err)
|
||||
_, err = container.DB.Exec(`DROP TABLE IF EXISTS hack_table`)
|
||||
assert.NoError(t, err)
|
||||
_, err = container.DB.Exec(`DROP TABLE IF EXISTS future_table`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`
|
||||
CREATE TABLE readonly_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(
|
||||
`INSERT INTO readonly_test (data) VALUES ('test1'), ('test2')`,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mariadbModel := createMariadbModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, username)
|
||||
assert.NotEmpty(t, password)
|
||||
assert.True(t, strings.HasPrefix(username, "pgs-"))
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
readOnlyModel := &MariadbDatabase{
|
||||
Version: mariadbModel.Version,
|
||||
Host: mariadbModel.Host,
|
||||
Port: mariadbModel.Port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: mariadbModel.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
isReadOnly, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isReadOnly, "Created user should be read-only")
|
||||
|
||||
readOnlyDSN := fmt.Sprintf(
|
||||
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
username,
|
||||
password,
|
||||
container.Host,
|
||||
container.Port,
|
||||
container.Database,
|
||||
)
|
||||
readOnlyConn, err := sqlx.Connect("mysql", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var count int
|
||||
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM readonly_test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
_, err = readOnlyConn.Exec("INSERT INTO readonly_test (data) VALUES ('should-fail')")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("UPDATE readonly_test SET data = 'hacked' WHERE id = 1")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("DELETE FROM readonly_test WHERE id = 1")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("CREATE TABLE hack_table (id INT)")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
dropUserSafe(container.DB, username)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ReadOnlyUser_FutureTables_NoSelectPermission(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMariadbContainer(t, env.TestMariadb1011Port, tools.MariadbVersion1011)
|
||||
defer container.DB.Close()
|
||||
|
||||
mariadbModel := createMariadbModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`DROP TABLE IF EXISTS future_table`)
|
||||
assert.NoError(t, err)
|
||||
_, err = container.DB.Exec(`
|
||||
CREATE TABLE future_table (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
_, err = container.DB.Exec(`INSERT INTO future_table (data) VALUES ('future_data')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
readOnlyDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
username, password, container.Host, container.Port, container.Database)
|
||||
readOnlyConn, err := sqlx.Connect("mysql", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var data string
|
||||
err = readOnlyConn.Get(&data, "SELECT data FROM future_table LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "future_data", data)
|
||||
|
||||
dropUserSafe(container.DB, username)
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMariadbContainer(t, env.TestMariadb1011Port, tools.MariadbVersion1011)
|
||||
defer container.DB.Close()
|
||||
|
||||
dashDbName := "test-db-with-dash"
|
||||
|
||||
_, err := container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dashDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", dashDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dashDbName))
|
||||
}()
|
||||
|
||||
dashDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username, container.Password, container.Host, container.Port, dashDbName)
|
||||
dashDB, err := sqlx.Connect("mysql", dashDSN)
|
||||
assert.NoError(t, err)
|
||||
defer dashDB.Close()
|
||||
|
||||
_, err = dashDB.Exec(`
|
||||
CREATE TABLE dash_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = dashDB.Exec(`INSERT INTO dash_test (data) VALUES ('test1'), ('test2')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mariadbModel := &MariadbDatabase{
|
||||
Version: tools.MariadbVersion1011,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: container.Username,
|
||||
Password: container.Password,
|
||||
Database: &dashDbName,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, username)
|
||||
assert.NotEmpty(t, password)
|
||||
assert.True(t, strings.HasPrefix(username, "pgs-"))
|
||||
|
||||
readOnlyDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
username, password, container.Host, container.Port, dashDbName)
|
||||
readOnlyConn, err := sqlx.Connect("mysql", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var count int
|
||||
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM dash_test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
_, err = readOnlyConn.Exec("INSERT INTO dash_test (data) VALUES ('should-fail')")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
dropUserSafe(dashDB, username)
|
||||
}
|
||||
|
||||
func Test_ReadOnlyUser_CannotDropOrAlterTables(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMariadbContainer(t, env.TestMariadb1011Port, tools.MariadbVersion1011)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS drop_test`)
|
||||
assert.NoError(t, err)
|
||||
_, err = container.DB.Exec(`
|
||||
CREATE TABLE drop_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
_, err = container.DB.Exec(`INSERT INTO drop_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mariadbModel := createMariadbModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
|
||||
readOnlyDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
username, password, container.Host, container.Port, container.Database)
|
||||
readOnlyConn, err := sqlx.Connect("mysql", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
_, err = readOnlyConn.Exec("DROP TABLE drop_test")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("ALTER TABLE drop_test ADD COLUMN new_col VARCHAR(100)")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("TRUNCATE TABLE drop_test")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
dropUserSafe(container.DB, username)
|
||||
}
|
||||
|
||||
type MariadbContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
Database string
|
||||
Version tools.MariadbVersion
|
||||
DB *sqlx.DB
|
||||
}
|
||||
|
||||
func connectToMariadbContainer(
|
||||
t *testing.T,
|
||||
port string,
|
||||
version tools.MariadbVersion,
|
||||
) *MariadbContainer {
|
||||
if port == "" {
|
||||
t.Skipf("MariaDB port not configured for version %s", version)
|
||||
}
|
||||
|
||||
dbName := "testdb"
|
||||
host := "127.0.0.1"
|
||||
username := "root"
|
||||
password := "rootpassword"
|
||||
|
||||
portInt, err := strconv.Atoi(port)
|
||||
assert.NoError(t, err)
|
||||
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
username, password, host, portInt, dbName)
|
||||
|
||||
db, err := sqlx.Connect("mysql", dsn)
|
||||
if err != nil {
|
||||
t.Skipf("Failed to connect to MariaDB %s: %v", version, err)
|
||||
}
|
||||
|
||||
return &MariadbContainer{
|
||||
Host: host,
|
||||
Port: portInt,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: dbName,
|
||||
Version: version,
|
||||
DB: db,
|
||||
}
|
||||
}
|
||||
|
||||
func createMariadbModel(container *MariadbContainer) *MariadbDatabase {
|
||||
return &MariadbDatabase{
|
||||
Version: container.Version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: container.Username,
|
||||
Password: container.Password,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
}
|
||||
|
||||
func dropUserSafe(db *sqlx.DB, username string) {
|
||||
// MariaDB 5.5 doesn't support DROP USER IF EXISTS, so we ignore errors
|
||||
_, _ = db.Exec(fmt.Sprintf("DROP USER '%s'@'%%'", username))
|
||||
}
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"databasus-backend/internal/util/encryption"
|
||||
@@ -95,6 +97,16 @@ func (m *MongodbDatabase) TestConnection(
|
||||
}
|
||||
m.Version = detectedVersion
|
||||
|
||||
if err := checkBackupPermissions(
|
||||
ctx,
|
||||
client,
|
||||
m.Username,
|
||||
m.Database,
|
||||
m.AuthDatabase,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -134,14 +146,11 @@ func (m *MongodbDatabase) EncryptSensitiveFields(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MongodbDatabase) PopulateVersionIfEmpty(
|
||||
func (m *MongodbDatabase) PopulateDbData(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
if m.Version != "" {
|
||||
return nil
|
||||
}
|
||||
return m.PopulateVersion(logger, encryptor, databaseID)
|
||||
}
|
||||
|
||||
@@ -185,10 +194,10 @@ func (m *MongodbDatabase) IsUserReadOnly(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) (bool, error) {
|
||||
) (bool, []string, error) {
|
||||
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
uri := m.buildConnectionURI(password)
|
||||
@@ -196,7 +205,7 @@ func (m *MongodbDatabase) IsUserReadOnly(
|
||||
clientOptions := options.Client().ApplyURI(uri)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to connect to database: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if disconnectErr := client.Disconnect(ctx); disconnectErr != nil {
|
||||
@@ -218,44 +227,153 @@ func (m *MongodbDatabase) IsUserReadOnly(
|
||||
}},
|
||||
}).Decode(&result)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get user info: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to get user info: %w", err)
|
||||
}
|
||||
|
||||
writeRoles := []string{
|
||||
"readWrite", "readWriteAnyDatabase", "dbAdmin", "dbAdminAnyDatabase",
|
||||
"userAdmin", "userAdminAnyDatabase", "clusterAdmin", "root",
|
||||
"dbOwner", "backup", "restore",
|
||||
writeRoles := map[string]bool{
|
||||
"readWrite": true,
|
||||
"readWriteAnyDatabase": true,
|
||||
"dbAdmin": true,
|
||||
"dbAdminAnyDatabase": true,
|
||||
"userAdmin": true,
|
||||
"userAdminAnyDatabase": true,
|
||||
"clusterAdmin": true,
|
||||
"clusterManager": true,
|
||||
"hostManager": true,
|
||||
"root": true,
|
||||
"dbOwner": true,
|
||||
"restore": true,
|
||||
"__system": true,
|
||||
}
|
||||
|
||||
// Roles that are read-only for our backup purposes
|
||||
// The "backup" role has insert/update on mms.backup collection but is needed for mongodump
|
||||
readOnlyRoles := map[string]bool{
|
||||
"read": true,
|
||||
"backup": true,
|
||||
}
|
||||
|
||||
writeActions := map[string]bool{
|
||||
"insert": true,
|
||||
"update": true,
|
||||
"remove": true,
|
||||
"createCollection": true,
|
||||
"dropCollection": true,
|
||||
"createIndex": true,
|
||||
"dropIndex": true,
|
||||
"convertToCapped": true,
|
||||
"dropDatabase": true,
|
||||
"renameCollection": true,
|
||||
"createUser": true,
|
||||
"dropUser": true,
|
||||
"updateUser": true,
|
||||
"grantRole": true,
|
||||
"revokeRole": true,
|
||||
"dropRole": true,
|
||||
"createRole": true,
|
||||
"updateRole": true,
|
||||
"enableSharding": true,
|
||||
"shardCollection": true,
|
||||
"addShard": true,
|
||||
"removeShard": true,
|
||||
"shutdown": true,
|
||||
"replSetReconfig": true,
|
||||
"replSetStateChange": true,
|
||||
}
|
||||
|
||||
var detectedRoles []string
|
||||
|
||||
users, ok := result["users"].(bson.A)
|
||||
if !ok || len(users) == 0 {
|
||||
return true, nil
|
||||
return true, detectedRoles, nil
|
||||
}
|
||||
|
||||
user, ok := users[0].(bson.M)
|
||||
if !ok {
|
||||
return true, nil
|
||||
return true, detectedRoles, nil
|
||||
}
|
||||
|
||||
roles, ok := user["roles"].(bson.A)
|
||||
if !ok {
|
||||
return true, nil
|
||||
return true, detectedRoles, nil
|
||||
}
|
||||
|
||||
// Collect all role names and check for write roles
|
||||
for _, roleDoc := range roles {
|
||||
role, ok := roleDoc.(bson.M)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
roleName, _ := role["role"].(string)
|
||||
for _, writeRole := range writeRoles {
|
||||
if roleName == writeRole {
|
||||
return false, nil
|
||||
if roleName != "" {
|
||||
detectedRoles = append(detectedRoles, roleName)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if any detected role is a write role
|
||||
for _, roleName := range detectedRoles {
|
||||
if writeRoles[roleName] {
|
||||
return false, detectedRoles, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If all roles are known read-only roles (read, backup), skip inherited privilege check
|
||||
allRolesReadOnly := true
|
||||
for _, roleName := range detectedRoles {
|
||||
if !readOnlyRoles[roleName] {
|
||||
allRolesReadOnly = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allRolesReadOnly && len(detectedRoles) > 0 {
|
||||
return true, detectedRoles, nil
|
||||
}
|
||||
|
||||
// Check inherited privileges for custom roles
|
||||
var privResult bson.M
|
||||
err = adminDB.RunCommand(ctx, bson.D{
|
||||
{Key: "usersInfo", Value: bson.D{
|
||||
{Key: "user", Value: m.Username},
|
||||
{Key: "db", Value: authDB},
|
||||
}},
|
||||
{Key: "showPrivileges", Value: true},
|
||||
}).Decode(&privResult)
|
||||
if err != nil {
|
||||
return false, nil, fmt.Errorf("failed to get user privileges: %w", err)
|
||||
}
|
||||
|
||||
privUsers, ok := privResult["users"].(bson.A)
|
||||
if !ok || len(privUsers) == 0 {
|
||||
return true, detectedRoles, nil
|
||||
}
|
||||
|
||||
privUser, ok := privUsers[0].(bson.M)
|
||||
if !ok {
|
||||
return true, detectedRoles, nil
|
||||
}
|
||||
|
||||
// Check inheritedPrivileges for write actions
|
||||
inheritedPrivileges, ok := privUser["inheritedPrivileges"].(bson.A)
|
||||
if ok {
|
||||
for _, privDoc := range inheritedPrivileges {
|
||||
priv, ok := privDoc.(bson.M)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
actions, ok := priv["actions"].(bson.A)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, action := range actions {
|
||||
actionStr, ok := action.(string)
|
||||
if ok && writeActions[actionStr] {
|
||||
return false, detectedRoles, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
return true, detectedRoles, nil
|
||||
}
|
||||
|
||||
func (m *MongodbDatabase) CreateReadOnlyUser(
|
||||
@@ -290,7 +408,7 @@ func (m *MongodbDatabase) CreateReadOnlyUser(
|
||||
maxRetries := 3
|
||||
for attempt := range maxRetries {
|
||||
newUsername := fmt.Sprintf("databasus-%s", uuid.New().String()[:8])
|
||||
newPassword := uuid.New().String()
|
||||
newPassword := encryption.GenerateComplexPassword()
|
||||
|
||||
adminDB := client.Database(authDB)
|
||||
err = adminDB.RunCommand(ctx, bson.D{
|
||||
@@ -332,20 +450,20 @@ func (m *MongodbDatabase) buildConnectionURI(password string) string {
|
||||
authDB = "admin"
|
||||
}
|
||||
|
||||
tlsOption := "false"
|
||||
tlsParams := ""
|
||||
if m.IsHttps {
|
||||
tlsOption = "true"
|
||||
tlsParams = "&tls=true&tlsInsecure=true"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"mongodb://%s:%s@%s:%d/%s?authSource=%s&tls=%s&connectTimeoutMS=15000",
|
||||
m.Username,
|
||||
password,
|
||||
"mongodb://%s:%s@%s:%d/%s?authSource=%s&connectTimeoutMS=15000%s",
|
||||
url.QueryEscape(m.Username),
|
||||
url.QueryEscape(password),
|
||||
m.Host,
|
||||
m.Port,
|
||||
m.Database,
|
||||
authDB,
|
||||
tlsOption,
|
||||
tlsParams,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -356,19 +474,19 @@ func (m *MongodbDatabase) BuildMongodumpURI(password string) string {
|
||||
authDB = "admin"
|
||||
}
|
||||
|
||||
tlsOption := "false"
|
||||
tlsParams := ""
|
||||
if m.IsHttps {
|
||||
tlsOption = "true"
|
||||
tlsParams = "&tls=true&tlsInsecure=true"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"mongodb://%s:%s@%s:%d/?authSource=%s&tls=%s&connectTimeoutMS=15000",
|
||||
m.Username,
|
||||
password,
|
||||
"mongodb://%s:%s@%s:%d/?authSource=%s&connectTimeoutMS=15000%s",
|
||||
url.QueryEscape(m.Username),
|
||||
url.QueryEscape(password),
|
||||
m.Host,
|
||||
m.Port,
|
||||
authDB,
|
||||
tlsOption,
|
||||
tlsParams,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -413,6 +531,128 @@ func detectMongodbVersion(ctx context.Context, client *mongo.Client) (tools.Mong
|
||||
}
|
||||
}
|
||||
|
||||
// checkBackupPermissions verifies the user has sufficient privileges for mongodump backup.
|
||||
// Required: 'read' role on target database OR 'backup' role on admin OR 'readAnyDatabase' role.
|
||||
func checkBackupPermissions(
|
||||
ctx context.Context,
|
||||
client *mongo.Client,
|
||||
username, database, authDatabase string,
|
||||
) error {
|
||||
authDB := authDatabase
|
||||
if authDB == "" {
|
||||
authDB = "admin"
|
||||
}
|
||||
|
||||
adminDB := client.Database(authDB)
|
||||
var result bson.M
|
||||
err := adminDB.RunCommand(ctx, bson.D{
|
||||
{Key: "usersInfo", Value: bson.D{
|
||||
{Key: "user", Value: username},
|
||||
{Key: "db", Value: authDB},
|
||||
}},
|
||||
{Key: "showPrivileges", Value: true},
|
||||
}).Decode(&result)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user info: %w", err)
|
||||
}
|
||||
|
||||
users, ok := result["users"].(bson.A)
|
||||
if !ok || len(users) == 0 {
|
||||
return errors.New("insufficient permissions for backup. User not found")
|
||||
}
|
||||
|
||||
user, ok := users[0].(bson.M)
|
||||
if !ok {
|
||||
return errors.New("insufficient permissions for backup. Could not parse user info")
|
||||
}
|
||||
|
||||
// Check roles for backup permissions
|
||||
roles, ok := user["roles"].(bson.A)
|
||||
if !ok {
|
||||
return errors.New("insufficient permissions for backup. No roles assigned")
|
||||
}
|
||||
|
||||
backupRoles := map[string]bool{
|
||||
"backup": true,
|
||||
"root": true,
|
||||
"readAnyDatabase": true,
|
||||
"dbOwner": true,
|
||||
"__system": true,
|
||||
"clusterAdmin": true,
|
||||
"readWriteAnyDatabase": true,
|
||||
}
|
||||
|
||||
var userRoles []string
|
||||
hasBackupRole := false
|
||||
hasReadOnTargetDB := false
|
||||
|
||||
for _, roleDoc := range roles {
|
||||
role, ok := roleDoc.(bson.M)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
roleName, _ := role["role"].(string)
|
||||
roleDB, _ := role["db"].(string)
|
||||
|
||||
if roleName != "" {
|
||||
userRoles = append(userRoles, roleName)
|
||||
}
|
||||
|
||||
if backupRoles[roleName] {
|
||||
hasBackupRole = true
|
||||
}
|
||||
|
||||
if roleName == "read" && (roleDB == database || roleDB == "") {
|
||||
hasReadOnTargetDB = true
|
||||
}
|
||||
if roleName == "readWrite" && (roleDB == database || roleDB == "") {
|
||||
hasReadOnTargetDB = true
|
||||
}
|
||||
}
|
||||
|
||||
if hasBackupRole || hasReadOnTargetDB {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check inherited privileges for 'find' action on target database
|
||||
inheritedPrivileges, ok := user["inheritedPrivileges"].(bson.A)
|
||||
if ok {
|
||||
for _, privDoc := range inheritedPrivileges {
|
||||
priv, ok := privDoc.(bson.M)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
resource, ok := priv["resource"].(bson.M)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
resourceDB, _ := resource["db"].(string)
|
||||
resourceCluster, _ := resource["cluster"].(bool)
|
||||
|
||||
isTargetDB := resourceDB == database || resourceDB == "" || resourceCluster
|
||||
|
||||
actions, ok := priv["actions"].(bson.A)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, action := range actions {
|
||||
actionStr, ok := action.(string)
|
||||
if ok && actionStr == "find" && isTargetDB {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf(
|
||||
"insufficient permissions for backup. Current roles: %s. Required: 'read' role on database '%s' OR 'backup' role on admin OR 'readAnyDatabase' role",
|
||||
strings.Join(userRoles, ", "),
|
||||
database,
|
||||
)
|
||||
}
|
||||
|
||||
func decryptPasswordIfNeeded(
|
||||
password string,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -19,6 +20,138 @@ import (
|
||||
"databasus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MongodbVersion
|
||||
port string
|
||||
}{
|
||||
{"MongoDB 4.0", tools.MongodbVersion4, env.TestMongodb40Port},
|
||||
{"MongoDB 4.2", tools.MongodbVersion4, env.TestMongodb42Port},
|
||||
{"MongoDB 4.4", tools.MongodbVersion4, env.TestMongodb44Port},
|
||||
{"MongoDB 5.0", tools.MongodbVersion5, env.TestMongodb50Port},
|
||||
{"MongoDB 6.0", tools.MongodbVersion6, env.TestMongodb60Port},
|
||||
{"MongoDB 7.0", tools.MongodbVersion7, env.TestMongodb70Port},
|
||||
{"MongoDB 8.2", tools.MongodbVersion8, env.TestMongodb82Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMongodbContainer(t, tc.port, tc.version)
|
||||
defer container.Client.Disconnect(context.Background())
|
||||
|
||||
ctx := context.Background()
|
||||
db := container.Client.Database(container.Database)
|
||||
|
||||
_ = db.Collection("permission_test").Drop(ctx)
|
||||
_, err := db.Collection("permission_test").InsertOne(ctx, bson.M{"data": "test1"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
limitedUsername := fmt.Sprintf("limited_%s", uuid.New().String()[:8])
|
||||
limitedPassword := "limitedpassword123"
|
||||
|
||||
adminDB := container.Client.Database(container.AuthDatabase)
|
||||
err = adminDB.RunCommand(ctx, bson.D{
|
||||
{Key: "createUser", Value: limitedUsername},
|
||||
{Key: "pwd", Value: limitedPassword},
|
||||
{Key: "roles", Value: bson.A{}},
|
||||
}).Err()
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer dropUserSafe(container.Client, limitedUsername, container.AuthDatabase)
|
||||
|
||||
mongodbModel := &MongodbDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: limitedUsername,
|
||||
Password: limitedPassword,
|
||||
Database: container.Database,
|
||||
AuthDatabase: container.AuthDatabase,
|
||||
IsHttps: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mongodbModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "insufficient permissions")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MongodbVersion
|
||||
port string
|
||||
}{
|
||||
{"MongoDB 4.0", tools.MongodbVersion4, env.TestMongodb40Port},
|
||||
{"MongoDB 4.2", tools.MongodbVersion4, env.TestMongodb42Port},
|
||||
{"MongoDB 4.4", tools.MongodbVersion4, env.TestMongodb44Port},
|
||||
{"MongoDB 5.0", tools.MongodbVersion5, env.TestMongodb50Port},
|
||||
{"MongoDB 6.0", tools.MongodbVersion6, env.TestMongodb60Port},
|
||||
{"MongoDB 7.0", tools.MongodbVersion7, env.TestMongodb70Port},
|
||||
{"MongoDB 8.2", tools.MongodbVersion8, env.TestMongodb82Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMongodbContainer(t, tc.port, tc.version)
|
||||
defer container.Client.Disconnect(context.Background())
|
||||
|
||||
ctx := context.Background()
|
||||
db := container.Client.Database(container.Database)
|
||||
|
||||
_ = db.Collection("backup_test").Drop(ctx)
|
||||
_, err := db.Collection("backup_test").InsertOne(ctx, bson.M{"data": "test1"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupUsername := fmt.Sprintf("backup_%s", uuid.New().String()[:8])
|
||||
backupPassword := "backuppassword123"
|
||||
|
||||
adminDB := container.Client.Database(container.AuthDatabase)
|
||||
err = adminDB.RunCommand(ctx, bson.D{
|
||||
{Key: "createUser", Value: backupUsername},
|
||||
{Key: "pwd", Value: backupPassword},
|
||||
{Key: "roles", Value: bson.A{
|
||||
bson.D{
|
||||
{Key: "role", Value: "read"},
|
||||
{Key: "db", Value: container.Database},
|
||||
},
|
||||
}},
|
||||
}).Err()
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer dropUserSafe(container.Client, backupUsername, container.AuthDatabase)
|
||||
|
||||
mongodbModel := &MongodbDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: backupUsername,
|
||||
Password: backupPassword,
|
||||
Database: container.Database,
|
||||
AuthDatabase: container.AuthDatabase,
|
||||
IsHttps: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mongodbModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
@@ -46,13 +179,52 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
isReadOnly, err := mongodbModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
isReadOnly, roles, err := mongodbModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, isReadOnly, "Root user should not be read-only")
|
||||
assert.NotEmpty(t, roles, "Root user should have roles")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMongodbContainer(t, env.TestMongodb70Port, tools.MongodbVersion7)
|
||||
defer container.Client.Disconnect(context.Background())
|
||||
|
||||
ctx := context.Background()
|
||||
db := container.Client.Database(container.Database)
|
||||
|
||||
_ = db.Collection("readonly_check_test").Drop(ctx)
|
||||
_, err := db.Collection("readonly_check_test").InsertOne(ctx, bson.M{"data": "test1"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
mongodbModel := createMongodbModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
username, password, err := mongodbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
|
||||
readOnlyModel := &MongodbDatabase{
|
||||
Version: mongodbModel.Version,
|
||||
Host: mongodbModel.Host,
|
||||
Port: mongodbModel.Port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: mongodbModel.Database,
|
||||
AuthDatabase: mongodbModel.AuthDatabase,
|
||||
IsHttps: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
isReadOnly, roles, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isReadOnly, "Read-only user should be read-only")
|
||||
assert.NotEmpty(t, roles, "Read-only user should have roles (read, backup)")
|
||||
|
||||
dropUserSafe(container.Client, username, container.AuthDatabase)
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
@@ -271,6 +443,7 @@ func createMongodbModel(container *MongodbContainer) *MongodbDatabase {
|
||||
Database: container.Database,
|
||||
AuthDatabase: container.AuthDatabase,
|
||||
IsHttps: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -281,7 +454,8 @@ func connectWithCredentials(
|
||||
) *mongo.Client {
|
||||
uri := fmt.Sprintf(
|
||||
"mongodb://%s:%s@%s:%d/%s?authSource=%s",
|
||||
username, password, container.Host, container.Port,
|
||||
url.QueryEscape(username), url.QueryEscape(password),
|
||||
container.Host, container.Port,
|
||||
container.Database, container.AuthDatabase,
|
||||
)
|
||||
|
||||
@@ -2,17 +2,20 @@ package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/tools"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
@@ -22,12 +25,13 @@ type MysqlDatabase struct {
|
||||
|
||||
Version tools.MysqlVersion `json:"version" gorm:"type:text;not null"`
|
||||
|
||||
Host string `json:"host" gorm:"type:text;not null"`
|
||||
Port int `json:"port" gorm:"type:int;not null"`
|
||||
Username string `json:"username" gorm:"type:text;not null"`
|
||||
Password string `json:"password" gorm:"type:text;not null"`
|
||||
Database *string `json:"database" gorm:"type:text"`
|
||||
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
|
||||
Host string `json:"host" gorm:"type:text;not null"`
|
||||
Port int `json:"port" gorm:"type:int;not null"`
|
||||
Username string `json:"username" gorm:"type:text;not null"`
|
||||
Password string `json:"password" gorm:"type:text;not null"`
|
||||
Database *string `json:"database" gorm:"type:text"`
|
||||
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
|
||||
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
|
||||
}
|
||||
|
||||
func (m *MysqlDatabase) TableName() string {
|
||||
@@ -93,6 +97,16 @@ func (m *MysqlDatabase) TestConnection(
|
||||
}
|
||||
m.Version = detectedVersion
|
||||
|
||||
privileges, err := detectPrivileges(ctx, db, *m.Database)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Privileges = privileges
|
||||
|
||||
if err := checkBackupPermissions(m.Privileges); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -110,6 +124,7 @@ func (m *MysqlDatabase) Update(incoming *MysqlDatabase) {
|
||||
m.Username = incoming.Username
|
||||
m.Database = incoming.Database
|
||||
m.IsHttps = incoming.IsHttps
|
||||
m.Privileges = incoming.Privileges
|
||||
|
||||
if incoming.Password != "" {
|
||||
m.Password = incoming.Password
|
||||
@@ -130,15 +145,48 @@ func (m *MysqlDatabase) EncryptSensitiveFields(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MysqlDatabase) PopulateVersionIfEmpty(
|
||||
func (m *MysqlDatabase) PopulateDbData(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
if m.Version != "" {
|
||||
if m.Database == nil || *m.Database == "" {
|
||||
return nil
|
||||
}
|
||||
return m.PopulateVersion(logger, encryptor, databaseID)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
dsn := m.buildDSN(password, *m.Database)
|
||||
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := db.Close(); closeErr != nil {
|
||||
logger.Error("Failed to close connection", "error", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
detectedVersion, err := detectMysqlVersion(ctx, db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Version = detectedVersion
|
||||
|
||||
privileges, err := detectPrivileges(ctx, db, *m.Database)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Privileges = privileges
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MysqlDatabase) PopulateVersion(
|
||||
@@ -174,8 +222,8 @@ func (m *MysqlDatabase) PopulateVersion(
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Version = detectedVersion
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -184,17 +232,17 @@ func (m *MysqlDatabase) IsUserReadOnly(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) (bool, error) {
|
||||
) (bool, []string, error) {
|
||||
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
dsn := m.buildDSN(password, *m.Database)
|
||||
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to connect to database: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := db.Close(); closeErr != nil {
|
||||
@@ -204,33 +252,45 @@ func (m *MysqlDatabase) IsUserReadOnly(
|
||||
|
||||
rows, err := db.QueryContext(ctx, "SHOW GRANTS FOR CURRENT_USER()")
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check grants: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to check grants: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
writePrivileges := []string{
|
||||
"INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER",
|
||||
"INDEX", "GRANT OPTION", "ALL PRIVILEGES", "SUPER",
|
||||
"EXECUTE", "FILE", "RELOAD", "SHUTDOWN", "CREATE ROUTINE",
|
||||
"ALTER ROUTINE", "CREATE USER",
|
||||
"CREATE TABLESPACE", "REFERENCES",
|
||||
}
|
||||
|
||||
detectedPrivileges := make(map[string]bool)
|
||||
|
||||
for rows.Next() {
|
||||
var grant string
|
||||
if err := rows.Scan(&grant); err != nil {
|
||||
return false, fmt.Errorf("failed to scan grant: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to scan grant: %w", err)
|
||||
}
|
||||
|
||||
for _, priv := range writePrivileges {
|
||||
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
|
||||
return false, nil
|
||||
detectedPrivileges[priv] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return false, fmt.Errorf("error iterating grants: %w", err)
|
||||
return false, nil, fmt.Errorf("error iterating grants: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
privileges := make([]string, 0, len(detectedPrivileges))
|
||||
for priv := range detectedPrivileges {
|
||||
privileges = append(privileges, priv)
|
||||
}
|
||||
|
||||
isReadOnly := len(privileges) == 0
|
||||
|
||||
return isReadOnly, privileges, nil
|
||||
}
|
||||
|
||||
func (m *MysqlDatabase) CreateReadOnlyUser(
|
||||
@@ -259,7 +319,7 @@ func (m *MysqlDatabase) CreateReadOnlyUser(
|
||||
maxRetries := 3
|
||||
for attempt := range maxRetries {
|
||||
newUsername := fmt.Sprintf("databasus-%s", uuid.New().String()[:8])
|
||||
newPassword := uuid.New().String()
|
||||
newPassword := encryption.GenerateComplexPassword()
|
||||
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
@@ -325,20 +385,45 @@ func (m *MysqlDatabase) CreateReadOnlyUser(
|
||||
return "", "", errors.New("failed to generate unique username after 3 attempts")
|
||||
}
|
||||
|
||||
func (m *MysqlDatabase) HasPrivilege(priv string) bool {
|
||||
return HasPrivilege(m.Privileges, priv)
|
||||
}
|
||||
|
||||
func HasPrivilege(privileges, priv string) bool {
|
||||
for p := range strings.SplitSeq(privileges, ",") {
|
||||
if strings.TrimSpace(p) == priv {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *MysqlDatabase) buildDSN(password string, database string) string {
|
||||
tlsConfig := "false"
|
||||
allowCleartext := ""
|
||||
|
||||
if m.IsHttps {
|
||||
tlsConfig = "true"
|
||||
err := mysql.RegisterTLSConfig("mysql-skip-verify", &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
if err != nil {
|
||||
// Config might already be registered, which is fine
|
||||
_ = err
|
||||
}
|
||||
|
||||
tlsConfig = "mysql-skip-verify"
|
||||
allowCleartext = "&allowCleartextPasswords=1"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%s:%s@tcp(%s:%d)/%s?parseTime=true&timeout=15s&tls=%s&charset=utf8mb4",
|
||||
"%s:%s@tcp(%s:%d)/%s?parseTime=true&timeout=15s&tls=%s&charset=utf8mb4%s",
|
||||
m.Username,
|
||||
password,
|
||||
m.Host,
|
||||
m.Port,
|
||||
database,
|
||||
tlsConfig,
|
||||
allowCleartext,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -388,6 +473,104 @@ func mapMysql8xVersion(minor string) tools.MysqlVersion {
|
||||
}
|
||||
}
|
||||
|
||||
// detectPrivileges detects backup-related privileges and returns them as comma-separated string
|
||||
func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string, error) {
|
||||
rows, err := db.QueryContext(ctx, "SHOW GRANTS FOR CURRENT_USER()")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to check grants: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
backupPrivileges := []string{
|
||||
"SELECT", "SHOW VIEW", "LOCK TABLES", "TRIGGER", "EVENT",
|
||||
}
|
||||
|
||||
detectedPrivileges := make(map[string]bool)
|
||||
hasProcess := false
|
||||
hasAllPrivileges := false
|
||||
|
||||
dbPatternStr := fmt.Sprintf(
|
||||
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
|
||||
regexp.QuoteMeta(database),
|
||||
)
|
||||
dbPattern := regexp.MustCompile(dbPatternStr)
|
||||
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)
|
||||
allPrivilegesPattern := regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`)
|
||||
|
||||
for rows.Next() {
|
||||
var grant string
|
||||
if err := rows.Scan(&grant); err != nil {
|
||||
return "", fmt.Errorf("failed to scan grant: %w", err)
|
||||
}
|
||||
|
||||
isRelevantGrant := globalPattern.MatchString(grant) || dbPattern.MatchString(grant)
|
||||
|
||||
if allPrivilegesPattern.MatchString(grant) && isRelevantGrant {
|
||||
hasAllPrivileges = true
|
||||
}
|
||||
|
||||
if isRelevantGrant {
|
||||
for _, priv := range backupPrivileges {
|
||||
privPattern := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(priv) + `\b`)
|
||||
if privPattern.MatchString(grant) {
|
||||
detectedPrivileges[priv] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if globalPattern.MatchString(grant) {
|
||||
processPattern := regexp.MustCompile(`(?i)\bPROCESS\b`)
|
||||
if processPattern.MatchString(grant) {
|
||||
hasProcess = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return "", fmt.Errorf("error iterating grants: %w", err)
|
||||
}
|
||||
|
||||
if hasAllPrivileges {
|
||||
for _, priv := range backupPrivileges {
|
||||
detectedPrivileges[priv] = true
|
||||
}
|
||||
hasProcess = true
|
||||
}
|
||||
|
||||
privileges := make([]string, 0, len(detectedPrivileges)+1)
|
||||
for priv := range detectedPrivileges {
|
||||
privileges = append(privileges, priv)
|
||||
}
|
||||
if hasProcess {
|
||||
privileges = append(privileges, "PROCESS")
|
||||
}
|
||||
|
||||
sort.Strings(privileges)
|
||||
return strings.Join(privileges, ","), nil
|
||||
}
|
||||
|
||||
// checkBackupPermissions verifies the user has sufficient privileges for mysqldump backup.
|
||||
// Required: SELECT, SHOW VIEW
|
||||
func checkBackupPermissions(privileges string) error {
|
||||
requiredPrivileges := []string{"SELECT", "SHOW VIEW"}
|
||||
|
||||
var missingPrivileges []string
|
||||
for _, priv := range requiredPrivileges {
|
||||
if !HasPrivilege(privileges, priv) {
|
||||
missingPrivileges = append(missingPrivileges, priv)
|
||||
}
|
||||
}
|
||||
|
||||
if len(missingPrivileges) > 0 {
|
||||
return fmt.Errorf(
|
||||
"insufficient permissions for backup. Missing: %s. Required: SELECT, SHOW VIEW",
|
||||
strings.Join(missingPrivileges, ", "),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decryptPasswordIfNeeded(
|
||||
password string,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
|
||||
@@ -18,6 +18,165 @@ import (
|
||||
"databasus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MysqlVersion
|
||||
port string
|
||||
}{
|
||||
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
|
||||
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
|
||||
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
|
||||
{"MySQL 9", tools.MysqlVersion9, env.TestMysql90Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMysqlContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS permission_test`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`CREATE TABLE permission_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`INSERT INTO permission_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
limitedUsername := fmt.Sprintf("limited_%s", uuid.New().String()[:8])
|
||||
limitedPassword := "limitedpassword123"
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
limitedUsername,
|
||||
limitedPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT SELECT ON `%s`.* TO '%s'@'%%'",
|
||||
container.Database,
|
||||
limitedUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", limitedUsername),
|
||||
)
|
||||
}()
|
||||
|
||||
mysqlModel := &MysqlDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: limitedUsername,
|
||||
Password: limitedPassword,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mysqlModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "insufficient permissions")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MysqlVersion
|
||||
port string
|
||||
}{
|
||||
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
|
||||
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
|
||||
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
|
||||
{"MySQL 9", tools.MysqlVersion9, env.TestMysql90Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMysqlContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS backup_test`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`CREATE TABLE backup_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`INSERT INTO backup_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupUsername := fmt.Sprintf("backup_%s", uuid.New().String()[:8])
|
||||
backupPassword := "backuppassword123"
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
backupUsername,
|
||||
backupPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT SELECT, SHOW VIEW, LOCK TABLES, TRIGGER, EVENT ON `%s`.* TO '%s'@'%%'",
|
||||
container.Database,
|
||||
backupUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT PROCESS ON *.* TO '%s'@'%%'",
|
||||
backupUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", backupUsername),
|
||||
)
|
||||
}()
|
||||
|
||||
mysqlModel := &MysqlDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: backupUsername,
|
||||
Password: backupPassword,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mysqlModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
@@ -42,13 +201,57 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
isReadOnly, err := mysqlModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
isReadOnly, privileges, err := mysqlModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, isReadOnly, "Root user should not be read-only")
|
||||
assert.NotEmpty(t, privileges, "Root user should have privileges")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMysqlContainer(t, env.TestMysql80Port, tools.MysqlVersion80)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS readonly_check_test`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`CREATE TABLE readonly_check_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`INSERT INTO readonly_check_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mysqlModel := createMysqlModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := mysqlModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
|
||||
readOnlyModel := &MysqlDatabase{
|
||||
Version: mysqlModel.Version,
|
||||
Host: mysqlModel.Host,
|
||||
Port: mysqlModel.Port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: mysqlModel.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isReadOnly, "Read-only user should be read-only")
|
||||
assert.Empty(t, privileges, "Read-only user should have no write privileges")
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", username))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
@@ -109,9 +312,15 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
isReadOnly, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
|
||||
ctx,
|
||||
logger,
|
||||
nil,
|
||||
uuid.New(),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isReadOnly, "Created user should be read-only")
|
||||
assert.Empty(t, privileges, "Read-only user should have no write privileges")
|
||||
|
||||
readOnlyDSN := fmt.Sprintf(
|
||||
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
@@ -309,6 +518,162 @@ func Test_ReadOnlyUser_CannotDropOrAlterTables(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_TestConnection_DatabaseSpecificPrivilegesWithGlobalProcess_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MysqlVersion
|
||||
port string
|
||||
}{
|
||||
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
|
||||
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
|
||||
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
|
||||
{"MySQL 9", tools.MysqlVersion9, env.TestMysql90Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMysqlContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS privilege_test`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`CREATE TABLE privilege_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`INSERT INTO privilege_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
specificUsername := fmt.Sprintf("specific_%s", uuid.New().String()[:8])
|
||||
specificPassword := "specificpass123"
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
specificUsername,
|
||||
specificPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT SELECT, SHOW VIEW ON %s.* TO '%s'@'%%'",
|
||||
container.Database,
|
||||
specificUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT PROCESS ON *.* TO '%s'@'%%'",
|
||||
specificUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", specificUsername),
|
||||
)
|
||||
}()
|
||||
|
||||
mysqlModel := &MysqlDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: specificUsername,
|
||||
Password: specificPassword,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mysqlModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMysqlContainer(t, env.TestMysql80Port, tools.MysqlVersion80)
|
||||
defer container.DB.Close()
|
||||
|
||||
underscoreDbName := "test_db_name"
|
||||
|
||||
_, err := container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
|
||||
}()
|
||||
|
||||
underscoreDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username, container.Password, container.Host, container.Port, underscoreDbName)
|
||||
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
|
||||
assert.NoError(t, err)
|
||||
defer underscoreDB.Close()
|
||||
|
||||
_, err = underscoreDB.Exec(`
|
||||
CREATE TABLE underscore_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(`INSERT INTO underscore_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
underscoreUsername := fmt.Sprintf("under_%s", uuid.New().String()[:8])
|
||||
underscorePassword := "underscorepass123"
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
underscoreUsername,
|
||||
underscorePassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"GRANT SELECT, SHOW VIEW ON `%s`.* TO '%s'@'%%'",
|
||||
underscoreDbName,
|
||||
underscoreUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = underscoreDB.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", underscoreUsername))
|
||||
}()
|
||||
|
||||
mysqlModel := &MysqlDatabase{
|
||||
Version: tools.MysqlVersion80,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: underscoreUsername,
|
||||
Password: underscorePassword,
|
||||
Database: &underscoreDbName,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mysqlModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
type MysqlContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
@@ -85,6 +85,27 @@ func (p *PostgresqlDatabase) Validate() error {
|
||||
return errors.New("cpu count must be greater than 0")
|
||||
}
|
||||
|
||||
// Prevent Databasus from backing up itself
|
||||
// Databasus runs an internal PostgreSQL instance that should not be backed up through the UI
|
||||
// because it would expose internal metadata to non-system administrators.
|
||||
// To properly backup Databasus, see: https://databasus.com/faq#backup-databasus
|
||||
if p.Database != nil && *p.Database != "" {
|
||||
localhostHosts := []string{"localhost", "127.0.0.1", "172.17.0.1", "host.docker.internal"}
|
||||
isLocalhost := false
|
||||
for _, host := range localhostHosts {
|
||||
if strings.EqualFold(p.Host, host) {
|
||||
isLocalhost = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isLocalhost && strings.EqualFold(*p.Database, "databasus") {
|
||||
return errors.New(
|
||||
"backing up Databasus internal database is not allowed. To backup Databasus itself, see https://databasus.com/faq#backup-databasus",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -137,16 +158,13 @@ func (p *PostgresqlDatabase) EncryptSensitiveFields(
|
||||
return nil
|
||||
}
|
||||
|
||||
// PopulateVersionIfEmpty detects and sets the PostgreSQL version if not already set.
|
||||
// PopulateDbData detects and sets the PostgreSQL version.
|
||||
// This should be called before encrypting sensitive fields.
|
||||
func (p *PostgresqlDatabase) PopulateVersionIfEmpty(
|
||||
func (p *PostgresqlDatabase) PopulateDbData(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
if p.Version != "" {
|
||||
return nil
|
||||
}
|
||||
return p.PopulateVersion(logger, encryptor, databaseID)
|
||||
}
|
||||
|
||||
@@ -192,29 +210,33 @@ func (p *PostgresqlDatabase) PopulateVersion(
|
||||
// IsUserReadOnly checks if the database user has read-only privileges.
|
||||
//
|
||||
// This method performs a comprehensive security check by examining:
|
||||
// - Role-level attributes (superuser, createrole, createdb)
|
||||
// - Role-level attributes (superuser, createrole, createdb, bypassrls, replication)
|
||||
// - Database-level privileges (CREATE, TEMP)
|
||||
// - Schema-level privileges (CREATE on any non-system schema)
|
||||
// - Table-level write permissions (INSERT, UPDATE, DELETE, TRUNCATE, REFERENCES, TRIGGER)
|
||||
// - Function-level privileges (EXECUTE on SECURITY DEFINER functions)
|
||||
//
|
||||
// A user is considered read-only only if they have ZERO write privileges
|
||||
// across all three levels. This ensures the database user follows the
|
||||
// across all levels. This ensures the database user follows the
|
||||
// principle of least privilege for backup operations.
|
||||
//
|
||||
// Returns: (isReadOnly, detectedPrivileges, error)
|
||||
func (p *PostgresqlDatabase) IsUserReadOnly(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) (bool, error) {
|
||||
) (bool, []string, error) {
|
||||
password, err := decryptPasswordIfNeeded(p.Password, encryptor, databaseID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
connStr := buildConnectionStringForDB(p, *p.Database, password)
|
||||
|
||||
conn, err := pgx.Connect(ctx, connStr)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to connect to database: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := conn.Close(ctx); closeErr != nil {
|
||||
@@ -222,22 +244,38 @@ func (p *PostgresqlDatabase) IsUserReadOnly(
|
||||
}
|
||||
}()
|
||||
|
||||
var privileges []string
|
||||
|
||||
// LEVEL 1: Check role-level attributes
|
||||
var isSuperuser, canCreateRole, canCreateDB bool
|
||||
var isSuperuser, canCreateRole, canCreateDB, canBypassRLS, canReplication bool
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT
|
||||
rolsuper,
|
||||
rolcreaterole,
|
||||
rolcreatedb
|
||||
rolcreatedb,
|
||||
rolbypassrls,
|
||||
rolreplication
|
||||
FROM pg_roles
|
||||
WHERE rolname = current_user
|
||||
`).Scan(&isSuperuser, &canCreateRole, &canCreateDB)
|
||||
`).Scan(&isSuperuser, &canCreateRole, &canCreateDB, &canBypassRLS, &canReplication)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check role attributes: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to check role attributes: %w", err)
|
||||
}
|
||||
|
||||
if isSuperuser || canCreateRole || canCreateDB {
|
||||
return false, nil
|
||||
if isSuperuser {
|
||||
privileges = append(privileges, "SUPERUSER")
|
||||
}
|
||||
if canCreateRole {
|
||||
privileges = append(privileges, "CREATEROLE")
|
||||
}
|
||||
if canCreateDB {
|
||||
privileges = append(privileges, "CREATEDB")
|
||||
}
|
||||
if canBypassRLS {
|
||||
privileges = append(privileges, "BYPASSRLS")
|
||||
}
|
||||
if canReplication {
|
||||
privileges = append(privileges, "REPLICATION")
|
||||
}
|
||||
|
||||
// LEVEL 2: Check database-level privileges
|
||||
@@ -248,46 +286,34 @@ func (p *PostgresqlDatabase) IsUserReadOnly(
|
||||
has_database_privilege(current_user, current_database(), 'TEMP') as can_temp
|
||||
`).Scan(&canCreate, &canTemp)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check database privileges: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to check database privileges: %w", err)
|
||||
}
|
||||
|
||||
if canCreate || canTemp {
|
||||
return false, nil
|
||||
if canCreate {
|
||||
privileges = append(privileges, "CREATE (database)")
|
||||
}
|
||||
if canTemp {
|
||||
privileges = append(privileges, "TEMP")
|
||||
}
|
||||
|
||||
// LEVEL 2.5: Check schema-level CREATE privileges
|
||||
schemaRows, err := conn.Query(ctx, `
|
||||
SELECT DISTINCT nspname
|
||||
FROM pg_namespace n
|
||||
WHERE has_schema_privilege(current_user, n.nspname, 'CREATE')
|
||||
AND nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
`)
|
||||
var hasSchemaCreate bool
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT EXISTS(
|
||||
SELECT 1
|
||||
FROM pg_namespace n
|
||||
WHERE has_schema_privilege(current_user, n.nspname, 'CREATE')
|
||||
AND nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
)
|
||||
`).Scan(&hasSchemaCreate)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check schema privileges: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to check schema privileges: %w", err)
|
||||
}
|
||||
defer schemaRows.Close()
|
||||
|
||||
// If user has CREATE privilege on any schema, they're not read-only
|
||||
if schemaRows.Next() {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err := schemaRows.Err(); err != nil {
|
||||
return false, fmt.Errorf("error iterating schema privileges: %w", err)
|
||||
if hasSchemaCreate {
|
||||
privileges = append(privileges, "CREATE (schema)")
|
||||
}
|
||||
|
||||
// LEVEL 3: Check table-level write permissions
|
||||
rows, err := conn.Query(ctx, `
|
||||
SELECT DISTINCT privilege_type
|
||||
FROM information_schema.role_table_grants
|
||||
WHERE grantee = current_user
|
||||
AND table_schema NOT IN ('pg_catalog', 'information_schema')
|
||||
`)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check table privileges: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
writePrivileges := map[string]bool{
|
||||
"INSERT": true,
|
||||
"UPDATE": true,
|
||||
@@ -297,22 +323,56 @@ func (p *PostgresqlDatabase) IsUserReadOnly(
|
||||
"TRIGGER": true,
|
||||
}
|
||||
|
||||
var tablePrivileges []string
|
||||
rows, err := conn.Query(ctx, `
|
||||
SELECT DISTINCT privilege_type
|
||||
FROM information_schema.role_table_grants
|
||||
WHERE grantee = current_user
|
||||
AND table_schema NOT IN ('pg_catalog', 'information_schema')
|
||||
`)
|
||||
if err != nil {
|
||||
return false, nil, fmt.Errorf("failed to check table privileges: %w", err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var privilege string
|
||||
if err := rows.Scan(&privilege); err != nil {
|
||||
return false, fmt.Errorf("failed to scan privilege: %w", err)
|
||||
}
|
||||
|
||||
if writePrivileges[privilege] {
|
||||
return false, nil
|
||||
rows.Close()
|
||||
return false, nil, fmt.Errorf("failed to scan privilege: %w", err)
|
||||
}
|
||||
tablePrivileges = append(tablePrivileges, privilege)
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return false, fmt.Errorf("error iterating privileges: %w", err)
|
||||
return false, nil, fmt.Errorf("error iterating privileges: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
for _, privilege := range tablePrivileges {
|
||||
if writePrivileges[privilege] {
|
||||
privileges = append(privileges, privilege)
|
||||
}
|
||||
}
|
||||
|
||||
// LEVEL 4: Check for EXECUTE privilege on functions that are SECURITY DEFINER
|
||||
var funcCount int
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_proc p
|
||||
JOIN pg_namespace n ON p.pronamespace = n.oid
|
||||
WHERE n.nspname NOT IN ('pg_catalog', 'information_schema')
|
||||
AND p.prosecdef = true
|
||||
AND has_function_privilege(current_user, p.oid, 'EXECUTE')
|
||||
`).Scan(&funcCount)
|
||||
if err != nil {
|
||||
return false, nil, fmt.Errorf("failed to check function privileges: %w", err)
|
||||
}
|
||||
if funcCount > 0 {
|
||||
privileges = append(privileges, "EXECUTE (SECURITY DEFINER)")
|
||||
}
|
||||
|
||||
isReadOnly := len(privileges) == 0
|
||||
return isReadOnly, privileges, nil
|
||||
}
|
||||
|
||||
// CreateReadOnlyUser creates a new PostgreSQL user with read-only privileges.
|
||||
@@ -383,7 +443,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
}
|
||||
}
|
||||
|
||||
newPassword := uuid.New().String()
|
||||
newPassword := encryption.GenerateComplexPassword()
|
||||
|
||||
tx, err := conn.Begin(ctx)
|
||||
if err != nil {
|
||||
@@ -631,13 +691,9 @@ func testSingleDatabaseConnection(
|
||||
}
|
||||
postgresDb.Version = detectedVersion
|
||||
|
||||
// Test if we can perform basic operations (like pg_dump would need)
|
||||
if err := testBasicOperations(ctx, conn, *postgresDb.Database); err != nil {
|
||||
return fmt.Errorf(
|
||||
"basic operations test failed for database '%s': %w",
|
||||
*postgresDb.Database,
|
||||
err,
|
||||
)
|
||||
// Verify user has sufficient permissions for backup operations
|
||||
if err := checkBackupPermissions(ctx, conn, postgresDb.IncludeSchemas); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -670,18 +726,128 @@ func detectDatabaseVersion(ctx context.Context, conn *pgx.Conn) (tools.Postgresq
|
||||
}
|
||||
}
|
||||
|
||||
// testBasicOperations tests basic operations that backup tools need
|
||||
func testBasicOperations(ctx context.Context, conn *pgx.Conn, dbName string) error {
|
||||
var hasCreatePriv bool
|
||||
// checkBackupPermissions verifies the user has sufficient privileges for pg_dump backup.
|
||||
// Required privileges: CONNECT on database, USAGE on schemas, SELECT on tables.
|
||||
// If includeSchemas is specified, only checks permissions on those schemas.
|
||||
func checkBackupPermissions(
|
||||
ctx context.Context,
|
||||
conn *pgx.Conn,
|
||||
includeSchemas []string,
|
||||
) error {
|
||||
var missingPrivileges []string
|
||||
|
||||
// Check CONNECT privilege on database
|
||||
var hasConnect bool
|
||||
err := conn.QueryRow(ctx, "SELECT has_database_privilege(current_user, current_database(), 'CONNECT')").
|
||||
Scan(&hasCreatePriv)
|
||||
Scan(&hasConnect)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot check database privileges: %w", err)
|
||||
}
|
||||
if !hasConnect {
|
||||
missingPrivileges = append(missingPrivileges, "CONNECT on database")
|
||||
}
|
||||
|
||||
if !hasCreatePriv {
|
||||
return fmt.Errorf("user does not have CONNECT privilege on database '%s'", dbName)
|
||||
// Check USAGE privilege on at least one non-system schema
|
||||
var schemaCount int
|
||||
if len(includeSchemas) > 0 {
|
||||
// Check only the specified schemas
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_namespace n
|
||||
WHERE has_schema_privilege(current_user, n.nspname, 'USAGE')
|
||||
AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
AND n.nspname NOT LIKE 'pg_temp_%'
|
||||
AND n.nspname NOT LIKE 'pg_toast_temp_%'
|
||||
AND n.nspname = ANY($1::text[])
|
||||
`, includeSchemas).Scan(&schemaCount)
|
||||
} else {
|
||||
// Check all non-system schemas
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_namespace n
|
||||
WHERE has_schema_privilege(current_user, n.nspname, 'USAGE')
|
||||
AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
AND n.nspname NOT LIKE 'pg_temp_%'
|
||||
AND n.nspname NOT LIKE 'pg_toast_temp_%'
|
||||
`).Scan(&schemaCount)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot check schema privileges: %w", err)
|
||||
}
|
||||
if schemaCount == 0 {
|
||||
missingPrivileges = append(missingPrivileges, "USAGE on at least one schema")
|
||||
}
|
||||
|
||||
// Check SELECT privilege on at least one table (if tables exist)
|
||||
// Use pg_tables from pg_catalog which shows all tables regardless of user privileges
|
||||
var tableCount int
|
||||
|
||||
if len(includeSchemas) > 0 {
|
||||
// Check only tables in the specified schemas
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_catalog.pg_tables t
|
||||
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
|
||||
AND t.schemaname NOT LIKE 'pg_temp_%'
|
||||
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
|
||||
AND t.schemaname = ANY($1::text[])
|
||||
`, includeSchemas).Scan(&tableCount)
|
||||
} else {
|
||||
// Check all tables in non-system schemas
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_catalog.pg_tables t
|
||||
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
|
||||
AND t.schemaname NOT LIKE 'pg_temp_%'
|
||||
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
|
||||
`).Scan(&tableCount)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot check table count: %w", err)
|
||||
}
|
||||
|
||||
if tableCount > 0 {
|
||||
// Check if user has SELECT on at least one of these tables
|
||||
var selectableTableCount int
|
||||
|
||||
if len(includeSchemas) > 0 {
|
||||
// Check only tables in the specified schemas
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_catalog.pg_tables t
|
||||
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
|
||||
AND t.schemaname NOT LIKE 'pg_temp_%'
|
||||
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
|
||||
AND t.schemaname = ANY($1::text[])
|
||||
AND has_table_privilege(current_user, quote_ident(t.schemaname) || '.' || quote_ident(t.tablename), 'SELECT')
|
||||
`, includeSchemas).Scan(&selectableTableCount)
|
||||
} else {
|
||||
// Check all tables in non-system schemas
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_catalog.pg_tables t
|
||||
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
|
||||
AND t.schemaname NOT LIKE 'pg_temp_%'
|
||||
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
|
||||
AND has_table_privilege(current_user, quote_ident(t.schemaname) || '.' || quote_ident(t.tablename), 'SELECT')
|
||||
`).Scan(&selectableTableCount)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot check SELECT privileges: %w", err)
|
||||
}
|
||||
if selectableTableCount == 0 {
|
||||
missingPrivileges = append(missingPrivileges, "SELECT on tables")
|
||||
}
|
||||
}
|
||||
|
||||
if len(missingPrivileges) > 0 {
|
||||
return fmt.Errorf(
|
||||
"insufficient permissions for backup. Missing: %s. Required: CONNECT on database, USAGE on schemas, SELECT on tables",
|
||||
strings.Join(missingPrivileges, ", "),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -695,16 +861,22 @@ func buildConnectionStringForDB(p *PostgresqlDatabase, dbName string, password s
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s default_query_exec_mode=simple_protocol standard_conforming_strings=on client_encoding=UTF8",
|
||||
"host=%s port=%d user=%s password='%s' dbname=%s sslmode=%s default_query_exec_mode=simple_protocol standard_conforming_strings=on client_encoding=UTF8",
|
||||
p.Host,
|
||||
p.Port,
|
||||
p.Username,
|
||||
password,
|
||||
escapeConnectionStringValue(password),
|
||||
dbName,
|
||||
sslMode,
|
||||
)
|
||||
}
|
||||
|
||||
func escapeConnectionStringValue(value string) string {
|
||||
value = strings.ReplaceAll(value, `\`, `\\`)
|
||||
value = strings.ReplaceAll(value, `'`, `\'`)
|
||||
return value
|
||||
}
|
||||
|
||||
func decryptPasswordIfNeeded(
|
||||
password string,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
|
||||
@@ -13,11 +13,236 @@ import (
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func Test_TestConnection_PasswordContainingSpaces_TestedSuccessfully(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToPostgresContainer(t, env.TestPostgres16Port)
|
||||
defer container.DB.Close()
|
||||
|
||||
passwordWithSpaces := "test password with spaces"
|
||||
usernameWithSpaces := fmt.Sprintf("testuser_spaces_%s", uuid.New().String()[:8])
|
||||
|
||||
_, err := container.DB.Exec(`
|
||||
DROP TABLE IF EXISTS password_test CASCADE;
|
||||
CREATE TABLE password_test (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO password_test (data) VALUES ('test1');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
|
||||
usernameWithSpaces,
|
||||
passwordWithSpaces,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
|
||||
container.Database,
|
||||
usernameWithSpaces,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT USAGE ON SCHEMA public TO "%s"`,
|
||||
usernameWithSpaces,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT SELECT ON ALL TABLES IN SCHEMA public TO "%s"`,
|
||||
usernameWithSpaces,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, usernameWithSpaces))
|
||||
}()
|
||||
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Version: tools.GetPostgresqlVersionEnum("16"),
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: usernameWithSpaces,
|
||||
Password: passwordWithSpaces,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = pgModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
port string
|
||||
}{
|
||||
{"PostgreSQL 12", "12", env.TestPostgres12Port},
|
||||
{"PostgreSQL 13", "13", env.TestPostgres13Port},
|
||||
{"PostgreSQL 14", "14", env.TestPostgres14Port},
|
||||
{"PostgreSQL 15", "15", env.TestPostgres15Port},
|
||||
{"PostgreSQL 16", "16", env.TestPostgres16Port},
|
||||
{"PostgreSQL 17", "17", env.TestPostgres17Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToPostgresContainer(t, tc.port)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`
|
||||
DROP TABLE IF EXISTS permission_test CASCADE;
|
||||
CREATE TABLE permission_test (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO permission_test (data) VALUES ('test1');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
limitedUsername := fmt.Sprintf("limited_user_%s", uuid.New().String()[:8])
|
||||
limitedPassword := "limitedpassword123"
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
|
||||
limitedUsername,
|
||||
limitedPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
|
||||
container.Database,
|
||||
limitedUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, limitedUsername))
|
||||
}()
|
||||
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Version: tools.GetPostgresqlVersionEnum(tc.version),
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: limitedUsername,
|
||||
Password: limitedPassword,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = pgModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.Error(t, err)
|
||||
if err != nil {
|
||||
assert.Contains(t, err.Error(), "insufficient permissions")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
port string
|
||||
}{
|
||||
{"PostgreSQL 12", "12", env.TestPostgres12Port},
|
||||
{"PostgreSQL 13", "13", env.TestPostgres13Port},
|
||||
{"PostgreSQL 14", "14", env.TestPostgres14Port},
|
||||
{"PostgreSQL 15", "15", env.TestPostgres15Port},
|
||||
{"PostgreSQL 16", "16", env.TestPostgres16Port},
|
||||
{"PostgreSQL 17", "17", env.TestPostgres17Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToPostgresContainer(t, tc.port)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`
|
||||
DROP TABLE IF EXISTS backup_test CASCADE;
|
||||
CREATE TABLE backup_test (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO backup_test (data) VALUES ('test1');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupUsername := fmt.Sprintf("backup_user_%s", uuid.New().String()[:8])
|
||||
backupPassword := "backuppassword123"
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
|
||||
backupUsername,
|
||||
backupPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
|
||||
container.Database,
|
||||
backupUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT USAGE ON SCHEMA public TO "%s"`,
|
||||
backupUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT SELECT ON ALL TABLES IN SCHEMA public TO "%s"`,
|
||||
backupUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, backupUsername))
|
||||
}()
|
||||
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Version: tools.GetPostgresqlVersionEnum(tc.version),
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: backupUsername,
|
||||
Password: backupPassword,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = pgModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
@@ -44,13 +269,60 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
isReadOnly, err := pgModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
isReadOnly, privileges, err := pgModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, isReadOnly, "Admin user should not be read-only")
|
||||
assert.NotEmpty(t, privileges, "Admin user should have privileges")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToPostgresContainer(t, env.TestPostgres16Port)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`
|
||||
DROP TABLE IF EXISTS readonly_check_test CASCADE;
|
||||
CREATE TABLE readonly_check_test (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO readonly_check_test (data) VALUES ('test1');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
|
||||
readOnlyModel := &PostgresqlDatabase{
|
||||
Version: pgModel.Version,
|
||||
Host: pgModel.Host,
|
||||
Port: pgModel.Port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: pgModel.Database,
|
||||
IsHttps: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isReadOnly, "Read-only user should be read-only")
|
||||
assert.Empty(t, privileges, "Read-only user should have no write privileges")
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to drop owned objects: %v", err)
|
||||
}
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
@@ -105,9 +377,15 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
isReadOnly, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
|
||||
ctx,
|
||||
logger,
|
||||
nil,
|
||||
uuid.New(),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isReadOnly, "Created user should be read-only")
|
||||
assert.Empty(t, privileges, "Read-only user should have no write privileges")
|
||||
|
||||
readOnlyDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
@@ -142,7 +420,6 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
// Clean up: Drop user with CASCADE to handle default privilege dependencies
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to drop owned objects: %v", err)
|
||||
@@ -186,7 +463,6 @@ func Test_ReadOnlyUser_FutureTables_HaveSelectPermission(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "future_data", data)
|
||||
|
||||
// Clean up: Drop user with CASCADE to handle default privilege dependencies
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to drop owned objects: %v", err)
|
||||
@@ -202,8 +478,10 @@ func Test_ReadOnlyUser_MultipleSchemas_AllAccessible(t *testing.T) {
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`
|
||||
CREATE SCHEMA IF NOT EXISTS schema_a;
|
||||
CREATE SCHEMA IF NOT EXISTS schema_b;
|
||||
DROP SCHEMA IF EXISTS schema_a CASCADE;
|
||||
DROP SCHEMA IF EXISTS schema_b CASCADE;
|
||||
CREATE SCHEMA schema_a;
|
||||
CREATE SCHEMA schema_b;
|
||||
CREATE TABLE schema_a.table_a (id INT, data TEXT);
|
||||
CREATE TABLE schema_b.table_b (id INT, data TEXT);
|
||||
INSERT INTO schema_a.table_a VALUES (1, 'data_a');
|
||||
@@ -234,7 +512,6 @@ func Test_ReadOnlyUser_MultipleSchemas_AllAccessible(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "data_b", dataB)
|
||||
|
||||
// Clean up: Drop user with CASCADE to handle default privilege dependencies
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to drop owned objects: %v", err)
|
||||
@@ -341,7 +618,7 @@ func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
|
||||
)
|
||||
|
||||
adminDB, err := sqlx.Connect("postgres", dsn)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
defer adminDB.Close()
|
||||
|
||||
tableName := fmt.Sprintf(
|
||||
@@ -428,6 +705,233 @@ func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenLocalhostAndDatabasus_ReturnsError(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
username string
|
||||
database string
|
||||
}{
|
||||
{
|
||||
name: "localhost with databasus db",
|
||||
host: "localhost",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1 with databasus db",
|
||||
host: "127.0.0.1",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "172.17.0.1 with databasus db",
|
||||
host: "172.17.0.1",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "host.docker.internal with databasus db",
|
||||
host: "host.docker.internal",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "LOCALHOST (uppercase) with DATABASUS db",
|
||||
host: "LOCALHOST",
|
||||
username: "POSTGRES",
|
||||
database: "DATABASUS",
|
||||
},
|
||||
{
|
||||
name: "LocalHost (mixed case) with DataBasus db",
|
||||
host: "LocalHost",
|
||||
username: "anyuser",
|
||||
database: "DataBasus",
|
||||
},
|
||||
{
|
||||
name: "localhost with databasus and any username",
|
||||
host: "localhost",
|
||||
username: "myuser",
|
||||
database: "databasus",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Host: tc.host,
|
||||
Port: 5437,
|
||||
Username: tc.username,
|
||||
Password: "somepassword",
|
||||
Database: &tc.database,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
err := pgModel.Validate()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "backing up Databasus internal database is not allowed")
|
||||
assert.Contains(t, err.Error(), "https://databasus.com/faq#backup-databasus")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Validate_WhenNotLocalhostOrNotDatabasus_ValidatesSuccessfully(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
username string
|
||||
database string
|
||||
}{
|
||||
{
|
||||
name: "different host (remote server) with databasus db",
|
||||
host: "192.168.1.100",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "different database name on localhost",
|
||||
host: "localhost",
|
||||
username: "postgres",
|
||||
database: "myapp",
|
||||
},
|
||||
{
|
||||
name: "all different",
|
||||
host: "db.example.com",
|
||||
username: "appuser",
|
||||
database: "production",
|
||||
},
|
||||
{
|
||||
name: "localhost with postgres database",
|
||||
host: "localhost",
|
||||
username: "postgres",
|
||||
database: "postgres",
|
||||
},
|
||||
{
|
||||
name: "remote host with databasus db name (allowed)",
|
||||
host: "db.example.com",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Host: tc.host,
|
||||
Port: 5432,
|
||||
Username: tc.username,
|
||||
Password: "somepassword",
|
||||
Database: &tc.database,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
err := pgModel.Validate()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Validate_WhenDatabaseIsNil_ValidatesSuccessfully(t *testing.T) {
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Host: "localhost",
|
||||
Port: 5437,
|
||||
Username: "postgres",
|
||||
Password: "somepassword",
|
||||
Database: nil,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
err := pgModel.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenDatabaseIsEmpty_ValidatesSuccessfully(t *testing.T) {
|
||||
emptyDb := ""
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Host: "localhost",
|
||||
Port: 5437,
|
||||
Username: "postgres",
|
||||
Password: "somepassword",
|
||||
Database: &emptyDb,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
err := pgModel.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenRequiredFieldsMissing_ReturnsError(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
model *PostgresqlDatabase
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "missing host",
|
||||
model: &PostgresqlDatabase{
|
||||
Host: "",
|
||||
Port: 5432,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
CpuCount: 1,
|
||||
},
|
||||
expectedError: "host is required",
|
||||
},
|
||||
{
|
||||
name: "missing port",
|
||||
model: &PostgresqlDatabase{
|
||||
Host: "localhost",
|
||||
Port: 0,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
CpuCount: 1,
|
||||
},
|
||||
expectedError: "port is required",
|
||||
},
|
||||
{
|
||||
name: "missing username",
|
||||
model: &PostgresqlDatabase{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "",
|
||||
Password: "pass",
|
||||
CpuCount: 1,
|
||||
},
|
||||
expectedError: "username is required",
|
||||
},
|
||||
{
|
||||
name: "missing password",
|
||||
model: &PostgresqlDatabase{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "user",
|
||||
Password: "",
|
||||
CpuCount: 1,
|
||||
},
|
||||
expectedError: "password is required",
|
||||
},
|
||||
{
|
||||
name: "invalid cpu count",
|
||||
model: &PostgresqlDatabase{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
CpuCount: 0,
|
||||
},
|
||||
expectedError: "cpu count must be greater than 0",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.model.Validate()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tc.expectedError)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type PostgresContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
@@ -483,6 +987,7 @@ func createPostgresModel(container *PostgresContainer) *PostgresqlDatabase {
|
||||
Password: container.Password,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,4 +39,5 @@ func GetDatabaseController() *DatabaseController {
|
||||
|
||||
func SetupDependencies() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService)
|
||||
notifiers.GetNotifierService().SetNotifierDatabaseCounter(databaseService)
|
||||
}
|
||||
|
||||
@@ -6,5 +6,6 @@ type CreateReadOnlyUserResponse struct {
|
||||
}
|
||||
|
||||
type IsReadOnlyResponse struct {
|
||||
IsReadOnly bool `json:"isReadOnly"`
|
||||
IsReadOnly bool `json:"isReadOnly"`
|
||||
Privileges []string `json:"privileges"`
|
||||
}
|
||||
|
||||
@@ -104,21 +104,21 @@ func (d *Database) EncryptSensitiveFields(encryptor encryption.FieldEncryptor) e
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) PopulateVersionIfEmpty(
|
||||
func (d *Database) PopulateDbData(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
) error {
|
||||
if d.Postgresql != nil {
|
||||
return d.Postgresql.PopulateVersionIfEmpty(logger, encryptor, d.ID)
|
||||
return d.Postgresql.PopulateDbData(logger, encryptor, d.ID)
|
||||
}
|
||||
if d.Mysql != nil {
|
||||
return d.Mysql.PopulateVersionIfEmpty(logger, encryptor, d.ID)
|
||||
return d.Mysql.PopulateDbData(logger, encryptor, d.ID)
|
||||
}
|
||||
if d.Mariadb != nil {
|
||||
return d.Mariadb.PopulateVersionIfEmpty(logger, encryptor, d.ID)
|
||||
return d.Mariadb.PopulateDbData(logger, encryptor, d.ID)
|
||||
}
|
||||
if d.Mongodb != nil {
|
||||
return d.Mongodb.PopulateVersionIfEmpty(logger, encryptor, d.ID)
|
||||
return d.Mongodb.PopulateDbData(logger, encryptor, d.ID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -243,3 +243,19 @@ func (r *DatabaseRepository) GetAllDatabases() ([]*Database, error) {
|
||||
|
||||
return databases, nil
|
||||
}
|
||||
|
||||
func (r *DatabaseRepository) GetDatabasesIDsByNotifierID(
|
||||
notifierID uuid.UUID,
|
||||
) ([]uuid.UUID, error) {
|
||||
var databasesIDs []uuid.UUID
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Table("database_notifiers").
|
||||
Where("notifier_id = ?", notifierID).
|
||||
Pluck("database_id", &databasesIDs).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return databasesIDs, nil
|
||||
}
|
||||
|
||||
@@ -52,6 +52,17 @@ func (s *DatabaseService) AddDbCopyListener(
|
||||
s.dbCopyListener = append(s.dbCopyListener, dbCopyListener)
|
||||
}
|
||||
|
||||
func (s *DatabaseService) GetNotifierAttachedDatabasesIDs(
|
||||
notifierID uuid.UUID,
|
||||
) ([]uuid.UUID, error) {
|
||||
databasesIDs, err := s.dbRepository.GetDatabasesIDsByNotifierID(notifierID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return databasesIDs, nil
|
||||
}
|
||||
|
||||
func (s *DatabaseService) CreateDatabase(
|
||||
user *users_models.User,
|
||||
workspaceID uuid.UUID,
|
||||
@@ -71,8 +82,8 @@ func (s *DatabaseService) CreateDatabase(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := database.PopulateVersionIfEmpty(s.logger, s.fieldEncryptor); err != nil {
|
||||
return nil, fmt.Errorf("failed to auto-detect database version: %w", err)
|
||||
if err := database.PopulateDbData(s.logger, s.fieldEncryptor); err != nil {
|
||||
return nil, fmt.Errorf("failed to auto-detect database data: %w", err)
|
||||
}
|
||||
|
||||
if err := database.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
|
||||
@@ -126,14 +137,20 @@ func (s *DatabaseService) UpdateDatabase(
|
||||
return err
|
||||
}
|
||||
|
||||
for _, notifier := range database.Notifiers {
|
||||
if notifier.WorkspaceID != *existingDatabase.WorkspaceID {
|
||||
return errors.New("notifier does not belong to this workspace")
|
||||
}
|
||||
}
|
||||
|
||||
existingDatabase.Update(database)
|
||||
|
||||
if err := existingDatabase.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := existingDatabase.PopulateVersionIfEmpty(s.logger, s.fieldEncryptor); err != nil {
|
||||
return fmt.Errorf("failed to auto-detect database version: %w", err)
|
||||
if err := existingDatabase.PopulateDbData(s.logger, s.fieldEncryptor); err != nil {
|
||||
return fmt.Errorf("failed to auto-detect database data: %w", err)
|
||||
}
|
||||
|
||||
if err := existingDatabase.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
|
||||
@@ -251,6 +268,23 @@ func (s *DatabaseService) IsNotifierUsing(
|
||||
return s.dbRepository.IsNotifierUsing(notifierID)
|
||||
}
|
||||
|
||||
func (s *DatabaseService) CountDatabasesByNotifier(
|
||||
user *users_models.User,
|
||||
notifierID uuid.UUID,
|
||||
) (int, error) {
|
||||
_, err := s.notifierService.GetNotifier(user, notifierID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
databaseIDs, err := s.dbRepository.GetDatabasesIDsByNotifierID(notifierID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(databaseIDs), nil
|
||||
}
|
||||
|
||||
func (s *DatabaseService) TestDatabaseConnection(
|
||||
user *users_models.User,
|
||||
databaseID uuid.UUID,
|
||||
@@ -481,6 +515,48 @@ func (s *DatabaseService) CopyDatabase(
|
||||
return copiedDatabase, nil
|
||||
}
|
||||
|
||||
func (s *DatabaseService) TransferDatabaseToWorkspace(
|
||||
databaseID uuid.UUID,
|
||||
targetWorkspaceID uuid.UUID,
|
||||
) error {
|
||||
database, err := s.dbRepository.FindByID(databaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sourceWorkspaceID := database.WorkspaceID
|
||||
database.WorkspaceID = &targetWorkspaceID
|
||||
|
||||
_, err = s.dbRepository.Save(database)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Database transferred: %s from workspace %s to workspace %s",
|
||||
database.Name, sourceWorkspaceID, targetWorkspaceID),
|
||||
nil,
|
||||
&targetWorkspaceID,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DatabaseService) UpdateDatabaseNotifiers(
|
||||
databaseID uuid.UUID,
|
||||
newNotifiers []notifiers.Notifier,
|
||||
) error {
|
||||
database, err := s.dbRepository.FindByID(databaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
database.Notifiers = newNotifiers
|
||||
|
||||
_, err = s.dbRepository.Save(database)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *DatabaseService) SetHealthStatus(
|
||||
databaseID uuid.UUID,
|
||||
healthStatus *HealthStatus,
|
||||
@@ -518,17 +594,17 @@ func (s *DatabaseService) OnBeforeWorkspaceDeletion(workspaceID uuid.UUID) error
|
||||
func (s *DatabaseService) IsUserReadOnly(
|
||||
user *users_models.User,
|
||||
database *Database,
|
||||
) (bool, error) {
|
||||
) (bool, []string, error) {
|
||||
var usingDatabase *Database
|
||||
|
||||
if database.ID != uuid.Nil {
|
||||
existingDatabase, err := s.dbRepository.FindByID(database.ID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return false, nil, err
|
||||
}
|
||||
|
||||
if existingDatabase.WorkspaceID == nil {
|
||||
return false, errors.New("cannot check user for database without workspace")
|
||||
return false, nil, errors.New("cannot check user for database without workspace")
|
||||
}
|
||||
|
||||
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(
|
||||
@@ -536,31 +612,34 @@ func (s *DatabaseService) IsUserReadOnly(
|
||||
user,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return false, nil, err
|
||||
}
|
||||
if !canAccess {
|
||||
return false, errors.New("insufficient permissions to access this database")
|
||||
return false, nil, errors.New("insufficient permissions to access this database")
|
||||
}
|
||||
|
||||
if database.WorkspaceID != nil && *existingDatabase.WorkspaceID != *database.WorkspaceID {
|
||||
return false, errors.New("database does not belong to this workspace")
|
||||
return false, nil, errors.New("database does not belong to this workspace")
|
||||
}
|
||||
|
||||
existingDatabase.Update(database)
|
||||
|
||||
if err := existingDatabase.Validate(); err != nil {
|
||||
return false, err
|
||||
return false, nil, err
|
||||
}
|
||||
|
||||
usingDatabase = existingDatabase
|
||||
} else {
|
||||
if database.WorkspaceID != nil {
|
||||
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, user)
|
||||
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(
|
||||
*database.WorkspaceID,
|
||||
user,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return false, nil, err
|
||||
}
|
||||
if !canAccess {
|
||||
return false, errors.New("insufficient permissions to access this workspace")
|
||||
return false, nil, errors.New("insufficient permissions to access this workspace")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -600,7 +679,7 @@ func (s *DatabaseService) IsUserReadOnly(
|
||||
usingDatabase.ID,
|
||||
)
|
||||
default:
|
||||
return false, errors.New("read-only check not supported for this database type")
|
||||
return false, nil, errors.New("read-only check not supported for this database type")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
package databases
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/databases/databases/mariadb"
|
||||
"databasus-backend/internal/features/databases/databases/mongodb"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
@@ -9,6 +15,71 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func GetTestPostgresConfig() *postgresql.PostgresqlDatabase {
|
||||
env := config.GetEnv()
|
||||
port, err := strconv.Atoi(env.TestPostgres16Port)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
|
||||
}
|
||||
|
||||
testDbName := "testdb"
|
||||
return &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func GetTestMariadbConfig() *mariadb.MariadbDatabase {
|
||||
env := config.GetEnv()
|
||||
portStr := env.TestMariadb1011Port
|
||||
if portStr == "" {
|
||||
portStr = "33111"
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_MARIADB_1011_PORT: %v", err))
|
||||
}
|
||||
|
||||
testDbName := "testdb"
|
||||
return &mariadb.MariadbDatabase{
|
||||
Version: tools.MariadbVersion1011,
|
||||
Host: "localhost",
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
Database: &testDbName,
|
||||
}
|
||||
}
|
||||
|
||||
func GetTestMongodbConfig() *mongodb.MongodbDatabase {
|
||||
env := config.GetEnv()
|
||||
portStr := env.TestMongodb70Port
|
||||
if portStr == "" {
|
||||
portStr = "27070"
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_MONGODB_70_PORT: %v", err))
|
||||
}
|
||||
|
||||
return &mongodb.MongodbDatabase{
|
||||
Version: tools.MongodbVersion7,
|
||||
Host: "localhost",
|
||||
Port: port,
|
||||
Username: "root",
|
||||
Password: "rootpassword",
|
||||
Database: "testdb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestDatabase(
|
||||
workspaceID uuid.UUID,
|
||||
storage *storages.Storage,
|
||||
@@ -18,16 +89,7 @@ func CreateTestDatabase(
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: "test " + uuid.New().String(),
|
||||
Type: DatabaseTypePostgres,
|
||||
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
CpuCount: 1,
|
||||
},
|
||||
|
||||
Postgresql: GetTestPostgresConfig(),
|
||||
Notifiers: []notifiers.Notifier{
|
||||
*notifier,
|
||||
},
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package healthcheck_attempt
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
"context"
|
||||
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
|
||||
"log/slog"
|
||||
"time"
|
||||
@@ -13,18 +13,19 @@ type HealthcheckAttemptBackgroundService struct {
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *HealthcheckAttemptBackgroundService) Run() {
|
||||
func (s *HealthcheckAttemptBackgroundService) Run(ctx context.Context) {
|
||||
// first healthcheck immediately
|
||||
s.checkDatabases()
|
||||
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
if config.IsShouldShutdown() {
|
||||
break
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.checkDatabases()
|
||||
}
|
||||
|
||||
s.checkDatabases()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,13 +12,11 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
"databasus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
@@ -111,7 +109,13 @@ func Test_GetAttemptsByDatabase_PermissionsEnforced(t *testing.T) {
|
||||
testUserToken = owner.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
member,
|
||||
*tt.workspaceRole,
|
||||
owner.Token,
|
||||
router,
|
||||
)
|
||||
testUserToken = member.Token
|
||||
} else {
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
@@ -205,20 +209,11 @@ func createTestDatabaseViaAPI(
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *databases.Database {
|
||||
testDbName := "test_db"
|
||||
request := databases.Database{
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: name,
|
||||
Type: databases.DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
},
|
||||
Postgresql: databases.GetTestPostgresConfig(),
|
||||
}
|
||||
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
|
||||
@@ -10,13 +10,11 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
"databasus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
@@ -90,7 +88,13 @@ func Test_SaveHealthcheckConfig_PermissionsEnforced(t *testing.T) {
|
||||
testUserToken = owner.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
member,
|
||||
*tt.workspaceRole,
|
||||
owner.Token,
|
||||
router,
|
||||
)
|
||||
testUserToken = member.Token
|
||||
}
|
||||
|
||||
@@ -228,7 +232,13 @@ func Test_GetHealthcheckConfig_PermissionsEnforced(t *testing.T) {
|
||||
testUserToken = owner.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
member,
|
||||
*tt.workspaceRole,
|
||||
owner.Token,
|
||||
router,
|
||||
)
|
||||
testUserToken = member.Token
|
||||
} else {
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
@@ -293,20 +303,11 @@ func createTestDatabaseViaAPI(
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *databases.Database {
|
||||
testDbName := "test_db"
|
||||
request := databases.Database{
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: name,
|
||||
Type: databases.DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
},
|
||||
Postgresql: databases.GetTestPostgresConfig(),
|
||||
}
|
||||
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package notifiers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"net/http"
|
||||
@@ -20,6 +22,7 @@ func (c *NotifierController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
router.GET("/notifiers/:id", c.GetNotifier)
|
||||
router.DELETE("/notifiers/:id", c.DeleteNotifier)
|
||||
router.POST("/notifiers/:id/test", c.SendTestNotification)
|
||||
router.POST("/notifiers/:id/transfer", c.TransferNotifierToWorkspace)
|
||||
router.POST("/notifiers/direct-test", c.SendTestNotificationDirect)
|
||||
}
|
||||
|
||||
@@ -55,7 +58,7 @@ func (c *NotifierController) SaveNotifier(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
if err := c.notifierService.SaveNotifier(user, request.WorkspaceID, &request); err != nil {
|
||||
if err.Error() == "insufficient permissions to manage notifier in this workspace" {
|
||||
if errors.Is(err, ErrInsufficientPermissionsToManageNotifier) {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -93,7 +96,7 @@ func (c *NotifierController) GetNotifier(ctx *gin.Context) {
|
||||
|
||||
notifier, err := c.notifierService.GetNotifier(user, id)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view notifier in this workspace" {
|
||||
if errors.Is(err, ErrInsufficientPermissionsToViewNotifier) {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -137,7 +140,7 @@ func (c *NotifierController) GetNotifiers(ctx *gin.Context) {
|
||||
|
||||
notifiers, err := c.notifierService.GetNotifiers(user, workspaceID)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view notifiers in this workspace" {
|
||||
if errors.Is(err, ErrInsufficientPermissionsToViewNotifiers) {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -174,7 +177,7 @@ func (c *NotifierController) DeleteNotifier(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
if err := c.notifierService.DeleteNotifier(user, id); err != nil {
|
||||
if err.Error() == "insufficient permissions to manage notifier in this workspace" {
|
||||
if errors.Is(err, ErrInsufficientPermissionsToManageNotifier) {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -211,7 +214,7 @@ func (c *NotifierController) SendTestNotification(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
if err := c.notifierService.SendTestNotification(user, id); err != nil {
|
||||
if err.Error() == "insufficient permissions to test notifier in this workspace" {
|
||||
if errors.Is(err, ErrInsufficientPermissionsToTestNotifier) {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -222,6 +225,62 @@ func (c *NotifierController) SendTestNotification(ctx *gin.Context) {
|
||||
ctx.JSON(http.StatusOK, gin.H{"message": "test notification sent successfully"})
|
||||
}
|
||||
|
||||
// TransferNotifierToWorkspace
|
||||
// @Summary Transfer notifier to another workspace
|
||||
// @Description Transfer a notifier from one workspace to another
|
||||
// @Tags notifiers
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param Authorization header string true "JWT token"
|
||||
// @Param id path string true "Notifier ID"
|
||||
// @Param request body TransferNotifierRequest true "Target workspace ID"
|
||||
// @Success 200
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 403
|
||||
// @Router /notifiers/{id}/transfer [post]
|
||||
func (c *NotifierController) TransferNotifierToWorkspace(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid notifier ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var request TransferNotifierRequest
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if request.TargetWorkspaceID == uuid.Nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "targetWorkspaceId is required"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.notifierService.TransferNotifierToWorkspace(
|
||||
user,
|
||||
id,
|
||||
request.TargetWorkspaceID,
|
||||
nil,
|
||||
); err != nil {
|
||||
if errors.Is(err, ErrInsufficientPermissionsInSourceWorkspace) ||
|
||||
errors.Is(err, ErrInsufficientPermissionsInTargetWorkspace) {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, gin.H{"message": "notifier transferred successfully"})
|
||||
}
|
||||
|
||||
// SendTestNotificationDirect
|
||||
// @Summary Send test notification directly
|
||||
// @Description Send a test notification using a notifier object provided in the request
|
||||
|
||||
@@ -202,164 +202,161 @@ func Test_SendTestNotificationExisting_NotificationSent(t *testing.T) {
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_ViewerCanViewNotifiers_ButCannotModify(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
viewer := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
viewer,
|
||||
users_enums.WorkspaceRoleViewer,
|
||||
owner.Token,
|
||||
router,
|
||||
)
|
||||
func Test_WorkspaceRolePermissions_Notifiers(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
workspaceRole *users_enums.WorkspaceRole
|
||||
isGlobalAdmin bool
|
||||
canCreate bool
|
||||
canUpdate bool
|
||||
canDelete bool
|
||||
}{
|
||||
{
|
||||
name: "owner can manage notifiers",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
canCreate: true,
|
||||
canUpdate: true,
|
||||
canDelete: true,
|
||||
},
|
||||
{
|
||||
name: "admin can manage notifiers",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
canCreate: true,
|
||||
canUpdate: true,
|
||||
canDelete: true,
|
||||
},
|
||||
{
|
||||
name: "member can manage notifiers",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
canCreate: true,
|
||||
canUpdate: true,
|
||||
canDelete: true,
|
||||
},
|
||||
{
|
||||
name: "viewer can view but cannot modify notifiers",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
canCreate: false,
|
||||
canUpdate: false,
|
||||
canDelete: false,
|
||||
},
|
||||
{
|
||||
name: "global admin can manage notifiers",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: true,
|
||||
canCreate: true,
|
||||
canUpdate: true,
|
||||
canDelete: true,
|
||||
},
|
||||
}
|
||||
|
||||
notifier := createNewNotifier(workspace.ID)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createRouter()
|
||||
GetNotifierService().SetNotifierDatabaseCounter(&mockNotifierDatabaseCounter{})
|
||||
|
||||
var savedNotifier Notifier
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/notifiers",
|
||||
"Bearer "+owner.Token,
|
||||
*notifier,
|
||||
http.StatusOK,
|
||||
&savedNotifier,
|
||||
)
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
// Viewer can GET notifiers
|
||||
var notifiers []Notifier
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/notifiers?workspace_id=%s", workspace.ID.String()),
|
||||
"Bearer "+viewer.Token,
|
||||
http.StatusOK,
|
||||
¬ifiers,
|
||||
)
|
||||
assert.Len(t, notifiers, 1)
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUserToken = admin.Token
|
||||
} else if tt.workspaceRole != nil && *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
|
||||
testUserToken = owner.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
testUser := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
testUser,
|
||||
*tt.workspaceRole,
|
||||
owner.Token,
|
||||
router,
|
||||
)
|
||||
testUserToken = testUser.Token
|
||||
}
|
||||
|
||||
// Viewer cannot CREATE notifier
|
||||
newNotifier := createNewNotifier(workspace.ID)
|
||||
test_utils.MakePostRequest(
|
||||
t, router, "/api/v1/notifiers", "Bearer "+viewer.Token, *newNotifier, http.StatusForbidden,
|
||||
)
|
||||
// Owner creates initial notifier for all test cases
|
||||
var ownerNotifier Notifier
|
||||
notifier := createNewNotifier(workspace.ID)
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t, router, "/api/v1/notifiers", "Bearer "+owner.Token,
|
||||
*notifier, http.StatusOK, &ownerNotifier,
|
||||
)
|
||||
|
||||
// Viewer cannot UPDATE notifier
|
||||
savedNotifier.Name = "Updated by viewer"
|
||||
test_utils.MakePostRequest(
|
||||
t, router, "/api/v1/notifiers", "Bearer "+viewer.Token, savedNotifier, http.StatusForbidden,
|
||||
)
|
||||
// Test GET notifiers
|
||||
var notifiers []Notifier
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t, router,
|
||||
fmt.Sprintf("/api/v1/notifiers?workspace_id=%s", workspace.ID.String()),
|
||||
"Bearer "+testUserToken, http.StatusOK, ¬ifiers,
|
||||
)
|
||||
assert.Len(t, notifiers, 1)
|
||||
|
||||
// Viewer cannot DELETE notifier
|
||||
test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/notifiers/%s", savedNotifier.ID.String()),
|
||||
"Bearer "+viewer.Token,
|
||||
http.StatusForbidden,
|
||||
)
|
||||
// Test CREATE notifier
|
||||
createStatusCode := http.StatusOK
|
||||
if !tt.canCreate {
|
||||
createStatusCode = http.StatusForbidden
|
||||
}
|
||||
newNotifier := createNewNotifier(workspace.ID)
|
||||
var savedNotifier Notifier
|
||||
if tt.canCreate {
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t, router, "/api/v1/notifiers", "Bearer "+testUserToken,
|
||||
*newNotifier, createStatusCode, &savedNotifier,
|
||||
)
|
||||
assert.NotEmpty(t, savedNotifier.ID)
|
||||
} else {
|
||||
test_utils.MakePostRequest(
|
||||
t, router, "/api/v1/notifiers", "Bearer "+testUserToken,
|
||||
*newNotifier, createStatusCode,
|
||||
)
|
||||
}
|
||||
|
||||
deleteNotifier(t, router, savedNotifier.ID, workspace.ID, owner.Token)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
// Test UPDATE notifier
|
||||
updateStatusCode := http.StatusOK
|
||||
if !tt.canUpdate {
|
||||
updateStatusCode = http.StatusForbidden
|
||||
}
|
||||
ownerNotifier.Name = "Updated by test user"
|
||||
if tt.canUpdate {
|
||||
var updatedNotifier Notifier
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t, router, "/api/v1/notifiers", "Bearer "+testUserToken,
|
||||
ownerNotifier, updateStatusCode, &updatedNotifier,
|
||||
)
|
||||
assert.Equal(t, "Updated by test user", updatedNotifier.Name)
|
||||
} else {
|
||||
test_utils.MakePostRequest(
|
||||
t, router, "/api/v1/notifiers", "Bearer "+testUserToken,
|
||||
ownerNotifier, updateStatusCode,
|
||||
)
|
||||
}
|
||||
|
||||
func Test_MemberCanManageNotifiers(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
member,
|
||||
users_enums.WorkspaceRoleMember,
|
||||
owner.Token,
|
||||
router,
|
||||
)
|
||||
// Test DELETE notifier
|
||||
deleteStatusCode := http.StatusOK
|
||||
if !tt.canDelete {
|
||||
deleteStatusCode = http.StatusForbidden
|
||||
}
|
||||
test_utils.MakeDeleteRequest(
|
||||
t, router,
|
||||
fmt.Sprintf("/api/v1/notifiers/%s", ownerNotifier.ID.String()),
|
||||
"Bearer "+testUserToken, deleteStatusCode,
|
||||
)
|
||||
|
||||
notifier := createNewNotifier(workspace.ID)
|
||||
|
||||
// Member can CREATE notifier
|
||||
var savedNotifier Notifier
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/notifiers",
|
||||
"Bearer "+member.Token,
|
||||
*notifier,
|
||||
http.StatusOK,
|
||||
&savedNotifier,
|
||||
)
|
||||
assert.NotEmpty(t, savedNotifier.ID)
|
||||
|
||||
// Member can UPDATE notifier
|
||||
savedNotifier.Name = "Updated by member"
|
||||
var updatedNotifier Notifier
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/notifiers",
|
||||
"Bearer "+member.Token,
|
||||
savedNotifier,
|
||||
http.StatusOK,
|
||||
&updatedNotifier,
|
||||
)
|
||||
assert.Equal(t, "Updated by member", updatedNotifier.Name)
|
||||
|
||||
// Member can DELETE notifier
|
||||
test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/notifiers/%s", savedNotifier.ID.String()),
|
||||
"Bearer "+member.Token,
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_AdminCanManageNotifiers(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
admin,
|
||||
users_enums.WorkspaceRoleAdmin,
|
||||
owner.Token,
|
||||
router,
|
||||
)
|
||||
|
||||
notifier := createNewNotifier(workspace.ID)
|
||||
|
||||
// Admin can CREATE, UPDATE, DELETE
|
||||
var savedNotifier Notifier
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/notifiers",
|
||||
"Bearer "+admin.Token,
|
||||
*notifier,
|
||||
http.StatusOK,
|
||||
&savedNotifier,
|
||||
)
|
||||
|
||||
savedNotifier.Name = "Updated by admin"
|
||||
test_utils.MakePostRequest(
|
||||
t, router, "/api/v1/notifiers", "Bearer "+admin.Token, savedNotifier, http.StatusOK,
|
||||
)
|
||||
|
||||
test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/notifiers/%s", savedNotifier.ID.String()),
|
||||
"Bearer "+admin.Token,
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
// Cleanup
|
||||
if tt.canCreate {
|
||||
deleteNotifier(t, router, savedNotifier.ID, workspace.ID, owner.Token)
|
||||
}
|
||||
if !tt.canDelete {
|
||||
deleteNotifier(t, router, ownerNotifier.ID, workspace.ID, owner.Token)
|
||||
}
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_UserNotInWorkspace_CannotAccessNotifiers(t *testing.T) {
|
||||
@@ -678,6 +675,10 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
WebhookNotifier: &webhook_notifier.WebhookNotifier{
|
||||
WebhookURL: "https://webhook.example.com/test",
|
||||
WebhookMethod: webhook_notifier.WebhookMethodPOST,
|
||||
Headers: []webhook_notifier.WebhookHeader{
|
||||
{Key: "Authorization", Value: "Bearer my-secret-token"},
|
||||
{Key: "X-Custom-Header", Value: "custom-value"},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
@@ -690,14 +691,40 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
WebhookNotifier: &webhook_notifier.WebhookNotifier{
|
||||
WebhookURL: "https://webhook.example.com/updated",
|
||||
WebhookMethod: webhook_notifier.WebhookMethodGET,
|
||||
Headers: []webhook_notifier.WebhookHeader{
|
||||
{Key: "Authorization", Value: "Bearer updated-token"},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
|
||||
// No sensitive data to verify for webhook
|
||||
assert.NotEmpty(
|
||||
t,
|
||||
notifier.WebhookNotifier.WebhookURL,
|
||||
"WebhookURL should be visible",
|
||||
)
|
||||
// Verify header values are encrypted in DB
|
||||
assert.True(
|
||||
t,
|
||||
isEncrypted(notifier.WebhookNotifier.Headers[0].Value),
|
||||
"Header value should be encrypted in DB",
|
||||
)
|
||||
decrypted := decryptField(
|
||||
t,
|
||||
notifier.ID,
|
||||
notifier.WebhookNotifier.Headers[0].Value,
|
||||
)
|
||||
assert.Equal(t, "Bearer updated-token", decrypted)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
|
||||
// No sensitive data to hide for webhook
|
||||
assert.NotEmpty(
|
||||
t,
|
||||
notifier.WebhookNotifier.WebhookURL,
|
||||
"WebhookURL should be visible",
|
||||
)
|
||||
for _, header := range notifier.WebhookNotifier.Headers {
|
||||
assert.Empty(t, header.Value, "Header value should be hidden")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -908,7 +935,7 @@ func Test_CreateNotifier_AllSensitiveFieldsEncryptedInDB(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Webhook Notifier - WebhookURL encrypted",
|
||||
name: "Webhook Notifier - Header values encrypted, URL not encrypted",
|
||||
createNotifier: func(workspaceID uuid.UUID) *Notifier {
|
||||
return &Notifier{
|
||||
WorkspaceID: workspaceID,
|
||||
@@ -917,17 +944,48 @@ func Test_CreateNotifier_AllSensitiveFieldsEncryptedInDB(t *testing.T) {
|
||||
WebhookNotifier: &webhook_notifier.WebhookNotifier{
|
||||
WebhookURL: "https://webhook.example.com/test456",
|
||||
WebhookMethod: webhook_notifier.WebhookMethodPOST,
|
||||
Headers: []webhook_notifier.WebhookHeader{
|
||||
{Key: "Authorization", Value: "Bearer secret-token-12345"},
|
||||
{Key: "X-API-Key", Value: "api-key-67890"},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
|
||||
assert.True(
|
||||
assert.False(
|
||||
t,
|
||||
isEncrypted(notifier.WebhookNotifier.WebhookURL),
|
||||
"WebhookURL should be encrypted",
|
||||
"WebhookURL should NOT be encrypted",
|
||||
)
|
||||
decrypted := decryptField(t, notifier.ID, notifier.WebhookNotifier.WebhookURL)
|
||||
assert.Equal(t, "https://webhook.example.com/test456", decrypted)
|
||||
assert.Equal(
|
||||
t,
|
||||
"https://webhook.example.com/test456",
|
||||
notifier.WebhookNotifier.WebhookURL,
|
||||
)
|
||||
|
||||
assert.True(
|
||||
t,
|
||||
isEncrypted(notifier.WebhookNotifier.Headers[0].Value),
|
||||
"Header value should be encrypted",
|
||||
)
|
||||
decrypted1 := decryptField(
|
||||
t,
|
||||
notifier.ID,
|
||||
notifier.WebhookNotifier.Headers[0].Value,
|
||||
)
|
||||
assert.Equal(t, "Bearer secret-token-12345", decrypted1)
|
||||
|
||||
assert.True(
|
||||
t,
|
||||
isEncrypted(notifier.WebhookNotifier.Headers[1].Value),
|
||||
"Header value should be encrypted",
|
||||
)
|
||||
decrypted2 := decryptField(
|
||||
t,
|
||||
notifier.ID,
|
||||
notifier.WebhookNotifier.Headers[1].Value,
|
||||
)
|
||||
assert.Equal(t, "api-key-67890", decrypted2)
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -965,6 +1023,204 @@ func Test_CreateNotifier_AllSensitiveFieldsEncryptedInDB(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_TransferNotifier_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sourceRole *users_enums.WorkspaceRole
|
||||
targetRole *users_enums.WorkspaceRole
|
||||
isGlobalAdmin bool
|
||||
expectSuccess bool
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "owner in both workspaces can transfer",
|
||||
sourceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
|
||||
targetRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "admin in both workspaces can transfer",
|
||||
sourceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
|
||||
targetRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "member in both workspaces can transfer",
|
||||
sourceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
|
||||
targetRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "viewer in both workspaces cannot transfer",
|
||||
sourceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
|
||||
targetRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "global admin can transfer",
|
||||
sourceRole: nil,
|
||||
targetRole: nil,
|
||||
isGlobalAdmin: true,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createRouter()
|
||||
GetNotifierService().SetNotifierDatabaseCounter(&mockNotifierDatabaseCounter{})
|
||||
|
||||
sourceOwner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
targetOwner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
sourceWorkspace := workspaces_testing.CreateTestWorkspace(
|
||||
"Source Workspace",
|
||||
sourceOwner,
|
||||
router,
|
||||
)
|
||||
targetWorkspace := workspaces_testing.CreateTestWorkspace(
|
||||
"Target Workspace",
|
||||
targetOwner,
|
||||
router,
|
||||
)
|
||||
|
||||
notifier := createNewNotifier(sourceWorkspace.ID)
|
||||
var savedNotifier Notifier
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/notifiers",
|
||||
"Bearer "+sourceOwner.Token,
|
||||
*notifier,
|
||||
http.StatusOK,
|
||||
&savedNotifier,
|
||||
)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUserToken = admin.Token
|
||||
} else if tt.sourceRole != nil {
|
||||
testUser := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
sourceWorkspace,
|
||||
testUser,
|
||||
*tt.sourceRole,
|
||||
sourceOwner.Token,
|
||||
router,
|
||||
)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
targetWorkspace,
|
||||
testUser,
|
||||
*tt.targetRole,
|
||||
targetOwner.Token,
|
||||
router,
|
||||
)
|
||||
testUserToken = testUser.Token
|
||||
}
|
||||
|
||||
request := TransferNotifierRequest{
|
||||
TargetWorkspaceID: targetWorkspace.ID,
|
||||
}
|
||||
|
||||
testResp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/notifiers/%s/transfer", savedNotifier.ID.String()),
|
||||
"Bearer "+testUserToken,
|
||||
request,
|
||||
tt.expectedStatusCode,
|
||||
)
|
||||
|
||||
if tt.expectSuccess {
|
||||
assert.Contains(t, string(testResp.Body), "transferred successfully")
|
||||
|
||||
var retrievedNotifier Notifier
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/notifiers/%s", savedNotifier.ID.String()),
|
||||
"Bearer "+targetOwner.Token,
|
||||
http.StatusOK,
|
||||
&retrievedNotifier,
|
||||
)
|
||||
assert.Equal(t, targetWorkspace.ID, retrievedNotifier.WorkspaceID)
|
||||
|
||||
deleteNotifier(t, router, savedNotifier.ID, targetWorkspace.ID, targetOwner.Token)
|
||||
} else {
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
deleteNotifier(t, router, savedNotifier.ID, sourceWorkspace.ID, sourceOwner.Token)
|
||||
}
|
||||
|
||||
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_TransferNotifierNotManagableWorkspace_TransferFailed(t *testing.T) {
|
||||
router := createRouter()
|
||||
GetNotifierService().SetNotifierDatabaseCounter(&mockNotifierDatabaseCounter{})
|
||||
|
||||
userA := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
userB := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
workspace1 := workspaces_testing.CreateTestWorkspace("Workspace 1", userA, router)
|
||||
workspace2 := workspaces_testing.CreateTestWorkspace("Workspace 2", userB, router)
|
||||
|
||||
notifier := createNewNotifier(workspace1.ID)
|
||||
var savedNotifier Notifier
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/notifiers",
|
||||
"Bearer "+userA.Token,
|
||||
*notifier,
|
||||
http.StatusOK,
|
||||
&savedNotifier,
|
||||
)
|
||||
|
||||
request := TransferNotifierRequest{
|
||||
TargetWorkspaceID: workspace2.ID,
|
||||
}
|
||||
|
||||
testResp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/notifiers/%s/transfer", savedNotifier.ID.String()),
|
||||
"Bearer "+userA.Token,
|
||||
request,
|
||||
http.StatusForbidden,
|
||||
)
|
||||
|
||||
assert.Contains(
|
||||
t,
|
||||
string(testResp.Body),
|
||||
"insufficient permissions to manage notifier in target workspace",
|
||||
)
|
||||
|
||||
deleteNotifier(t, router, savedNotifier.ID, workspace1.ID, userA.Token)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace1, router)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace2, router)
|
||||
}
|
||||
|
||||
type mockNotifierDatabaseCounter struct{}
|
||||
|
||||
func (m *mockNotifierDatabaseCounter) GetNotifierAttachedDatabasesIDs(
|
||||
notifierID uuid.UUID,
|
||||
) ([]uuid.UUID, error) {
|
||||
return []uuid.UUID{}, nil
|
||||
}
|
||||
|
||||
func createRouter() *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
@@ -979,6 +1235,7 @@ func createRouter() *gin.Engine {
|
||||
}
|
||||
|
||||
audit_logs.SetupDependencies()
|
||||
GetNotifierService().SetNotifierDatabaseCounter(&mockNotifierDatabaseCounter{})
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ var notifierService = &NotifierService{
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
audit_logs.GetAuditLogService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
nil,
|
||||
}
|
||||
var notifierController = &NotifierController{
|
||||
notifierService,
|
||||
|
||||
7
backend/internal/features/notifiers/dto.go
Normal file
7
backend/internal/features/notifiers/dto.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package notifiers
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
type TransferNotifierRequest struct {
|
||||
TargetWorkspaceID uuid.UUID `json:"targetWorkspaceId" binding:"required"`
|
||||
}
|
||||
36
backend/internal/features/notifiers/errors.go
Normal file
36
backend/internal/features/notifiers/errors.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package notifiers
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrInsufficientPermissionsToManageNotifier = errors.New(
|
||||
"insufficient permissions to manage notifier in this workspace",
|
||||
)
|
||||
ErrInsufficientPermissionsToViewNotifier = errors.New(
|
||||
"insufficient permissions to view notifier in this workspace",
|
||||
)
|
||||
ErrInsufficientPermissionsToViewNotifiers = errors.New(
|
||||
"insufficient permissions to view notifiers in this workspace",
|
||||
)
|
||||
ErrInsufficientPermissionsToTestNotifier = errors.New(
|
||||
"insufficient permissions to test notifier in this workspace",
|
||||
)
|
||||
ErrNotifierDoesNotBelongToWorkspace = errors.New(
|
||||
"notifier does not belong to this workspace",
|
||||
)
|
||||
ErrInsufficientPermissionsInSourceWorkspace = errors.New(
|
||||
"insufficient permissions to manage notifier in source workspace",
|
||||
)
|
||||
ErrInsufficientPermissionsInTargetWorkspace = errors.New(
|
||||
"insufficient permissions to manage notifier in target workspace",
|
||||
)
|
||||
ErrNotifierHasAttachedDatabases = errors.New(
|
||||
"notifier has attached databases and cannot be deleted",
|
||||
)
|
||||
ErrNotifierHasAttachedDatabasesCannotTransfer = errors.New(
|
||||
"notifier has attached databases and cannot be transferred",
|
||||
)
|
||||
ErrNotifierHasOtherAttachedDatabasesCannotTransfer = errors.New(
|
||||
"notifier has other attached databases and cannot be transferred",
|
||||
)
|
||||
)
|
||||
@@ -3,6 +3,8 @@ package notifiers
|
||||
import (
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"log/slog"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type NotificationSender interface {
|
||||
@@ -19,3 +21,7 @@ type NotificationSender interface {
|
||||
|
||||
EncryptSensitiveData(encryptor encryption.FieldEncryptor) error
|
||||
}
|
||||
|
||||
type NotifierDatabaseCounter interface {
|
||||
GetNotifierAttachedDatabasesIDs(notifierID uuid.UUID) ([]uuid.UUID, error)
|
||||
}
|
||||
|
||||
@@ -21,6 +21,10 @@ type WebhookHeader struct {
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
// Before both WebhookURL, BodyTemplate and HeadersJSON were considered
|
||||
// as sensetive data and it was causing issues. Now only headers values
|
||||
// considered as sensetive data, but we try to decrypt webhook URL and
|
||||
// body template for backward combability
|
||||
type WebhookNotifier struct {
|
||||
NotifierID uuid.UUID `json:"notifierId" gorm:"primaryKey;column:notifier_id"`
|
||||
WebhookURL string `json:"webhookUrl" gorm:"not null;column:webhook_url"`
|
||||
@@ -58,6 +62,20 @@ func (t *WebhookNotifier) AfterFind(_ *gorm.DB) error {
|
||||
}
|
||||
}
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
|
||||
if t.WebhookURL != "" {
|
||||
if decrypted, err := encryptor.Decrypt(t.NotifierID, t.WebhookURL); err == nil {
|
||||
t.WebhookURL = decrypted
|
||||
}
|
||||
}
|
||||
|
||||
if t.BodyTemplate != nil && *t.BodyTemplate != "" {
|
||||
if decrypted, err := encryptor.Decrypt(t.NotifierID, *t.BodyTemplate); err == nil {
|
||||
t.BodyTemplate = &decrypted
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -79,22 +97,24 @@ func (t *WebhookNotifier) Send(
|
||||
heading string,
|
||||
message string,
|
||||
) error {
|
||||
webhookURL, err := encryptor.Decrypt(t.NotifierID, t.WebhookURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
|
||||
if err := t.decryptHeadersForSending(encryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch t.WebhookMethod {
|
||||
case WebhookMethodGET:
|
||||
return t.sendGET(webhookURL, heading, message, logger)
|
||||
return t.sendGET(t.WebhookURL, heading, message, logger)
|
||||
case WebhookMethodPOST:
|
||||
return t.sendPOST(webhookURL, heading, message, logger)
|
||||
return t.sendPOST(t.WebhookURL, heading, message, logger)
|
||||
default:
|
||||
return fmt.Errorf("unsupported webhook method: %s", t.WebhookMethod)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WebhookNotifier) HideSensitiveData() {
|
||||
for i := range t.Headers {
|
||||
t.Headers[i].Value = ""
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WebhookNotifier) Update(incoming *WebhookNotifier) {
|
||||
@@ -105,14 +125,15 @@ func (t *WebhookNotifier) Update(incoming *WebhookNotifier) {
|
||||
}
|
||||
|
||||
func (t *WebhookNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
if t.WebhookURL != "" {
|
||||
encrypted, err := encryptor.Encrypt(t.NotifierID, t.WebhookURL)
|
||||
for i := range t.Headers {
|
||||
if t.Headers[i].Value != "" {
|
||||
encrypted, err := encryptor.Encrypt(t.NotifierID, t.Headers[i].Value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt header value: %w", err)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt webhook URL: %w", err)
|
||||
t.Headers[i].Value = encrypted
|
||||
}
|
||||
|
||||
t.WebhookURL = encrypted
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -241,3 +262,15 @@ func escapeJSONString(s string) string {
|
||||
|
||||
return string(b[1 : len(b)-1])
|
||||
}
|
||||
|
||||
func (t *WebhookNotifier) decryptHeadersForSending(encryptor encryption.FieldEncryptor) error {
|
||||
for i := range t.Headers {
|
||||
if t.Headers[i].Value != "" {
|
||||
if decrypted, err := encryptor.Decrypt(t.NotifierID, t.Headers[i].Value); err == nil {
|
||||
t.Headers[i].Value = decrypted
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package notifiers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
@@ -14,11 +13,18 @@ import (
|
||||
)
|
||||
|
||||
type NotifierService struct {
|
||||
notifierRepository *NotifierRepository
|
||||
logger *slog.Logger
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
auditLogService *audit_logs.AuditLogService
|
||||
fieldEncryptor encryption.FieldEncryptor
|
||||
notifierRepository *NotifierRepository
|
||||
logger *slog.Logger
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
auditLogService *audit_logs.AuditLogService
|
||||
fieldEncryptor encryption.FieldEncryptor
|
||||
notifierDatabaseCounter NotifierDatabaseCounter
|
||||
}
|
||||
|
||||
func (s *NotifierService) SetNotifierDatabaseCounter(
|
||||
notifierDatabaseCounter NotifierDatabaseCounter,
|
||||
) {
|
||||
s.notifierDatabaseCounter = notifierDatabaseCounter
|
||||
}
|
||||
|
||||
func (s *NotifierService) SaveNotifier(
|
||||
@@ -31,7 +37,7 @@ func (s *NotifierService) SaveNotifier(
|
||||
return err
|
||||
}
|
||||
if !canManage {
|
||||
return errors.New("insufficient permissions to manage notifier in this workspace")
|
||||
return ErrInsufficientPermissionsToManageNotifier
|
||||
}
|
||||
|
||||
isUpdate := notifier.ID != uuid.Nil
|
||||
@@ -43,7 +49,7 @@ func (s *NotifierService) SaveNotifier(
|
||||
}
|
||||
|
||||
if existingNotifier.WorkspaceID != workspaceID {
|
||||
return errors.New("notifier does not belong to this workspace")
|
||||
return ErrNotifierDoesNotBelongToWorkspace
|
||||
}
|
||||
|
||||
existingNotifier.Update(notifier)
|
||||
@@ -106,7 +112,17 @@ func (s *NotifierService) DeleteNotifier(
|
||||
return err
|
||||
}
|
||||
if !canManage {
|
||||
return errors.New("insufficient permissions to manage notifier in this workspace")
|
||||
return ErrInsufficientPermissionsToManageNotifier
|
||||
}
|
||||
|
||||
attachedDatabasesIDs, err := s.notifierDatabaseCounter.GetNotifierAttachedDatabasesIDs(
|
||||
notifier.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(attachedDatabasesIDs) > 0 {
|
||||
return ErrNotifierHasAttachedDatabases
|
||||
}
|
||||
|
||||
err = s.notifierRepository.Delete(notifier)
|
||||
@@ -137,13 +153,17 @@ func (s *NotifierService) GetNotifier(
|
||||
return nil, err
|
||||
}
|
||||
if !canView {
|
||||
return nil, errors.New("insufficient permissions to view notifier in this workspace")
|
||||
return nil, ErrInsufficientPermissionsToViewNotifier
|
||||
}
|
||||
|
||||
notifier.HideSensitiveData()
|
||||
return notifier, nil
|
||||
}
|
||||
|
||||
func (s *NotifierService) GetNotifierByID(id uuid.UUID) (*Notifier, error) {
|
||||
return s.notifierRepository.FindByID(id)
|
||||
}
|
||||
|
||||
func (s *NotifierService) GetNotifiers(
|
||||
user *users_models.User,
|
||||
workspaceID uuid.UUID,
|
||||
@@ -153,7 +173,7 @@ func (s *NotifierService) GetNotifiers(
|
||||
return nil, err
|
||||
}
|
||||
if !canView {
|
||||
return nil, errors.New("insufficient permissions to view notifiers in this workspace")
|
||||
return nil, ErrInsufficientPermissionsToViewNotifiers
|
||||
}
|
||||
|
||||
notifiers, err := s.notifierRepository.FindByWorkspaceID(workspaceID)
|
||||
@@ -182,7 +202,7 @@ func (s *NotifierService) SendTestNotification(
|
||||
return err
|
||||
}
|
||||
if !canView {
|
||||
return errors.New("insufficient permissions to test notifier in this workspace")
|
||||
return ErrInsufficientPermissionsToTestNotifier
|
||||
}
|
||||
|
||||
err = notifier.Send(s.fieldEncryptor, s.logger, "Test message", "This is a test message")
|
||||
@@ -210,7 +230,7 @@ func (s *NotifierService) SendTestNotificationToNotifier(
|
||||
}
|
||||
|
||||
if existingNotifier.WorkspaceID != notifier.WorkspaceID {
|
||||
return errors.New("notifier does not belong to this workspace")
|
||||
return ErrNotifierDoesNotBelongToWorkspace
|
||||
}
|
||||
|
||||
existingNotifier.Update(notifier)
|
||||
@@ -269,6 +289,70 @@ func (s *NotifierService) SendNotification(
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NotifierService) TransferNotifierToWorkspace(
|
||||
user *users_models.User,
|
||||
notifierID uuid.UUID,
|
||||
targetWorkspaceID uuid.UUID,
|
||||
transferingWithDbID *uuid.UUID,
|
||||
) error {
|
||||
existingNotifier, err := s.notifierRepository.FindByID(notifierID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
canManageSource, err := s.workspaceService.CanUserManageDBs(existingNotifier.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !canManageSource {
|
||||
return ErrInsufficientPermissionsInSourceWorkspace
|
||||
}
|
||||
|
||||
canManageTarget, err := s.workspaceService.CanUserManageDBs(targetWorkspaceID, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !canManageTarget {
|
||||
return ErrInsufficientPermissionsInTargetWorkspace
|
||||
}
|
||||
|
||||
attachedDatabasesIDs, err := s.notifierDatabaseCounter.GetNotifierAttachedDatabasesIDs(
|
||||
existingNotifier.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if transferingWithDbID != nil {
|
||||
for _, dbID := range attachedDatabasesIDs {
|
||||
if dbID != *transferingWithDbID {
|
||||
return ErrNotifierHasOtherAttachedDatabasesCannotTransfer
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if len(attachedDatabasesIDs) > 0 {
|
||||
return ErrNotifierHasAttachedDatabasesCannotTransfer
|
||||
}
|
||||
}
|
||||
|
||||
sourceWorkspaceID := existingNotifier.WorkspaceID
|
||||
existingNotifier.WorkspaceID = targetWorkspaceID
|
||||
|
||||
_, err = s.notifierRepository.Save(existingNotifier)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Notifier transferred: %s from workspace %s to workspace %s",
|
||||
existingNotifier.Name, sourceWorkspaceID, targetWorkspaceID),
|
||||
&user.ID,
|
||||
&targetWorkspaceID,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *NotifierService) OnBeforeWorkspaceDeletion(workspaceID uuid.UUID) error {
|
||||
notifiers, err := s.notifierRepository.FindByWorkspaceID(workspaceID)
|
||||
if err != nil {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user