Compare commits

...

36 Commits

Author SHA1 Message Date
Rostislav Dugin
4344f5ea5e Merge pull request #273 from databasus/develop
FIX (ci \ cd): Make DB files in CI \ CD executable
2026-01-15 22:17:06 +03:00
Rostislav Dugin
7c6afa5b88 FIX (ci \ cd): Make DB files in CI \ CD executable 2026-01-15 22:16:45 +03:00
Rostislav Dugin
dbac799e1b Merge pull request #272 from databasus/develop
FIX (backups): Add backups failure logging when it is expected
2026-01-15 22:02:39 +03:00
Rostislav Dugin
7ee3817089 FIX (backups): Add backups failure logging when it is expected 2026-01-15 22:01:53 +03:00
Rostislav Dugin
bae6f7f007 Merge pull request #271 from databasus/develop
Develop
2026-01-15 21:19:55 +03:00
Rostislav Dugin
55dc087ddd FIX (containers): Do not allow to backup internal DB from inside containers, instead give link to FAQ with manual how to backup Databasus in proper way 2026-01-15 21:18:37 +03:00
Rostislav Dugin
c94d0db637 FIX (ci \ cd): Remove caches and use assets from repo to avoid flucky tests over CI 2026-01-15 21:03:43 +03:00
Rostislav Dugin
a1adef2261 !REFACTOR (tasks): Move tasks cancellation and tracking to separate package from backuping to use for restores 2026-01-15 21:03:05 +03:00
Rostislav Dugin
4602dc3f88 Merge pull request #267 from databasus/develop
FIX (mysql): Enable allowCleartextPasswords over SSL
2026-01-14 18:13:46 +03:00
Rostislav Dugin
cbbfc5ea8f FIX (mysql): Enable allowCleartextPasswords over SSL 2026-01-14 18:11:49 +03:00
Rostislav Dugin
dd1072e230 Merge pull request #265 from databasus/develop
FIX (pre-commit): Add running go mod tidy in pre-commit
2026-01-14 15:18:35 +03:00
Rostislav Dugin
a495e5317a FIX (pre-commit): Add running go mod tidy in pre-commit 2026-01-14 15:18:06 +03:00
Rostislav Dugin
7eed647038 Merge pull request #264 from databasus/develop
Develop
2026-01-14 15:14:05 +03:00
Rostislav Dugin
6973241e25 FIX (backups): Throw error on parallel download token generation 2026-01-14 14:40:22 +03:00
Rostislav Dugin
ab181f5b81 FEATURE (bandwidth): Limit download throughput for backups to not exhaust more than 75% of server network bandwidth 2026-01-14 14:40:22 +03:00
Rostislav Dugin
b60a0cc170 FEATURE (backups): Allow single backup download to avoid exhausting of server throughput 2026-01-14 14:40:22 +03:00
Rostislav Dugin
f319a497b3 FEATURE (auth): Add rate limiting for sign in via email using sliding window 2026-01-14 14:40:22 +03:00
Rostislav Dugin
bc870b3f8e Merge pull request #261 from databasus/develop
FIX (webhook): Update webhook tests to not expect URL to be encrypted
2026-01-14 09:43:26 +03:00
Rostislav Dugin
15383c59eb FIX (webhook): Update webhook tests to not expect URL to be encrypted 2026-01-14 09:42:25 +03:00
Rostislav Dugin
d14c223a65 Merge pull request #259 from databasus/develop
Develop
2026-01-14 09:10:28 +03:00
Rostislav Dugin
2c0a294027 FIX (webhook): Do not encypt webhook URL, keep encyption for headers only 2026-01-14 09:09:00 +03:00
Rostislav Dugin
5d851d73bd FIX (mysql \ mariadb): Decrease strictness of SELECT check for health check 2026-01-14 08:39:27 +03:00
Rostislav Dugin
699913c251 FIX (postgresql): Filter TEMP table SELECT checks 2026-01-14 07:42:29 +03:00
Rostislav Dugin
a2e3f30a6d Merge pull request #258 from databasus/develop
FEATURE (backups): Add support of multinode Databasus setup
2026-01-14 07:34:06 +03:00
Rostislav Dugin
80f1174ecd FEATURE (backups): Add support of multinode Databasus setup 2026-01-14 07:32:13 +03:00
Rostislav Dugin
a47f8d5e2c Merge pull request #253 from databasus/develop
FIX (permissions check): Check permissions only in schemas selected f…
2026-01-12 14:23:24 +03:00
Rostislav Dugin
54b9e67656 FIX (permissions check): Check permissions only in schemas selected for backup 2026-01-12 14:22:12 +03:00
Rostislav Dugin
3782846872 Merge pull request #251 from databasus/develop
FIX (tidy): Run go mod tidy
2026-01-12 11:32:25 +03:00
Rostislav Dugin
245a81897f FIX (tidy): Run go mod tidy 2026-01-12 11:31:52 +03:00
Rostislav Dugin
5cbc0773b6 Merge pull request #250 from databasus/develop
FEATURE (backups): Move backups cancellation to Valkey pub\sub
2026-01-12 11:26:29 +03:00
Rostislav Dugin
997fc01442 FEATURE (backups): Move backups cancellation to Valkey pub\sub 2026-01-12 11:24:25 +03:00
Rostislav Dugin
6d0ae32d0c Merge pull request #240 from databasus/develop
FIX (oauth): Enable GitHub and Google OAuth
2026-01-10 20:15:43 +03:00
Rostislav Dugin
011985d723 FIX (oauth): Enable GitHub and Google OAuth 2026-01-10 19:19:37 +03:00
Rostislav Dugin
d677ee61de Merge pull request #239 from databasus/develop
FIX (mariadb): --skip-ssl-verify-server-cert for mariadb / mysql
2026-01-10 18:34:58 +03:00
Rostislav Dugin
c6b8f6e87a Merge pull request #237 from wzzrd/bugfix/disable_mariadb_mysql_ssl_verify
--skip-ssl-verify-server-cert for mariadb
2026-01-10 18:33:45 +03:00
Maxim Burgerhout
2bb5f93d00 --skip-ssl-verify-server-cert for mariadb / mysql
This change adds the --skip-ssl-verify-server-cert flag to mariadb
database connections for both backups and restores. This errors when
trying to verify certificates during those procedures.
2026-01-10 15:50:09 +01:00
112 changed files with 9073 additions and 3046 deletions

View File

@@ -17,17 +17,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.24.4"
- name: Cache Go modules
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
~/.cache/go-build
key: ${{ runner.os }}-go-${{ hashFiles('backend/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
go-version: "1.24.9"
- name: Install golangci-lint
run: |
@@ -63,8 +53,6 @@ jobs:
uses: actions/setup-node@v4
with:
node-version: "20"
cache: "npm"
cache-dependency-path: frontend/package-lock.json
- name: Install dependencies
run: |
@@ -93,8 +81,6 @@ jobs:
uses: actions/setup-node@v4
with:
node-version: "20"
cache: "npm"
cache-dependency-path: frontend/package-lock.json
- name: Install dependencies
run: |
@@ -134,17 +120,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.24.4"
- name: Cache Go modules
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
~/.cache/go-build
key: ${{ runner.os }}-go-${{ hashFiles('backend/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
go-version: "1.24.9"
- name: Create .env file for testing
run: |
@@ -221,6 +197,12 @@ jobs:
TEST_MONGODB_60_PORT=27060
TEST_MONGODB_70_PORT=27070
TEST_MONGODB_82_PORT=27082
# Valkey (cache)
VALKEY_HOST=localhost
VALKEY_PORT=6379
VALKEY_USERNAME=
VALKEY_PASSWORD=
VALKEY_IS_SSL=false
EOF
- name: Start test containers
@@ -233,6 +215,11 @@ jobs:
# Wait for main dev database
timeout 60 bash -c 'until docker exec dev-db pg_isready -h localhost -p 5437 -U postgres; do sleep 2; done'
# Wait for Valkey (cache)
echo "Waiting for Valkey..."
timeout 60 bash -c 'until docker exec dev-valkey valkey-cli ping 2>/dev/null | grep -q PONG; do sleep 2; done'
echo "Valkey is ready!"
# Wait for test databases
timeout 60 bash -c 'until nc -z localhost 5000; do sleep 2; done'
timeout 60 bash -c 'until nc -z localhost 5001; do sleep 2; done'
@@ -310,34 +297,6 @@ jobs:
mkdir -p databasus-data/backups
mkdir -p databasus-data/temp
- name: Cache PostgreSQL client tools
id: cache-postgres
uses: actions/cache@v4
with:
path: /usr/lib/postgresql
key: postgres-clients-12-18-v1
- name: Cache MySQL client tools
id: cache-mysql
uses: actions/cache@v4
with:
path: backend/tools/mysql
key: mysql-clients-57-80-84-9-v1
- name: Cache MariaDB client tools
id: cache-mariadb
uses: actions/cache@v4
with:
path: backend/tools/mariadb
key: mariadb-clients-106-121-v1
- name: Cache MongoDB Database Tools
id: cache-mongodb
uses: actions/cache@v4
with:
path: backend/tools/mongodb
key: mongodb-database-tools-100.10.0-v1
- name: Install MySQL dependencies
run: |
sudo apt-get update -qq
@@ -345,31 +304,58 @@ jobs:
sudo ln -sf /usr/lib/x86_64-linux-gnu/libncurses.so.6 /usr/lib/x86_64-linux-gnu/libncurses.so.5
sudo ln -sf /usr/lib/x86_64-linux-gnu/libtinfo.so.6 /usr/lib/x86_64-linux-gnu/libtinfo.so.5
- name: Install PostgreSQL, MySQL, MariaDB and MongoDB client tools
if: steps.cache-postgres.outputs.cache-hit != 'true' || steps.cache-mysql.outputs.cache-hit != 'true' || steps.cache-mariadb.outputs.cache-hit != 'true' || steps.cache-mongodb.outputs.cache-hit != 'true'
run: |
chmod +x backend/tools/download_linux.sh
cd backend/tools
./download_linux.sh
- name: Setup PostgreSQL symlinks (when using cache)
if: steps.cache-postgres.outputs.cache-hit == 'true'
- name: Setup PostgreSQL, MySQL and MariaDB client tools from pre-built assets
run: |
cd backend/tools
mkdir -p postgresql
# Create directory structure
mkdir -p postgresql mysql mariadb mongodb/bin
# Copy PostgreSQL client tools (12-18) from pre-built assets
for version in 12 13 14 15 16 17 18; do
version_dir="postgresql/postgresql-$version"
mkdir -p "$version_dir/bin"
pg_bin_dir="/usr/lib/postgresql/$version/bin"
if [ -d "$pg_bin_dir" ]; then
ln -sf "$pg_bin_dir/pg_dump" "$version_dir/bin/pg_dump"
ln -sf "$pg_bin_dir/pg_dumpall" "$version_dir/bin/pg_dumpall"
ln -sf "$pg_bin_dir/psql" "$version_dir/bin/psql"
ln -sf "$pg_bin_dir/pg_restore" "$version_dir/bin/pg_restore"
ln -sf "$pg_bin_dir/createdb" "$version_dir/bin/createdb"
ln -sf "$pg_bin_dir/dropdb" "$version_dir/bin/dropdb"
fi
mkdir -p postgresql/postgresql-$version
cp -r ../../assets/tools/x64/postgresql/postgresql-$version/bin postgresql/postgresql-$version/
done
# Copy MySQL client tools (5.7, 8.0, 8.4, 9) from pre-built assets
for version in 5.7 8.0 8.4 9; do
mkdir -p mysql/mysql-$version
cp -r ../../assets/tools/x64/mysql/mysql-$version/bin mysql/mysql-$version/
done
# Copy MariaDB client tools (10.6, 12.1) from pre-built assets
for version in 10.6 12.1; do
mkdir -p mariadb/mariadb-$version
cp -r ../../assets/tools/x64/mariadb/mariadb-$version/bin mariadb/mariadb-$version/
done
# Make all binaries executable
chmod +x postgresql/*/bin/*
chmod +x mysql/*/bin/*
chmod +x mariadb/*/bin/*
echo "Pre-built client tools setup complete"
- name: Install MongoDB Database Tools
run: |
cd backend/tools
# MongoDB Database Tools must be downloaded (not in pre-built assets)
# They are backward compatible - single version supports all servers (4.0-8.0)
MONGODB_TOOLS_URL="https://fastdl.mongodb.org/tools/db/mongodb-database-tools-debian12-x86_64-100.10.0.deb"
echo "Downloading MongoDB Database Tools..."
wget -q "$MONGODB_TOOLS_URL" -O /tmp/mongodb-database-tools.deb
echo "Installing MongoDB Database Tools..."
sudo dpkg -i /tmp/mongodb-database-tools.deb || sudo apt-get install -f -y --no-install-recommends
# Create symlinks to tools directory
ln -sf /usr/bin/mongodump mongodb/bin/mongodump
ln -sf /usr/bin/mongorestore mongodb/bin/mongorestore
rm -f /tmp/mongodb-database-tools.deb
echo "MongoDB Database Tools installed successfully"
- name: Verify MariaDB client tools exist
run: |
@@ -715,4 +701,4 @@ jobs:
- name: Push Helm chart to GHCR
run: |
VERSION="${{ needs.determine-version.outputs.new_version }}"
helm push databasus-${VERSION}.tgz oci://ghcr.io/databasus/charts
helm push databasus-${VERSION}.tgz oci://ghcr.io/databasus/charts

View File

@@ -27,3 +27,10 @@ repos:
language: system
files: ^backend/.*\.go$
pass_filenames: false
- id: backend-go-mod-tidy
name: Backend Go Mod Tidy
entry: bash -c "cd backend && go mod tidy"
language: system
files: ^backend/.*\.go$
pass_filenames: false

1345
AGENTS.md Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -22,7 +22,7 @@ RUN npm run build
# ========= BUILD BACKEND =========
# Backend build stage
FROM --platform=$BUILDPLATFORM golang:1.24.4 AS backend-build
FROM --platform=$BUILDPLATFORM golang:1.24.9 AS backend-build
# Make TARGET args available early so tools built here match the final image arch
ARG TARGETOS
@@ -123,6 +123,15 @@ RUN wget -qO- https://www.postgresql.org/media/keys/ACCC4CF8.asc | apt-key add -
apt-get install -y --no-install-recommends postgresql-17 && \
rm -rf /var/lib/apt/lists/*
# Install Valkey server from debian repository
# Valkey is only accessible internally (localhost) - not exposed outside container
RUN wget -O /usr/share/keyrings/greensec.github.io-valkey-debian.key https://greensec.github.io/valkey-debian/public.key && \
echo "deb [signed-by=/usr/share/keyrings/greensec.github.io-valkey-debian.key] https://greensec.github.io/valkey-debian/repo $(lsb_release -cs) main" \
> /etc/apt/sources.list.d/valkey-debian.list && \
apt-get update && \
apt-get install -y --no-install-recommends valkey && \
rm -rf /var/lib/apt/lists/*
# ========= Install rclone =========
RUN apt-get update && \
apt-get install -y --no-install-recommends rclone && \
@@ -250,6 +259,30 @@ mkdir -p /databasus-data/backups
chown -R postgres:postgres /databasus-data
chmod 700 /databasus-data/temp
# ========= Start Valkey (internal cache) =========
echo "Configuring Valkey cache..."
cat > /tmp/valkey.conf << 'VALKEY_CONFIG'
port 6379
bind 127.0.0.1
protected-mode yes
save ""
maxmemory 256mb
maxmemory-policy allkeys-lru
VALKEY_CONFIG
echo "Starting Valkey..."
valkey-server /tmp/valkey.conf &
VALKEY_PID=\$!
echo "Waiting for Valkey to be ready..."
for i in {1..30}; do
if valkey-cli ping >/dev/null 2>&1; then
echo "Valkey is ready!"
break
fi
sleep 1
done
# Initialize PostgreSQL if not already initialized
if [ ! -s "/databasus-data/pgdata/PG_VERSION" ]; then
echo "Initializing PostgreSQL database..."

View File

@@ -1,152 +0,0 @@
---
description:
globs:
alwaysApply: true
---
Always place private methods to the bottom of file
**This rule applies to ALL Go files including tests, services, controllers, repositories, etc.**
In Go, exported (public) functions/methods start with uppercase letters, while unexported (private) ones start with lowercase letters.
## Structure Order:
1. Type definitions and constants
2. Public methods/functions (uppercase)
3. Private methods/functions (lowercase)
## Examples:
### Service with methods:
```go
type UserService struct {
repository *UserRepository
}
// Public methods first
func (s *UserService) CreateUser(user *User) error {
if err := s.validateUser(user); err != nil {
return err
}
return s.repository.Save(user)
}
func (s *UserService) GetUser(id uuid.UUID) (*User, error) {
return s.repository.FindByID(id)
}
// Private methods at the bottom
func (s *UserService) validateUser(user *User) error {
if user.Name == "" {
return errors.New("name is required")
}
return nil
}
```
### Package-level functions:
```go
package utils
// Public functions first
func ProcessData(data []byte) (Result, error) {
cleaned := sanitizeInput(data)
return parseData(cleaned)
}
func ValidateInput(input string) bool {
return isValidFormat(input) && checkLength(input)
}
// Private functions at the bottom
func sanitizeInput(data []byte) []byte {
// implementation
}
func parseData(data []byte) (Result, error) {
// implementation
}
func isValidFormat(input string) bool {
// implementation
}
func checkLength(input string) bool {
// implementation
}
```
### Test files:
```go
package user_test
// Public test functions first
func Test_CreateUser_ValidInput_UserCreated(t *testing.T) {
user := createTestUser()
result, err := service.CreateUser(user)
assert.NoError(t, err)
assert.NotNil(t, result)
}
func Test_GetUser_ExistingUser_ReturnsUser(t *testing.T) {
user := createTestUser()
// test implementation
}
// Private helper functions at the bottom
func createTestUser() *User {
return &User{
Name: "Test User",
Email: "test@example.com",
}
}
func setupTestDatabase() *Database {
// setup implementation
}
```
### Controller example:
```go
type ProjectController struct {
service *ProjectService
}
// Public HTTP handlers first
func (c *ProjectController) CreateProject(ctx *gin.Context) {
var request CreateProjectRequest
if err := ctx.ShouldBindJSON(&request); err != nil {
c.handleError(ctx, err)
return
}
// handler logic
}
func (c *ProjectController) GetProject(ctx *gin.Context) {
projectID := c.extractProjectID(ctx)
// handler logic
}
// Private helper methods at the bottom
func (c *ProjectController) handleError(ctx *gin.Context, err error) {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
}
func (c *ProjectController) extractProjectID(ctx *gin.Context) uuid.UUID {
return uuid.MustParse(ctx.Param("projectId"))
}
```
## Key Points:
- **Exported/Public** = starts with uppercase letter (CreateUser, GetProject)
- **Unexported/Private** = starts with lowercase letter (validateUser, handleError)
- This improves code readability by showing the public API first
- Private helpers are implementation details, so they go at the bottom
- Apply this rule consistently across ALL Go files in the project

View File

@@ -1,45 +0,0 @@
---
description:
globs:
alwaysApply: true
---
## Comment Guidelines
1. **No obvious comments** - Don't state what the code already clearly shows
2. **Functions and variables should have meaningful names** - Code should be self-documenting
3. **Comments for unclear code only** - Only add comments when code logic isn't immediately clear
## Key Principles:
- **Code should tell a story** - Use descriptive variable and function names
- **Comments explain WHY, not WHAT** - The code shows what happens, comments explain business logic or complex decisions
- **Prefer refactoring over commenting** - If code needs explaining, consider making it clearer instead
- **API documentation is required** - Swagger comments for all HTTP endpoints are mandatory
- **Complex algorithms deserve comments** - Mathematical formulas, business rules, or non-obvious optimizations
Example of useless comment:
1.
```sql
// Create projects table
CREATE TABLE projects (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
name TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
```
2.
```go
// Create test project
project := CreateTestProject(projectName, user, router)
```
3.
```go
// CreateValidLogItems creates valid log items for testing
func CreateValidLogItems(count int, uniqueID string) []logs_receiving.LogItemRequestDTO {
```

View File

@@ -1,133 +0,0 @@
---
description:
globs:
alwaysApply: true
---
1. When we write controller:
- we combine all routes to single controller
- names them as .WhatWeDo (not "handlers") concept
2. We use gin and \*gin.Context for all routes.
Example:
func (c *TasksController) GetAvailableTasks(ctx *gin.Context) ...
3. We document all routes with Swagger in the following format:
package audit_logs
import (
"net/http"
user_models "databasus-backend/internal/features/users/models"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type AuditLogController struct {
auditLogService \*AuditLogService
}
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
// All audit log endpoints require authentication (handled in main.go)
auditRoutes := router.Group("/audit-logs")
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
}
// GetGlobalAuditLogs
// @Summary Get global audit logs (ADMIN only)
// @Description Retrieve all audit logs across the system
// @Tags audit-logs
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
// @Success 200 {object} GetAuditLogsResponse
// @Failure 401 {object} map[string]string
// @Failure 403 {object} map[string]string
// @Router /audit-logs/global [get]
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
user, isOk := ctx.MustGet("user").(\*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
if err != nil {
if err.Error() == "only administrators can view global audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
ctx.JSON(http.StatusOK, response)
}
// GetUserAuditLogs
// @Summary Get user audit logs
// @Description Retrieve audit logs for a specific user
// @Tags audit-logs
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param userId path string true "User ID"
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
// @Success 200 {object} GetAuditLogsResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 403 {object} map[string]string
// @Router /audit-logs/users/{userId} [get]
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
user, isOk := ctx.MustGet("user").(\*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
userIDStr := ctx.Param("userId")
targetUserID, err := uuid.Parse(userIDStr)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
if err != nil {
if err.Error() == "insufficient permissions to view user audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
ctx.JSON(http.StatusOK, response)
}

View File

@@ -1,671 +0,0 @@
---
alwaysApply: false
---
This is example of CRUD:
------ backend/internal/features/audit_logs/controller.go ------
```
package audit_logs
import (
"net/http"
user_models "databasus-backend/internal/features/users/models"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type AuditLogController struct {
auditLogService *AuditLogService
}
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
// All audit log endpoints require authentication (handled in main.go)
auditRoutes := router.Group("/audit-logs")
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
}
// GetGlobalAuditLogs
// @Summary Get global audit logs (ADMIN only)
// @Description Retrieve all audit logs across the system
// @Tags audit-logs
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
// @Success 200 {object} GetAuditLogsResponse
// @Failure 401 {object} map[string]string
// @Failure 403 {object} map[string]string
// @Router /audit-logs/global [get]
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
user, isOk := ctx.MustGet("user").(*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
if err != nil {
if err.Error() == "only administrators can view global audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
ctx.JSON(http.StatusOK, response)
}
// GetUserAuditLogs
// @Summary Get user audit logs
// @Description Retrieve audit logs for a specific user
// @Tags audit-logs
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param userId path string true "User ID"
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
// @Success 200 {object} GetAuditLogsResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 403 {object} map[string]string
// @Router /audit-logs/users/{userId} [get]
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
user, isOk := ctx.MustGet("user").(*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
userIDStr := ctx.Param("userId")
targetUserID, err := uuid.Parse(userIDStr)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
if err != nil {
if err.Error() == "insufficient permissions to view user audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
ctx.JSON(http.StatusOK, response)
}
```
------ backend/internal/features/audit_logs/controller_test.go ------
```
package audit_logs
import (
"fmt"
"net/http"
"testing"
"time"
user_enums "databasus-backend/internal/features/users/enums"
users_middleware "databasus-backend/internal/features/users/middleware"
users_services "databasus-backend/internal/features/users/services"
users_testing "databasus-backend/internal/features/users/testing"
"databasus-backend/internal/storage"
test_utils "databasus-backend/internal/util/testing"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_GetGlobalAuditLogs_AdminSucceedsAndMemberGetsForbidden(t *testing.T) {
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
memberUser := users_testing.CreateTestUser(user_enums.UserRoleMember)
router := createRouter()
service := GetAuditLogService()
projectID := uuid.New()
// Create test logs
createAuditLog(service, "Test log with user", &adminUser.UserID, nil)
createAuditLog(service, "Test log with project", nil, &projectID)
createAuditLog(service, "Test log standalone", nil, nil)
// Test ADMIN can access global logs
var response GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router,
"/api/v1/audit-logs/global?limit=10", "Bearer "+adminUser.Token, http.StatusOK, &response)
assert.GreaterOrEqual(t, len(response.AuditLogs), 3)
assert.GreaterOrEqual(t, response.Total, int64(3))
messages := extractMessages(response.AuditLogs)
assert.Contains(t, messages, "Test log with user")
assert.Contains(t, messages, "Test log with project")
assert.Contains(t, messages, "Test log standalone")
// Test MEMBER cannot access global logs
resp := test_utils.MakeGetRequest(t, router, "/api/v1/audit-logs/global",
"Bearer "+memberUser.Token, http.StatusForbidden)
assert.Contains(t, string(resp.Body), "only administrators can view global audit logs")
}
func Test_GetUserAuditLogs_PermissionsEnforcedCorrectly(t *testing.T) {
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
router := createRouter()
service := GetAuditLogService()
projectID := uuid.New()
// Create test logs for different users
createAuditLog(service, "Test log user1 first", &user1.UserID, nil)
createAuditLog(service, "Test log user1 second", &user1.UserID, &projectID)
createAuditLog(service, "Test log user2 first", &user2.UserID, nil)
createAuditLog(service, "Test log user2 second", &user2.UserID, &projectID)
createAuditLog(service, "Test project log", nil, &projectID)
// Test ADMIN can view any user's logs
var user1Response GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router,
fmt.Sprintf("/api/v1/audit-logs/users/%s?limit=10", user1.UserID.String()),
"Bearer "+adminUser.Token, http.StatusOK, &user1Response)
assert.Equal(t, 2, len(user1Response.AuditLogs))
messages := extractMessages(user1Response.AuditLogs)
assert.Contains(t, messages, "Test log user1 first")
assert.Contains(t, messages, "Test log user1 second")
// Test user can view own logs
var ownLogsResponse GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router,
fmt.Sprintf("/api/v1/audit-logs/users/%s", user2.UserID.String()),
"Bearer "+user2.Token, http.StatusOK, &ownLogsResponse)
assert.Equal(t, 2, len(ownLogsResponse.AuditLogs))
// Test user cannot view other user's logs
resp := test_utils.MakeGetRequest(t, router,
fmt.Sprintf("/api/v1/audit-logs/users/%s", user1.UserID.String()),
"Bearer "+user2.Token, http.StatusForbidden)
assert.Contains(t, string(resp.Body), "insufficient permissions")
}
func Test_FilterAuditLogsByTime_ReturnsOnlyLogsBeforeDate(t *testing.T) {
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
router := createRouter()
service := GetAuditLogService()
db := storage.GetDb()
baseTime := time.Now().UTC()
// Create logs with different timestamps
createTimedLog(db, &adminUser.UserID, "Test old log", baseTime.Add(-2*time.Hour))
createTimedLog(db, &adminUser.UserID, "Test recent log", baseTime.Add(-30*time.Minute))
createAuditLog(service, "Test current log", &adminUser.UserID, nil)
// Test filtering - get logs before 1 hour ago
beforeTime := baseTime.Add(-1 * time.Hour)
var filteredResponse GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router,
fmt.Sprintf("/api/v1/audit-logs/global?beforeDate=%s", beforeTime.Format(time.RFC3339)),
"Bearer "+adminUser.Token, http.StatusOK, &filteredResponse)
// Verify only old log is returned
messages := extractMessages(filteredResponse.AuditLogs)
assert.Contains(t, messages, "Test old log")
assert.NotContains(t, messages, "Test recent log")
assert.NotContains(t, messages, "Test current log")
// Test without filter - should get all logs
var allResponse GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router, "/api/v1/audit-logs/global",
"Bearer "+adminUser.Token, http.StatusOK, &allResponse)
assert.GreaterOrEqual(t, len(allResponse.AuditLogs), 3)
}
func createRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
SetupDependencies()
v1 := router.Group("/api/v1")
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
GetAuditLogController().RegisterRoutes(protected.(*gin.RouterGroup))
return router
}
```
------ backend/internal/features/audit_logs/di.go ------
```
package audit_logs
import (
users_services "databasus-backend/internal/features/users/services"
"databasus-backend/internal/util/logger"
)
var auditLogRepository = &AuditLogRepository{}
var auditLogService = &AuditLogService{
auditLogRepository: auditLogRepository,
logger: logger.GetLogger(),
}
var auditLogController = &AuditLogController{
auditLogService: auditLogService,
}
func GetAuditLogService() *AuditLogService {
return auditLogService
}
func GetAuditLogController() *AuditLogController {
return auditLogController
}
func SetupDependencies() {
users_services.GetUserService().SetAuditLogWriter(auditLogService)
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
}
```
------ backend/internal/features/audit_logs/dto.go ------
```
package audit_logs
import "time"
type GetAuditLogsRequest struct {
Limit int `form:"limit" json:"limit"`
Offset int `form:"offset" json:"offset"`
BeforeDate *time.Time `form:"beforeDate" json:"beforeDate"`
}
type GetAuditLogsResponse struct {
AuditLogs []*AuditLog `json:"auditLogs"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}
```
------ backend/internal/features/audit_logs/models.go ------
```
package audit_logs
import (
"time"
"github.com/google/uuid"
)
type AuditLog struct {
ID uuid.UUID `json:"id" gorm:"column:id"`
UserID *uuid.UUID `json:"userId" gorm:"column:user_id"`
ProjectID *uuid.UUID `json:"projectId" gorm:"column:project_id"`
Message string `json:"message" gorm:"column:message"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
}
func (AuditLog) TableName() string {
return "audit_logs"
}
```
------ backend/internal/features/audit_logs/repository.go ------
```
package audit_logs
import (
"databasus-backend/internal/storage"
"time"
"github.com/google/uuid"
)
type AuditLogRepository struct{}
func (r *AuditLogRepository) Create(auditLog *AuditLog) error {
if auditLog.ID == uuid.Nil {
auditLog.ID = uuid.New()
}
return storage.GetDb().Create(auditLog).Error
}
func (r *AuditLogRepository) GetGlobal(limit, offset int, beforeDate *time.Time) ([]*AuditLog, error) {
var auditLogs []*AuditLog
query := storage.GetDb().Order("created_at DESC")
if beforeDate != nil {
query = query.Where("created_at < ?", *beforeDate)
}
err := query.
Limit(limit).
Offset(offset).
Find(&auditLogs).Error
return auditLogs, err
}
func (r *AuditLogRepository) GetByUser(
userID uuid.UUID,
limit, offset int,
beforeDate *time.Time,
) ([]*AuditLog, error) {
var auditLogs []*AuditLog
query := storage.GetDb().
Where("user_id = ?", userID).
Order("created_at DESC")
if beforeDate != nil {
query = query.Where("created_at < ?", *beforeDate)
}
err := query.
Limit(limit).
Offset(offset).
Find(&auditLogs).Error
return auditLogs, err
}
func (r *AuditLogRepository) GetByProject(
projectID uuid.UUID,
limit, offset int,
beforeDate *time.Time,
) ([]*AuditLog, error) {
var auditLogs []*AuditLog
query := storage.GetDb().
Where("project_id = ?", projectID).
Order("created_at DESC")
if beforeDate != nil {
query = query.Where("created_at < ?", *beforeDate)
}
err := query.
Limit(limit).
Offset(offset).
Find(&auditLogs).Error
return auditLogs, err
}
func (r *AuditLogRepository) CountGlobal(beforeDate *time.Time) (int64, error) {
var count int64
query := storage.GetDb().Model(&AuditLog{})
if beforeDate != nil {
query = query.Where("created_at < ?", *beforeDate)
}
err := query.Count(&count).Error
return count, err
}
```
------ backend/internal/features/audit_logs/service.go ------
```
package audit_logs
import (
"errors"
"log/slog"
"time"
user_enums "databasus-backend/internal/features/users/enums"
user_models "databasus-backend/internal/features/users/models"
"github.com/google/uuid"
)
type AuditLogService struct {
auditLogRepository *AuditLogRepository
logger *slog.Logger
}
func (s *AuditLogService) WriteAuditLog(
message string,
userID *uuid.UUID,
projectID *uuid.UUID,
) {
auditLog := &AuditLog{
UserID: userID,
ProjectID: projectID,
Message: message,
CreatedAt: time.Now().UTC(),
}
err := s.auditLogRepository.Create(auditLog)
if err != nil {
s.logger.Error("failed to create audit log", "error", err)
return
}
}
func (s *AuditLogService) CreateAuditLog(auditLog *AuditLog) error {
return s.auditLogRepository.Create(auditLog)
}
func (s *AuditLogService) GetGlobalAuditLogs(
user *user_models.User,
request *GetAuditLogsRequest,
) (*GetAuditLogsResponse, error) {
if user.Role != user_enums.UserRoleAdmin {
return nil, errors.New("only administrators can view global audit logs")
}
limit := request.Limit
if limit <= 0 || limit > 1000 {
limit = 100
}
offset := max(request.Offset, 0)
auditLogs, err := s.auditLogRepository.GetGlobal(limit, offset, request.BeforeDate)
if err != nil {
return nil, err
}
total, err := s.auditLogRepository.CountGlobal(request.BeforeDate)
if err != nil {
return nil, err
}
return &GetAuditLogsResponse{
AuditLogs: auditLogs,
Total: total,
Limit: limit,
Offset: offset,
}, nil
}
func (s *AuditLogService) GetUserAuditLogs(
targetUserID uuid.UUID,
user *user_models.User,
request *GetAuditLogsRequest,
) (*GetAuditLogsResponse, error) {
// Users can view their own logs, ADMIN can view any user's logs
if user.Role != user_enums.UserRoleAdmin && user.ID != targetUserID {
return nil, errors.New("insufficient permissions to view user audit logs")
}
limit := request.Limit
if limit <= 0 || limit > 1000 {
limit = 100
}
offset := max(request.Offset, 0)
auditLogs, err := s.auditLogRepository.GetByUser(targetUserID, limit, offset, request.BeforeDate)
if err != nil {
return nil, err
}
return &GetAuditLogsResponse{
AuditLogs: auditLogs,
Total: int64(len(auditLogs)),
Limit: limit,
Offset: offset,
}, nil
}
func (s *AuditLogService) GetProjectAuditLogs(
projectID uuid.UUID,
request *GetAuditLogsRequest,
) (*GetAuditLogsResponse, error) {
limit := request.Limit
if limit <= 0 || limit > 1000 {
limit = 100
}
offset := max(request.Offset, 0)
auditLogs, err := s.auditLogRepository.GetByProject(projectID, limit, offset, request.BeforeDate)
if err != nil {
return nil, err
}
return &GetAuditLogsResponse{
AuditLogs: auditLogs,
Total: int64(len(auditLogs)),
Limit: limit,
Offset: offset,
}, nil
}
```
------ backend/internal/features/audit_logs/service_test.go ------
```
package audit_logs
import (
"testing"
"time"
user_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
)
func Test_AuditLogs_ProjectSpecificLogs(t *testing.T) {
service := GetAuditLogService()
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
project1ID, project2ID := uuid.New(), uuid.New()
// Create test logs for projects
createAuditLog(service, "Test project1 log first", &user1.UserID, &project1ID)
createAuditLog(service, "Test project1 log second", &user2.UserID, &project1ID)
createAuditLog(service, "Test project2 log first", &user1.UserID, &project2ID)
createAuditLog(service, "Test project2 log second", &user2.UserID, &project2ID)
createAuditLog(service, "Test no project log", &user1.UserID, nil)
request := &GetAuditLogsRequest{Limit: 10, Offset: 0}
// Test project 1 logs
project1Response, err := service.GetProjectAuditLogs(project1ID, request)
assert.NoError(t, err)
assert.Equal(t, 2, len(project1Response.AuditLogs))
messages := extractMessages(project1Response.AuditLogs)
assert.Contains(t, messages, "Test project1 log first")
assert.Contains(t, messages, "Test project1 log second")
for _, log := range project1Response.AuditLogs {
assert.Equal(t, &project1ID, log.ProjectID)
}
// Test project 2 logs
project2Response, err := service.GetProjectAuditLogs(project2ID, request)
assert.NoError(t, err)
assert.Equal(t, 2, len(project2Response.AuditLogs))
messages2 := extractMessages(project2Response.AuditLogs)
assert.Contains(t, messages2, "Test project2 log first")
assert.Contains(t, messages2, "Test project2 log second")
// Test pagination
limitedResponse, err := service.GetProjectAuditLogs(project1ID,
&GetAuditLogsRequest{Limit: 1, Offset: 0})
assert.NoError(t, err)
assert.Equal(t, 1, len(limitedResponse.AuditLogs))
assert.Equal(t, 1, limitedResponse.Limit)
// Test beforeDate filter
beforeTime := time.Now().UTC().Add(-1 * time.Minute)
filteredResponse, err := service.GetProjectAuditLogs(project1ID,
&GetAuditLogsRequest{Limit: 10, BeforeDate: &beforeTime})
assert.NoError(t, err)
for _, log := range filteredResponse.AuditLogs {
assert.True(t, log.CreatedAt.Before(beforeTime))
}
}
func createAuditLog(service *AuditLogService, message string, userID, projectID *uuid.UUID) {
service.WriteAuditLog(message, userID, projectID)
}
func extractMessages(logs []*AuditLog) []string {
messages := make([]string, len(logs))
for i, log := range logs {
messages[i] = log.Message
}
return messages
}
func createTimedLog(db *gorm.DB, userID *uuid.UUID, message string, createdAt time.Time) {
log := &AuditLog{
ID: uuid.New(),
UserID: userID,
Message: message,
CreatedAt: createdAt,
}
db.Create(log)
}
```

View File

@@ -1,74 +0,0 @@
---
description:
globs:
alwaysApply: true
---
For DI files use implicit fields declaration styles (espesially
for controllers, services, repositories, use cases, etc., not simple
data structures).
So, instead of:
var orderController = &OrderController{
orderService: orderService,
botUserService: bot_users.GetBotUserService(),
botService: bots.GetBotService(),
userService: users.GetUserService(),
}
Use:
var orderController = &OrderController{
orderService,
bot_users.GetBotUserService(),
bots.GetBotService(),
users.GetUserService(),
}
This is needed to avoid forgetting to update DI style
when we add new dependency.
---
Please force such usage if file look like this (see some
services\controllers\repos definitions and getters):
var orderBackgroundService = &OrderBackgroundService{
orderService: orderService,
orderPaymentRepository: orderPaymentRepository,
botService: bots.GetBotService(),
paymentSettingsService: payment_settings.GetPaymentSettingsService(),
orderSubscriptionListeners: []OrderSubscriptionListener{},
}
var orderController = &OrderController{
orderService: orderService,
botUserService: bot_users.GetBotUserService(),
botService: bots.GetBotService(),
userService: users.GetUserService(),
}
func GetUniquePaymentRepository() *repositories.UniquePaymentRepository {
return uniquePaymentRepository
}
func GetOrderPaymentRepository() *repositories.OrderPaymentRepository {
return orderPaymentRepository
}
func GetOrderService() *OrderService {
return orderService
}
func GetOrderController() *OrderController {
return orderController
}
func GetOrderBackgroundService() *OrderBackgroundService {
return orderBackgroundService
}
func GetOrderRepository() *repositories.OrderRepository {
return orderRepository
}

View File

@@ -1,27 +0,0 @@
---
description:
globs:
alwaysApply: true
---
When writting migrations:
- write them for PostgreSQL
- for PRIMARY UUID keys use gen_random_uuid()
- for time use TIMESTAMPTZ (timestamp with zone)
- split table, constraint and indexes declaration (table first, them other one by one)
- format SQL in pretty way (add spaces, align columns types), constraints split by lines. The example:
CREATE TABLE marketplace_info (
bot_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
title TEXT NOT NULL,
description TEXT NOT NULL,
short_description TEXT NOT NULL,
tutorial_url TEXT,
info_order BIGINT NOT NULL DEFAULT 0,
is_published BOOLEAN NOT NULL DEFAULT FALSE
);
ALTER TABLE marketplace_info_images
ADD CONSTRAINT fk_marketplace_info_images_bot_id
FOREIGN KEY (bot_id)
REFERENCES marketplace_info (bot_id);

View File

@@ -1,12 +0,0 @@
---
description:
globs:
alwaysApply: true
---
When applying changes, do not forget to refactor old code.
You can shortify, make more readable, improve code quality, etc.
Common logic can be extracted to functions, constants, files, etc.
After each large change with more than ~50-100 lines of code - always run `make lint` (from backend root folder) and, if you change frontend, run `npm run format` (from frontend root folder).

View File

@@ -1,147 +0,0 @@
---
description:
globs:
alwaysApply: true
---
After writing tests, always launch them and verify that they pass.
## Test Naming Format
Use these naming patterns:
- `Test_WhatWeDo_WhatWeExpect`
- `Test_WhatWeDo_WhichConditions_WhatWeExpect`
## Examples from Real Codebase:
- `Test_CreateApiKey_WhenUserIsProjectOwner_ApiKeyCreated`
- `Test_UpdateProject_WhenUserIsProjectAdmin_ProjectUpdated`
- `Test_DeleteApiKey_WhenUserIsProjectMember_ReturnsForbidden`
- `Test_GetProjectAuditLogs_WithDifferentUserRoles_EnforcesPermissionsCorrectly`
- `Test_ProjectLifecycleE2E_CompletesSuccessfully`
## Testing Philosophy
**Prefer Controllers Over Unit Tests:**
- Test through HTTP endpoints via controllers whenever possible
- Avoid testing repositories, services in isolation - test via API instead
- Only use unit tests for complex model logic when no API exists
- Name test files `controller_test.go` or `service_test.go`, not `integration_test.go`
**Extract Common Logic to Testing Utilities:**
- Create `testing.go` or `testing/testing.go` files for shared test utilities
- Extract router creation, user setup, models creation helpers (in API, not just structs creation)
- Reuse common patterns across different test files
**Refactor Existing Tests:**
- When working with existing tests, always look for opportunities to refactor and improve
- Extract repetitive setup code to common utilities
- Simplify complex tests by breaking them into smaller, focused tests
- Replace inline test data creation with reusable helper functions
- Consolidate similar test patterns across different test files
- Make tests more readable and maintainable for other developers
## Testing Utilities Structure
**Create `testing.go` or `testing/testing.go` files with common utilities:**
```go
package projects_testing
// CreateTestRouter creates unified router for all controllers
func CreateTestRouter(controllers ...ControllerInterface) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
v1 := router.Group("/api/v1")
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
for _, controller := range controllers {
if routerGroup, ok := protected.(*gin.RouterGroup); ok {
controller.RegisterRoutes(routerGroup)
}
}
return router
}
// CreateTestProjectViaAPI creates project through HTTP API
func CreateTestProjectViaAPI(name string, owner *users_dto.SignInResponseDTO, router *gin.Engine) (*projects_models.Project, string) {
request := projects_dto.CreateProjectRequestDTO{Name: name}
w := MakeAPIRequest(router, "POST", "/api/v1/projects", "Bearer "+owner.Token, request)
// Handle response...
return project, owner.Token
}
// AddMemberToProject adds member via API call
func AddMemberToProject(project *projects_models.Project, member *users_dto.SignInResponseDTO, role users_enums.ProjectRole, ownerToken string, router *gin.Engine) {
// Implementation...
}
```
## Controller Test Examples
**Permission-based testing:**
```go
func Test_CreateApiKey_WhenUserIsProjectOwner_ApiKeyCreated(t *testing.T) {
router := CreateApiKeyTestRouter(GetProjectController(), GetMembershipController())
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
project, _ := projects_testing.CreateTestProjectViaAPI("Test Project", owner, router)
request := CreateApiKeyRequestDTO{Name: "Test API Key"}
var response ApiKey
test_utils.MakePostRequestAndUnmarshal(t, router, "/api/v1/projects/api-keys/"+project.ID.String(), "Bearer "+owner.Token, request, http.StatusOK, &response)
assert.Equal(t, "Test API Key", response.Name)
assert.NotEmpty(t, response.Token)
}
```
**Cross-project security testing:**
```go
func Test_UpdateApiKey_WithApiKeyFromDifferentProject_ReturnsBadRequest(t *testing.T) {
router := CreateApiKeyTestRouter(GetProjectController(), GetMembershipController())
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
project1, _ := projects_testing.CreateTestProjectViaAPI("Project 1", owner1, router)
project2, _ := projects_testing.CreateTestProjectViaAPI("Project 2", owner2, router)
apiKey := CreateTestApiKey("Cross Project Key", project1.ID, owner1.Token, router)
// Try to update via different project endpoint
request := UpdateApiKeyRequestDTO{Name: &"Hacked Key"}
resp := test_utils.MakePutRequest(t, router, "/api/v1/projects/api-keys/"+project2.ID.String()+"/"+apiKey.ID.String(), "Bearer "+owner2.Token, request, http.StatusBadRequest)
assert.Contains(t, string(resp.Body), "API key does not belong to this project")
}
```
**E2E lifecycle testing:**
```go
func Test_ProjectLifecycleE2E_CompletesSuccessfully(t *testing.T) {
router := projects_testing.CreateTestRouter(GetProjectController(), GetMembershipController())
// 1. Create project
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
project := projects_testing.CreateTestProject("E2E Project", owner, router)
// 2. Add member
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
projects_testing.AddMemberToProject(project, member, users_enums.ProjectRoleMember, owner.Token, router)
// 3. Promote to admin
projects_testing.ChangeMemberRole(project, member.UserID, users_enums.ProjectRoleAdmin, owner.Token, router)
// 4. Transfer ownership
projects_testing.TransferProjectOwnership(project, member.UserID, owner.Token, router)
// 5. Verify new owner can manage project
finalProject := projects_testing.GetProject(project.ID, member.Token, router)
assert.Equal(t, project.ID, finalProject.ID)
}
```

View File

@@ -1,6 +0,0 @@
---
description:
globs:
alwaysApply: true
---
Always use time.Now().UTC() instead of time.Now()

View File

@@ -2,8 +2,10 @@
DEV_DB_NAME=databasus
DEV_DB_USERNAME=postgres
DEV_DB_PASSWORD=Q1234567
#app
# app
ENV_MODE=development
# logging
SHOW_DB_INSTALLATION_VERIFICATION_LOGS=true
# db
DATABASE_DSN=host=dev-db user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
DATABASE_URL=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable
@@ -11,6 +13,12 @@ DATABASE_URL=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable
GOOSE_DRIVER=postgres
GOOSE_DBSTRING=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable
GOOSE_MIGRATION_DIR=./migrations
# valkey
VALKEY_HOST=127.0.0.1
VALKEY_PORT=6379
VALKEY_USERNAME=
VALKEY_PASSWORD=
VALKEY_IS_SSL=false
# testing
# to get Google Drive env variables: add storage in UI and copy data from added storage here
TEST_GOOGLE_DRIVE_CLIENT_ID=

View File

@@ -10,4 +10,10 @@ DATABASE_URL=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disab
# migrations
GOOSE_DRIVER=postgres
GOOSE_DBSTRING=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disable
GOOSE_MIGRATION_DIR=./migrations
GOOSE_MIGRATION_DIR=./migrations
# valkey
VALKEY_HOST=127.0.0.1
VALKEY_PORT=6379
VALKEY_USERNAME=
VALKEY_PASSWORD=
VALKEY_IS_SSL=false

3
backend/.gitignore vendored
View File

@@ -17,4 +17,5 @@ ui/build/*
pgdata-for-restore/
temp/
cmd.exe
temp/
temp/
valkey-data/

View File

@@ -2,10 +2,10 @@ run:
go run cmd/main.go
test:
go test -p=1 -count=1 -failfast -timeout 10m ./internal/...
go test -p=1 -count=1 -failfast -timeout 15m ./internal/...
lint:
golangci-lint fmt && golangci-lint run
golangci-lint fmt ./cmd/... ./internal/... && golangci-lint run ./cmd/... ./internal/...
migration-create:
goose create $(name) sql

View File

@@ -15,6 +15,8 @@ import (
"databasus-backend/internal/config"
"databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/features/backups/backups/backuping"
backups_download "databasus-backend/internal/features/backups/backups/download"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/disk"
@@ -25,10 +27,13 @@ import (
"databasus-backend/internal/features/restores"
"databasus-backend/internal/features/storages"
system_healthcheck "databasus-backend/internal/features/system/healthcheck"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
task_registry "databasus-backend/internal/features/tasks/registry"
users_controllers "databasus-backend/internal/features/users/controllers"
users_middleware "databasus-backend/internal/features/users/middleware"
users_services "databasus-backend/internal/features/users/services"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
cache_utils "databasus-backend/internal/util/cache"
env_utils "databasus-backend/internal/util/env"
files_utils "databasus-backend/internal/util/files"
"databasus-backend/internal/util/logger"
@@ -52,7 +57,23 @@ import (
func main() {
log := logger.GetLogger()
runMigrations(log)
cache_utils.TestCacheConnection()
if config.GetEnv().IsPrimaryNode {
log.Info("Clearing cache...")
err := cache_utils.ClearAllCache()
if err != nil {
log.Error("Failed to clear cache", "error", err)
os.Exit(1)
}
}
if config.GetEnv().IsPrimaryNode {
runMigrations(log)
} else {
log.Info("Skipping migrations (IS_PRIMARY_NODE is false)")
}
// create directories that used for backups and restore
err := files_utils.EnsureDirectories([]string{
@@ -96,7 +117,9 @@ func main() {
enableCors(ginApp)
setUpRoutes(ginApp)
setUpDependencies()
runBackgroundTasks(log)
mountFrontend(ginApp)
startServerWithGracefulShutdown(log, ginApp)
@@ -219,35 +242,68 @@ func setUpDependencies() {
notifiers.SetupDependencies()
storages.SetupDependencies()
backups_config.SetupDependencies()
task_cancellation.SetupDependencies()
}
func runBackgroundTasks(log *slog.Logger) {
log.Info("Preparing to run background tasks...")
// Create context that will be cancelled on shutdown
ctx, cancel := context.WithCancel(context.Background())
// Set up signal handling for graceful shutdown
quit := make(chan os.Signal, 1)
signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
go func() {
<-quit
log.Info("Shutdown signal received, cancelling all background tasks")
cancel()
}()
err := files_utils.CleanFolder(config.GetEnv().TempFolder)
if err != nil {
log.Error("Failed to clean temp folder", "error", err)
}
go runWithPanicLogging(log, "backup background service", func() {
backups.GetBackupBackgroundService().Run()
})
if config.GetEnv().IsPrimaryNode {
log.Info("Starting primary node background tasks...")
go runWithPanicLogging(log, "restore background service", func() {
restores.GetRestoreBackgroundService().Run()
})
go runWithPanicLogging(log, "backup background service", func() {
backuping.GetBackupsScheduler().Run(ctx)
})
go runWithPanicLogging(log, "healthcheck attempt background service", func() {
healthcheck_attempt.GetHealthcheckAttemptBackgroundService().Run()
})
go runWithPanicLogging(log, "restore background service", func() {
restores.GetRestoreBackgroundService().Run(ctx)
})
go runWithPanicLogging(log, "audit log cleanup background service", func() {
audit_logs.GetAuditLogBackgroundService().Run()
})
go runWithPanicLogging(log, "healthcheck attempt background service", func() {
healthcheck_attempt.GetHealthcheckAttemptBackgroundService().Run(ctx)
})
go runWithPanicLogging(log, "download token cleanup background service", func() {
backups.GetDownloadTokenBackgroundService().Run()
})
go runWithPanicLogging(log, "audit log cleanup background service", func() {
audit_logs.GetAuditLogBackgroundService().Run(ctx)
})
go runWithPanicLogging(log, "download token cleanup background service", func() {
backups_download.GetDownloadTokenBackgroundService().Run(ctx)
})
go runWithPanicLogging(log, "task nodes registry background service", func() {
task_registry.GetTaskNodesRegistry().Run(ctx)
})
} else {
log.Info("Skipping primary node tasks as not primary node")
}
if config.GetEnv().IsBackupNode {
log.Info("Starting backup node background tasks...")
go runWithPanicLogging(log, "backup node", func() {
backuping.GetBackuperNode().Run(ctx)
})
} else {
log.Info("Skipping backup node tasks as not backup node")
}
}
func runWithPanicLogging(log *slog.Logger, serviceName string, fn func()) {
@@ -290,16 +346,13 @@ func generateSwaggerDocs(log *slog.Logger) {
func runMigrations(log *slog.Logger) {
log.Info("Running database migrations...")
cmd := exec.Command("goose", "up")
cmd := exec.Command("goose", "-dir", "./migrations", "up")
cmd.Env = append(
os.Environ(),
"GOOSE_DRIVER=postgres",
"GOOSE_DBSTRING="+config.GetEnv().DatabaseDsn,
)
// Set the working directory to where migrations are located
cmd.Dir = "./migrations"
output, err := cmd.CombinedOutput()
if err != nil {
log.Error("Failed to run migrations", "error", err, "output", string(output))

View File

@@ -19,6 +19,21 @@ services:
command: -p 5437
shm_size: 10gb
# Valkey for caching
dev-valkey:
image: valkey/valkey:9.0.1-alpine
ports:
- "${VALKEY_PORT:-6379}:6379"
volumes:
- ./valkey-data:/data
container_name: dev-valkey
healthcheck:
test: ["CMD", "valkey-cli", "ping"]
interval: 10s
timeout: 5s
retries: 5
start_period: 20s
# Test MinIO container
test-minio:
image: minio/minio:latest

View File

@@ -1,6 +1,6 @@
module databasus-backend
go 1.24.4
go 1.24.9
require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0
@@ -25,9 +25,9 @@ require (
github.com/swaggo/files v1.0.1
github.com/swaggo/gin-swagger v1.6.0
github.com/swaggo/swag v1.16.4
github.com/valkey-io/valkey-go v1.0.70
go.mongodb.org/mongo-driver v1.17.6
golang.org/x/crypto v0.46.0
golang.org/x/time v0.14.0
gorm.io/driver/postgres v1.5.11
gorm.io/gorm v1.26.1
)
@@ -185,6 +185,7 @@ require (
go.yaml.in/yaml/v2 v2.4.3 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/term v0.38.0 // indirect
golang.org/x/time v0.14.0 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
gopkg.in/validator.v2 v2.0.1 // indirect
moul.io/http2curl/v2 v2.3.0 // indirect
@@ -269,7 +270,7 @@ require (
go.opentelemetry.io/otel/metric v1.38.0 // indirect
go.opentelemetry.io/otel/trace v1.38.0 // indirect
golang.org/x/arch v0.17.0 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/net v0.48.0 // indirect
golang.org/x/oauth2 v0.33.0
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.39.0 // indirect

View File

@@ -539,8 +539,8 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
github.com/onsi/ginkgo/v2 v2.17.3 h1:oJcvKpIb7/8uLpDDtnQuf18xVnwKp8DTD7DQ6gTd/MU=
github.com/onsi/ginkgo/v2 v2.17.3/go.mod h1:nP2DPOQoNsQmsVyv5rDA8JkXQoCs6goXIvr/PRJ1eCc=
github.com/onsi/gomega v1.37.0 h1:CdEG8g0S133B4OswTDC/5XPSzE1OeP29QOioj2PID2Y=
github.com/onsi/gomega v1.37.0/go.mod h1:8D9+Txp43QWKhM24yyOBEdpkzN8FvJyAwecBgsU4KU0=
github.com/onsi/gomega v1.38.3 h1:eTX+W6dobAYfFeGC2PV6RwXRu/MyT+cQguijutvkpSM=
github.com/onsi/gomega v1.38.3/go.mod h1:ZCU1pkQcXDO5Sl9/VVEGlDyp+zm0m1cmeG5TOzLgdh4=
github.com/oracle/oci-go-sdk/v65 v65.104.0 h1:l9awEvzWvxmYhy/97A0hZ87pa7BncYXmcO/S8+rvgK0=
github.com/oracle/oci-go-sdk/v65 v65.104.0/go.mod h1:oB8jFGVc/7/zJ+DbleE8MzGHjhs2ioCz5stRTdZdIcY=
github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg=
@@ -660,6 +660,8 @@ github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY=
github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
github.com/unknwon/goconfig v1.0.0 h1:rS7O+CmUdli1T+oDm7fYj1MwqNWtEJfNj+FqcUHML8U=
github.com/unknwon/goconfig v1.0.0/go.mod h1:qu2ZQ/wcC/if2u32263HTVC39PeOQRSmidQk3DuDFQ8=
github.com/valkey-io/valkey-go v1.0.70 h1:mjYNT8qiazxDAJ0QNQ8twWT/YFOkOoRd40ERV2mB49Y=
github.com/valkey-io/valkey-go v1.0.70/go.mod h1:VGhZ6fs68Qrn2+OhH+6waZH27bjpgQOiLyUQyXuYK5k=
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM=
@@ -720,6 +722,8 @@ go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/arch v0.17.0 h1:4O3dfLzd+lQewptAHqjewQZQDyEdejz3VwgeYwkZneU=
golang.org/x/arch v0.17.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
@@ -818,8 +822,8 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=

View File

@@ -9,6 +9,7 @@ import (
"strings"
"sync"
"github.com/google/uuid"
"github.com/ilyakaznacheev/cleanenv"
"github.com/joho/godotenv"
)
@@ -29,6 +30,14 @@ type EnvVariables struct {
MariadbInstallDir string `env:"MARIADB_INSTALL_DIR"`
MongodbInstallDir string `env:"MONGODB_INSTALL_DIR"`
ShowDbInstallationVerificationLogs bool `env:"SHOW_DB_INSTALLATION_VERIFICATION_LOGS"`
NodeID string
IsManyNodesMode bool `env:"IS_MANY_NODES_MODE"`
IsPrimaryNode bool `env:"IS_PRIMARY_NODE"`
IsBackupNode bool `env:"IS_BACKUP_NODE"`
NodeNetworkThroughputMBs int `env:"NODE_NETWORK_THROUGHPUT_MBPS"`
DataFolder string
TempFolder string
SecretKeyPath string
@@ -79,6 +88,13 @@ type EnvVariables struct {
TestMongodb70Port string `env:"TEST_MONGODB_70_PORT"`
TestMongodb82Port string `env:"TEST_MONGODB_82_PORT"`
// Valkey
ValkeyHost string `env:"VALKEY_HOST" required:"true"`
ValkeyPort string `env:"VALKEY_PORT" required:"true"`
ValkeyUsername string `env:"VALKEY_USERNAME"`
ValkeyPassword string `env:"VALKEY_PASSWORD"`
ValkeyIsSsl bool `env:"VALKEY_IS_SSL" required:"true"`
// oauth
GitHubClientID string `env:"GITHUB_CLIENT_ID"`
GitHubClientSecret string `env:"GITHUB_CLIENT_SECRET"`
@@ -155,6 +171,11 @@ func loadEnvVariables() {
os.Exit(1)
}
// Set default value for ShowDbInstallationVerificationLogs if not defined
if os.Getenv("SHOW_DB_INSTALLATION_VERIFICATION_LOGS") == "" {
env.ShowDbInstallationVerificationLogs = true
}
for _, arg := range os.Args {
if strings.Contains(arg, "test") {
env.IsTesting = true
@@ -178,16 +199,56 @@ func loadEnvVariables() {
log.Info("ENV_MODE loaded", "mode", env.EnvMode)
env.PostgresesInstallDir = filepath.Join(backendRoot, "tools", "postgresql")
tools.VerifyPostgresesInstallation(log, env.EnvMode, env.PostgresesInstallDir)
tools.VerifyPostgresesInstallation(
log,
env.EnvMode,
env.PostgresesInstallDir,
env.ShowDbInstallationVerificationLogs,
)
env.MysqlInstallDir = filepath.Join(backendRoot, "tools", "mysql")
tools.VerifyMysqlInstallation(log, env.EnvMode, env.MysqlInstallDir)
tools.VerifyMysqlInstallation(
log,
env.EnvMode,
env.MysqlInstallDir,
env.ShowDbInstallationVerificationLogs,
)
env.MariadbInstallDir = filepath.Join(backendRoot, "tools", "mariadb")
tools.VerifyMariadbInstallation(log, env.EnvMode, env.MariadbInstallDir)
tools.VerifyMariadbInstallation(
log,
env.EnvMode,
env.MariadbInstallDir,
env.ShowDbInstallationVerificationLogs,
)
env.MongodbInstallDir = filepath.Join(backendRoot, "tools", "mongodb")
tools.VerifyMongodbInstallation(log, env.EnvMode, env.MongodbInstallDir)
tools.VerifyMongodbInstallation(
log,
env.EnvMode,
env.MongodbInstallDir,
env.ShowDbInstallationVerificationLogs,
)
env.NodeID = uuid.New().String()
if env.NodeNetworkThroughputMBs == 0 {
env.NodeNetworkThroughputMBs = 125 // 1 Gbit/s
}
if !env.IsManyNodesMode {
env.IsPrimaryNode = true
env.IsBackupNode = true
}
// Valkey
if env.ValkeyHost == "" {
log.Error("VALKEY_HOST is empty")
os.Exit(1)
}
if env.ValkeyPort == "" {
log.Error("VALKEY_PORT is empty")
os.Exit(1)
}
// Store the data and temp folders one level below the root
// (projectRoot/databasus-data -> /databasus-data)

View File

@@ -1,7 +1,7 @@
package audit_logs
import (
"databasus-backend/internal/config"
"context"
"log/slog"
"time"
)
@@ -11,23 +11,25 @@ type AuditLogBackgroundService struct {
logger *slog.Logger
}
func (s *AuditLogBackgroundService) Run() {
func (s *AuditLogBackgroundService) Run(ctx context.Context) {
s.logger.Info("Starting audit log cleanup background service")
if config.IsShouldShutdown() {
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
if config.IsShouldShutdown() {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.cleanOldAuditLogs(); err != nil {
s.logger.Error("Failed to clean old audit logs", "error", err)
}
}
if err := s.cleanOldAuditLogs(); err != nil {
s.logger.Error("Failed to clean old audit logs", "error", err)
}
time.Sleep(1 * time.Hour)
}
}

View File

@@ -1,254 +0,0 @@
package backups
import (
"databasus-backend/internal/config"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/storages"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/period"
"log/slog"
"time"
)
type BackupBackgroundService struct {
backupService *BackupService
backupRepository *BackupRepository
backupConfigService *backups_config.BackupConfigService
storageService *storages.StorageService
lastBackupTime time.Time
logger *slog.Logger
}
func (s *BackupBackgroundService) Run() {
s.lastBackupTime = time.Now().UTC()
if err := s.failBackupsInProgress(); err != nil {
s.logger.Error("Failed to fail backups in progress", "error", err)
panic(err)
}
if config.IsShouldShutdown() {
return
}
for {
if config.IsShouldShutdown() {
return
}
if err := s.cleanOldBackups(); err != nil {
s.logger.Error("Failed to clean old backups", "error", err)
}
if err := s.runPendingBackups(); err != nil {
s.logger.Error("Failed to run pending backups", "error", err)
}
s.lastBackupTime = time.Now().UTC()
time.Sleep(1 * time.Minute)
}
}
func (s *BackupBackgroundService) IsBackupsWorkerRunning() bool {
// if last backup time is more than 5 minutes ago, return false
return s.lastBackupTime.After(time.Now().UTC().Add(-5 * time.Minute))
}
func (s *BackupBackgroundService) failBackupsInProgress() error {
backupsInProgress, err := s.backupRepository.FindByStatus(BackupStatusInProgress)
if err != nil {
return err
}
for _, backup := range backupsInProgress {
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(backup.DatabaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
continue
}
failMessage := "Backup failed due to application restart"
backup.FailMessage = &failMessage
backup.Status = BackupStatusFailed
backup.BackupSizeMb = 0
s.backupService.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupFailed,
&failMessage,
)
if err := s.backupRepository.Save(backup); err != nil {
return err
}
}
return nil
}
func (s *BackupBackgroundService) cleanOldBackups() error {
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
backupStorePeriod := backupConfig.StorePeriod
if backupStorePeriod == period.PeriodForever {
continue
}
storeDuration := backupStorePeriod.ToDuration()
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
oldBackups, err := s.backupRepository.FindBackupsBeforeDate(
backupConfig.DatabaseID,
dateBeforeBackupsShouldBeDeleted,
)
if err != nil {
s.logger.Error(
"Failed to find old backups for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
for _, backup := range oldBackups {
storage, err := s.storageService.GetStorageByID(backup.StorageID)
if err != nil {
s.logger.Error(
"Failed to get storage by ID",
"storageId",
backup.StorageID,
"error",
err,
)
continue
}
encryptor := encryption.GetFieldEncryptor()
err = storage.DeleteFile(encryptor, backup.ID)
if err != nil {
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
}
if err := s.backupRepository.DeleteByID(backup.ID); err != nil {
s.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
continue
}
s.logger.Info(
"Deleted old backup",
"backupId",
backup.ID,
"databaseId",
backupConfig.DatabaseID,
)
}
}
return nil
}
func (s *BackupBackgroundService) runPendingBackups() error {
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
if backupConfig.BackupInterval == nil {
continue
}
lastBackup, err := s.backupRepository.FindLastByDatabaseID(backupConfig.DatabaseID)
if err != nil {
s.logger.Error(
"Failed to get last backup for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
var lastBackupTime *time.Time
if lastBackup != nil {
lastBackupTime = &lastBackup.CreatedAt
}
remainedBackupTryCount := s.GetRemainedBackupTryCount(lastBackup)
if backupConfig.BackupInterval.ShouldTriggerBackup(time.Now().UTC(), lastBackupTime) ||
remainedBackupTryCount > 0 {
s.logger.Info(
"Triggering scheduled backup",
"databaseId",
backupConfig.DatabaseID,
"intervalType",
backupConfig.BackupInterval.Interval,
)
go s.backupService.MakeBackup(backupConfig.DatabaseID, remainedBackupTryCount == 1)
s.logger.Info(
"Successfully triggered scheduled backup",
"databaseId",
backupConfig.DatabaseID,
)
}
}
return nil
}
// GetRemainedBackupTryCount returns the number of remaining backup tries for a given backup.
// If the backup is not failed or the backup config does not allow retries, it returns 0.
// If the backup is failed and the backup config allows retries, it returns the number of remaining tries.
// If the backup is failed and the backup config does not allow retries, it returns 0.
func (s *BackupBackgroundService) GetRemainedBackupTryCount(lastBackup *Backup) int {
if lastBackup == nil {
return 0
}
if lastBackup.Status != BackupStatusFailed {
return 0
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(lastBackup.DatabaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
return 0
}
if !backupConfig.IsRetryIfFailed {
return 0
}
maxFailedTriesCount := backupConfig.MaxFailedTriesCount
lastBackups, err := s.backupRepository.FindByDatabaseIDWithLimit(
lastBackup.DatabaseID,
maxFailedTriesCount,
)
if err != nil {
s.logger.Error("Failed to find last backups by database ID", "error", err)
return 0
}
lastFailedBackups := make([]*Backup, 0)
for _, backup := range lastBackups {
if backup.Status == BackupStatusFailed {
lastFailedBackups = append(lastFailedBackups, backup)
}
}
return maxFailedTriesCount - len(lastFailedBackups)
}

View File

@@ -1,389 +0,0 @@
package backups
import (
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/util/period"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
// setup data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups for the database
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// add old backup
backupRepository.Save(&Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusCompleted,
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
})
GetBackupBackgroundService().runPendingBackups()
// Wait for backup to complete (runs in goroutine)
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
// assertions
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 2)
}
func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
// setup data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups for the database
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// add recent backup (1 hour ago)
backupRepository.Save(&Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusCompleted,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
})
GetBackupBackgroundService().runPendingBackups()
time.Sleep(100 * time.Millisecond)
// assertions
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1) // Should still be 1 backup, no new backup created
}
func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T) {
// setup data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups for the database with retries disabled
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
backupConfig.IsRetryIfFailed = false
backupConfig.MaxFailedTriesCount = 0
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// add failed backup
failMessage := "backup failed"
backupRepository.Save(&Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusFailed,
FailMessage: &failMessage,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
})
GetBackupBackgroundService().runPendingBackups()
time.Sleep(100 * time.Millisecond)
// assertions
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1) // Should still be 1 backup, no retry attempted
}
func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
// setup data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups for the database with retries enabled
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
backupConfig.IsRetryIfFailed = true
backupConfig.MaxFailedTriesCount = 3
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// add failed backup
failMessage := "backup failed"
backupRepository.Save(&Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusFailed,
FailMessage: &failMessage,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
})
GetBackupBackgroundService().runPendingBackups()
// Wait for backup to complete (runs in goroutine)
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
// assertions
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 2) // Should have 2 backups, retry was attempted
}
func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *testing.T) {
// setup data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups for the database with retries enabled
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
backupConfig.IsRetryIfFailed = true
backupConfig.MaxFailedTriesCount = 3
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
failMessage := "backup failed"
for i := 0; i < 3; i++ {
backupRepository.Save(&Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusFailed,
FailMessage: &failMessage,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
})
}
GetBackupBackgroundService().runPendingBackups()
time.Sleep(100 * time.Millisecond)
// assertions
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 3) // Should have 3 backups, not more than max
}
func Test_MakeBackgroundBackupWhenBakupsDisabled_BackupSkipped(t *testing.T) {
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = false
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// add old backup that would trigger new backup if enabled
backupRepository.Save(&Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusCompleted,
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
})
GetBackupBackgroundService().runPendingBackups()
time.Sleep(100 * time.Millisecond)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
}

View File

@@ -1,60 +0,0 @@
package backups
import (
"context"
"sync"
"github.com/google/uuid"
)
type BackupContextManager struct {
mu sync.RWMutex
cancelFuncs map[uuid.UUID]context.CancelFunc
cancelledBackups map[uuid.UUID]bool
}
func NewBackupContextManager() *BackupContextManager {
return &BackupContextManager{
cancelFuncs: make(map[uuid.UUID]context.CancelFunc),
cancelledBackups: make(map[uuid.UUID]bool),
}
}
func (m *BackupContextManager) RegisterBackup(backupID uuid.UUID, cancelFunc context.CancelFunc) {
m.mu.Lock()
defer m.mu.Unlock()
m.cancelFuncs[backupID] = cancelFunc
delete(m.cancelledBackups, backupID)
}
func (m *BackupContextManager) CancelBackup(backupID uuid.UUID) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.cancelledBackups[backupID] {
return nil
}
cancelFunc, exists := m.cancelFuncs[backupID]
if exists {
cancelFunc()
delete(m.cancelFuncs, backupID)
}
m.cancelledBackups[backupID] = true
return nil
}
func (m *BackupContextManager) IsCancelled(backupID uuid.UUID) bool {
m.mu.RLock()
defer m.mu.RUnlock()
return m.cancelledBackups[backupID]
}
func (m *BackupContextManager) UnregisterBackup(backupID uuid.UUID) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.cancelFuncs, backupID)
delete(m.cancelledBackups, backupID)
}

View File

@@ -0,0 +1,365 @@
package backuping
import (
"context"
"databasus-backend/internal/config"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/storages"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
task_registry "databasus-backend/internal/features/tasks/registry"
workspaces_services "databasus-backend/internal/features/workspaces/services"
util_encryption "databasus-backend/internal/util/encryption"
"errors"
"fmt"
"log/slog"
"slices"
"strings"
"time"
"github.com/google/uuid"
)
const (
heartbeatTickerInterval = 15 * time.Second
backuperHeathcheckThreshold = 5 * time.Minute
)
type BackuperNode struct {
databaseService *databases.DatabaseService
fieldEncryptor util_encryption.FieldEncryptor
workspaceService *workspaces_services.WorkspaceService
backupRepository *backups_core.BackupRepository
backupConfigService *backups_config.BackupConfigService
storageService *storages.StorageService
notificationSender backups_core.NotificationSender
backupCancelManager *tasks_cancellation.TaskCancelManager
tasksRegistry *task_registry.TaskNodesRegistry
logger *slog.Logger
createBackupUseCase backups_core.CreateBackupUsecase
nodeID uuid.UUID
lastHeartbeat time.Time
}
func (n *BackuperNode) Run(ctx context.Context) {
n.lastHeartbeat = time.Now().UTC()
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
backupNode := task_registry.TaskNode{
ID: n.nodeID,
ThroughputMBs: throughputMBs,
}
if err := n.tasksRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
n.logger.Error("Failed to register node in registry", "error", err)
panic(err)
}
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
n.MakeBackup(backupID, isCallNotifier)
if err := n.tasksRegistry.PublishTaskCompletion(n.nodeID.String(), backupID); err != nil {
n.logger.Error(
"Failed to publish backup completion",
"error",
err,
"backupID",
backupID,
)
}
}
if err := n.tasksRegistry.SubscribeNodeForTasksAssignment(n.nodeID.String(), backupHandler); err != nil {
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
panic(err)
}
defer func() {
if err := n.tasksRegistry.UnsubscribeNodeForTasksAssignments(); err != nil {
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
}
}()
ticker := time.NewTicker(heartbeatTickerInterval)
defer ticker.Stop()
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
for {
select {
case <-ctx.Done():
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
if err := n.tasksRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
n.logger.Error("Failed to unregister node from registry", "error", err)
}
return
case <-ticker.C:
n.sendHeartbeat(&backupNode)
}
}
}
func (n *BackuperNode) IsBackuperRunning() bool {
return n.lastHeartbeat.After(time.Now().UTC().Add(-backuperHeathcheckThreshold))
}
func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
backup, err := n.backupRepository.FindByID(backupID)
if err != nil {
n.logger.Error("Failed to get backup by ID", "backupId", backupID, "error", err)
return
}
databaseID := backup.DatabaseID
database, err := n.databaseService.GetDatabaseByID(databaseID)
if err != nil {
n.logger.Error("Failed to get database by ID", "databaseId", databaseID, "error", err)
return
}
backupConfig, err := n.backupConfigService.GetBackupConfigByDbId(databaseID)
if err != nil {
n.logger.Error("Failed to get backup config by database ID", "error", err)
return
}
if backupConfig.StorageID == nil {
n.logger.Error("Backup config storage ID is not defined")
return
}
storage, err := n.storageService.GetStorageByID(*backupConfig.StorageID)
if err != nil {
n.logger.Error("Failed to get storage by ID", "error", err)
return
}
start := time.Now().UTC()
backupProgressListener := func(
completedMBs float64,
) {
backup.BackupSizeMb = completedMBs
backup.BackupDurationMs = time.Since(start).Milliseconds()
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to update backup progress", "error", err)
}
}
ctx, cancel := context.WithCancel(context.Background())
n.backupCancelManager.RegisterTask(backup.ID, cancel)
defer n.backupCancelManager.UnregisterTask(backup.ID)
backupMetadata, err := n.createBackupUseCase.Execute(
ctx,
backup.ID,
backupConfig,
database,
storage,
backupProgressListener,
)
if err != nil {
errMsg := err.Error()
// Log detailed error information for debugging
n.logger.Error("Backup execution failed",
"backupId", backup.ID,
"databaseId", databaseID,
"databaseType", database.Type,
"storageId", storage.ID,
"storageType", storage.Type,
"error", err,
"errorMessage", errMsg,
)
// Check if backup was cancelled (not due to shutdown)
isCancelled := strings.Contains(errMsg, "backup cancelled") ||
strings.Contains(errMsg, "context canceled") ||
errors.Is(err, context.Canceled)
isShutdown := strings.Contains(errMsg, "shutdown")
if isCancelled && !isShutdown {
n.logger.Warn("Backup was cancelled by user or system",
"backupId", backup.ID,
"isCancelled", isCancelled,
"isShutdown", isShutdown,
)
backup.Status = backups_core.BackupStatusCanceled
backup.BackupDurationMs = time.Since(start).Milliseconds()
backup.BackupSizeMb = 0
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to save cancelled backup", "error", err)
}
// Delete partial backup from storage
storage, storageErr := n.storageService.GetStorageByID(backup.StorageID)
if storageErr == nil {
if deleteErr := storage.DeleteFile(n.fieldEncryptor, backup.ID); deleteErr != nil {
n.logger.Error(
"Failed to delete partial backup file",
"backupId",
backup.ID,
"error",
deleteErr,
)
}
}
return
}
backup.FailMessage = &errMsg
backup.Status = backups_core.BackupStatusFailed
backup.BackupDurationMs = time.Since(start).Milliseconds()
backup.BackupSizeMb = 0
if updateErr := n.databaseService.SetBackupError(databaseID, errMsg); updateErr != nil {
n.logger.Error(
"Failed to update database last backup time",
"databaseId",
databaseID,
"error",
updateErr,
)
}
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to save backup", "error", err)
}
n.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupFailed,
&errMsg,
)
return
}
backup.Status = backups_core.BackupStatusCompleted
backup.BackupDurationMs = time.Since(start).Milliseconds()
// Update backup with encryption metadata if provided
if backupMetadata != nil {
backup.EncryptionSalt = backupMetadata.EncryptionSalt
backup.EncryptionIV = backupMetadata.EncryptionIV
backup.Encryption = backupMetadata.Encryption
}
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to save backup", "error", err)
return
}
// Update database last backup time
now := time.Now().UTC()
if updateErr := n.databaseService.SetLastBackupTime(databaseID, now); updateErr != nil {
n.logger.Error(
"Failed to update database last backup time",
"databaseId",
databaseID,
"error",
updateErr,
)
}
if backup.Status != backups_core.BackupStatusCompleted && !isCallNotifier {
return
}
n.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupSuccess,
nil,
)
}
func (n *BackuperNode) SendBackupNotification(
backupConfig *backups_config.BackupConfig,
backup *backups_core.Backup,
notificationType backups_config.BackupNotificationType,
errorMessage *string,
) {
database, err := n.databaseService.GetDatabaseByID(backupConfig.DatabaseID)
if err != nil {
return
}
workspace, err := n.workspaceService.GetWorkspaceByID(*database.WorkspaceID)
if err != nil {
return
}
for _, notifier := range database.Notifiers {
if !slices.Contains(
backupConfig.SendNotificationsOn,
notificationType,
) {
continue
}
title := ""
switch notificationType {
case backups_config.NotificationBackupFailed:
title = fmt.Sprintf(
"❌ Backup failed for database \"%s\" (workspace \"%s\")",
database.Name,
workspace.Name,
)
case backups_config.NotificationBackupSuccess:
title = fmt.Sprintf(
"✅ Backup completed for database \"%s\" (workspace \"%s\")",
database.Name,
workspace.Name,
)
}
message := ""
if errorMessage != nil {
message = *errorMessage
} else {
// Format size conditionally
var sizeStr string
if backup.BackupSizeMb < 1024 {
sizeStr = fmt.Sprintf("%.2f MB", backup.BackupSizeMb)
} else {
sizeGB := backup.BackupSizeMb / 1024
sizeStr = fmt.Sprintf("%.2f GB", sizeGB)
}
// Format duration as "0m 0s 0ms"
totalMs := backup.BackupDurationMs
minutes := totalMs / (1000 * 60)
seconds := (totalMs % (1000 * 60)) / 1000
durationStr := fmt.Sprintf("%dm %ds", minutes, seconds)
message = fmt.Sprintf(
"Backup completed successfully in %s.\nCompressed backup size: %s",
durationStr,
sizeStr,
)
}
n.notificationSender.SendNotification(
&notifier,
title,
message,
)
}
}
func (n *BackuperNode) sendHeartbeat(backupNode *task_registry.TaskNode) {
n.lastHeartbeat = time.Now().UTC()
if err := n.tasksRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil {
n.logger.Error("Failed to send heartbeat", "error", err)
}
}

View File

@@ -1,4 +1,4 @@
package backups
package backuping
import (
"context"
@@ -8,17 +8,15 @@ import (
"time"
common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_services "databasus-backend/internal/features/workspaces/services"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
cache_utils "databasus-backend/internal/util/cache"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
@@ -26,6 +24,7 @@ import (
)
func Test_BackupExecuted_NotificationSent(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
@@ -50,23 +49,19 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
t.Run("BackupFailed_FailNotificationSent", func(t *testing.T) {
mockNotificationSender := &MockNotificationSender{}
backupService := &BackupService{
databases.GetDatabaseService(),
storages.GetStorageService(),
backupRepository,
notifiers.GetNotifierService(),
mockNotificationSender,
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
&CreateFailedBackupUsecase{},
logger.GetLogger(),
[]BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
nil,
NewBackupContextManager(),
nil,
backuperNode := CreateTestBackuperNode()
backuperNode.notificationSender = mockNotificationSender
backuperNode.createBackupUseCase = &CreateFailedBackupUsecase{}
// Create a backup record directly that will be looked up by MakeBackup
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err := backupRepository.Save(backup)
assert.NoError(t, err)
// Set up expectations
mockNotificationSender.On("SendNotification",
@@ -79,7 +74,7 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
}),
).Once()
backupService.MakeBackup(database.ID, true)
backuperNode.MakeBackup(backup.ID, true)
// Verify all expectations were met
mockNotificationSender.AssertExpectations(t)
@@ -87,6 +82,19 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
t.Run("BackupSuccess_SuccessNotificationSent", func(t *testing.T) {
mockNotificationSender := &MockNotificationSender{}
backuperNode := CreateTestBackuperNode()
backuperNode.notificationSender = mockNotificationSender
backuperNode.createBackupUseCase = &CreateSuccessBackupUsecase{}
// Create a backup record directly that will be looked up by MakeBackup
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err := backupRepository.Save(backup)
assert.NoError(t, err)
// Set up expectations
mockNotificationSender.On("SendNotification",
@@ -99,25 +107,7 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
}),
).Once()
backupService := &BackupService{
databases.GetDatabaseService(),
storages.GetStorageService(),
backupRepository,
notifiers.GetNotifierService(),
mockNotificationSender,
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
&CreateSuccessBackupUsecase{},
logger.GetLogger(),
[]BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
nil,
NewBackupContextManager(),
nil,
}
backupService.MakeBackup(database.ID, true)
backuperNode.MakeBackup(backup.ID, true)
// Verify all expectations were met
mockNotificationSender.AssertExpectations(t)
@@ -125,23 +115,19 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
t.Run("BackupSuccess_VerifyNotificationContent", func(t *testing.T) {
mockNotificationSender := &MockNotificationSender{}
backupService := &BackupService{
databases.GetDatabaseService(),
storages.GetStorageService(),
backupRepository,
notifiers.GetNotifierService(),
mockNotificationSender,
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
&CreateSuccessBackupUsecase{},
logger.GetLogger(),
[]BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
nil,
NewBackupContextManager(),
nil,
backuperNode := CreateTestBackuperNode()
backuperNode.notificationSender = mockNotificationSender
backuperNode.createBackupUseCase = &CreateSuccessBackupUsecase{}
// Create a backup record directly that will be looked up by MakeBackup
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err := backupRepository.Save(backup)
assert.NoError(t, err)
// capture arguments
var capturedNotifier *notifiers.Notifier
@@ -158,7 +144,7 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
capturedMessage = args.Get(2).(string)
}).Once()
backupService.MakeBackup(database.ID, true)
backuperNode.MakeBackup(backup.ID, true)
// Verify expectations were met
mockNotificationSender.AssertExpectations(t)

View File

@@ -0,0 +1,71 @@
package backuping
import (
"databasus-backend/internal/config"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/usecases"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
task_registry "databasus-backend/internal/features/tasks/registry"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
"time"
"github.com/google/uuid"
)
var backupRepository = &backups_core.BackupRepository{}
var taskCancelManager = tasks_cancellation.GetTaskCancelManager()
var nodesRegistry = task_registry.GetTaskNodesRegistry()
func getNodeID() uuid.UUID {
nodeIDStr := config.GetEnv().NodeID
nodeID, err := uuid.Parse(nodeIDStr)
if err != nil {
logger.GetLogger().Error("Failed to parse node ID from config", "error", err)
panic(err)
}
return nodeID
}
var backuperNode = &BackuperNode{
databases.GetDatabaseService(),
encryption.GetFieldEncryptor(),
workspaces_services.GetWorkspaceService(),
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
taskCancelManager,
nodesRegistry,
logger.GetLogger(),
usecases.GetCreateBackupUsecase(),
getNodeID(),
time.Time{},
}
var backupsScheduler = &BackupsScheduler{
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
taskCancelManager,
nodesRegistry,
time.Now().UTC(),
logger.GetLogger(),
make(map[uuid.UUID]BackupToNodeRelation),
backuperNode,
}
func GetBackupsScheduler() *BackupsScheduler {
return backupsScheduler
}
func GetBackuperNode() *BackuperNode {
return backuperNode
}

View File

@@ -0,0 +1,8 @@
package backuping
import "github.com/google/uuid"
type BackupToNodeRelation struct {
NodeID uuid.UUID `json:"nodeId"`
BackupsIDs []uuid.UUID `json:"backupsIds"`
}

View File

@@ -1,4 +1,4 @@
package backups
package backuping
import (
"databasus-backend/internal/features/notifiers"

View File

@@ -0,0 +1,603 @@
package backuping
import (
"context"
"databasus-backend/internal/config"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/storages"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
task_registry "databasus-backend/internal/features/tasks/registry"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/period"
"fmt"
"log/slog"
"time"
"github.com/google/uuid"
)
const (
schedulerStartupDelay = 1 * time.Minute
schedulerTickerInterval = 1 * time.Minute
schedulerHealthcheckThreshold = 5 * time.Minute
)
type BackupsScheduler struct {
backupRepository *backups_core.BackupRepository
backupConfigService *backups_config.BackupConfigService
storageService *storages.StorageService
taskCancelManager *task_cancellation.TaskCancelManager
tasksRegistry *task_registry.TaskNodesRegistry
lastBackupTime time.Time
logger *slog.Logger
backupToNodeRelations map[uuid.UUID]BackupToNodeRelation
backuperNode *BackuperNode
}
func (s *BackupsScheduler) Run(ctx context.Context) {
s.lastBackupTime = time.Now().UTC()
if config.GetEnv().IsManyNodesMode {
// wait other nodes to start
time.Sleep(schedulerStartupDelay)
}
if err := s.failBackupsInProgress(); err != nil {
s.logger.Error("Failed to fail backups in progress", "error", err)
panic(err)
}
if err := s.tasksRegistry.SubscribeForTasksCompletions(s.onBackupCompleted); err != nil {
s.logger.Error("Failed to subscribe to backup completions", "error", err)
panic(err)
}
defer func() {
if err := s.tasksRegistry.UnsubscribeForTasksCompletions(); err != nil {
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
}
}()
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(schedulerTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.cleanOldBackups(); err != nil {
s.logger.Error("Failed to clean old backups", "error", err)
}
if err := s.checkDeadNodesAndFailBackups(); err != nil {
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
}
if err := s.runPendingBackups(); err != nil {
s.logger.Error("Failed to run pending backups", "error", err)
}
s.lastBackupTime = time.Now().UTC()
}
}
}
func (s *BackupsScheduler) IsSchedulerRunning() bool {
// if last backup time is more than 5 minutes ago, return false
return s.lastBackupTime.After(time.Now().UTC().Add(-schedulerHealthcheckThreshold))
}
func (s *BackupsScheduler) failBackupsInProgress() error {
backupsInProgress, err := s.backupRepository.FindByStatus(backups_core.BackupStatusInProgress)
if err != nil {
return err
}
for _, backup := range backupsInProgress {
if err := s.taskCancelManager.CancelTask(backup.ID); err != nil {
s.logger.Error(
"Failed to cancel backup via task cancel manager",
"backupId",
backup.ID,
"error",
err,
)
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(backup.DatabaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
continue
}
failMessage := "Backup failed due to application restart"
backup.FailMessage = &failMessage
backup.Status = backups_core.BackupStatusFailed
backup.BackupSizeMb = 0
s.backuperNode.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupFailed,
&failMessage,
)
if err := s.backupRepository.Save(backup); err != nil {
return err
}
}
return nil
}
func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool) {
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(databaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
return
}
if backupConfig.StorageID == nil {
s.logger.Error("Backup config storage ID is nil", "databaseId", databaseID)
return
}
leastBusyNodeID, err := s.calculateLeastBusyNode()
if err != nil {
s.logger.Error(
"Failed to calculate least busy node",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
return
}
backup := &backups_core.Backup{
DatabaseID: backupConfig.DatabaseID,
StorageID: *backupConfig.StorageID,
Status: backups_core.BackupStatusInProgress,
BackupSizeMb: 0,
CreatedAt: time.Now().UTC(),
}
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error(
"Failed to save backup",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
return
}
if err := s.tasksRegistry.IncrementTasksInProgress(leastBusyNodeID.String()); err != nil {
s.logger.Error(
"Failed to increment backups in progress",
"nodeId",
leastBusyNodeID,
"backupId",
backup.ID,
"error",
err,
)
return
}
if err := s.tasksRegistry.AssignTaskToNode(leastBusyNodeID.String(), backup.ID, isCallNotifier); err != nil {
s.logger.Error(
"Failed to submit backup",
"nodeId",
leastBusyNodeID,
"backupId",
backup.ID,
"error",
err,
)
if decrementErr := s.tasksRegistry.DecrementTasksInProgress(leastBusyNodeID.String()); decrementErr != nil {
s.logger.Error(
"Failed to decrement backups in progress after submit failure",
"nodeId",
leastBusyNodeID,
"error",
decrementErr,
)
}
return
}
if relation, exists := s.backupToNodeRelations[*leastBusyNodeID]; exists {
relation.BackupsIDs = append(relation.BackupsIDs, backup.ID)
s.backupToNodeRelations[*leastBusyNodeID] = relation
} else {
s.backupToNodeRelations[*leastBusyNodeID] = BackupToNodeRelation{
NodeID: *leastBusyNodeID,
BackupsIDs: []uuid.UUID{backup.ID},
}
}
s.logger.Info(
"Successfully triggered scheduled backup",
"databaseId",
backupConfig.DatabaseID,
"backupId",
backup.ID,
"nodeId",
leastBusyNodeID,
)
}
// GetRemainedBackupTryCount returns the number of remaining backup tries for a given backup.
// If the backup is not failed or the backup config does not allow retries, it returns 0.
// If the backup is failed and the backup config allows retries, it returns the number of remaining tries.
// If the backup is failed and the backup config does not allow retries, it returns 0.
func (s *BackupsScheduler) GetRemainedBackupTryCount(lastBackup *backups_core.Backup) int {
if lastBackup == nil {
return 0
}
if lastBackup.Status != backups_core.BackupStatusFailed {
return 0
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(lastBackup.DatabaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
return 0
}
if !backupConfig.IsRetryIfFailed {
return 0
}
maxFailedTriesCount := backupConfig.MaxFailedTriesCount
lastBackups, err := s.backupRepository.FindByDatabaseIDWithLimit(
lastBackup.DatabaseID,
maxFailedTriesCount,
)
if err != nil {
s.logger.Error("Failed to find last backups by database ID", "error", err)
return 0
}
lastFailedBackups := make([]*backups_core.Backup, 0)
for _, backup := range lastBackups {
if backup.Status == backups_core.BackupStatusFailed {
lastFailedBackups = append(lastFailedBackups, backup)
}
}
return maxFailedTriesCount - len(lastFailedBackups)
}
func (s *BackupsScheduler) cleanOldBackups() error {
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
backupStorePeriod := backupConfig.StorePeriod
if backupStorePeriod == period.PeriodForever {
continue
}
storeDuration := backupStorePeriod.ToDuration()
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
oldBackups, err := s.backupRepository.FindBackupsBeforeDate(
backupConfig.DatabaseID,
dateBeforeBackupsShouldBeDeleted,
)
if err != nil {
s.logger.Error(
"Failed to find old backups for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
for _, backup := range oldBackups {
storage, err := s.storageService.GetStorageByID(backup.StorageID)
if err != nil {
s.logger.Error(
"Failed to get storage by ID",
"storageId",
backup.StorageID,
"error",
err,
)
continue
}
encryptor := encryption.GetFieldEncryptor()
err = storage.DeleteFile(encryptor, backup.ID)
if err != nil {
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
}
if err := s.backupRepository.DeleteByID(backup.ID); err != nil {
s.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
continue
}
s.logger.Info(
"Deleted old backup",
"backupId",
backup.ID,
"databaseId",
backupConfig.DatabaseID,
)
}
}
return nil
}
func (s *BackupsScheduler) runPendingBackups() error {
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
if backupConfig.BackupInterval == nil {
continue
}
lastBackup, err := s.backupRepository.FindLastByDatabaseID(backupConfig.DatabaseID)
if err != nil {
s.logger.Error(
"Failed to get last backup for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
var lastBackupTime *time.Time
if lastBackup != nil {
lastBackupTime = &lastBackup.CreatedAt
}
remainedBackupTryCount := s.GetRemainedBackupTryCount(lastBackup)
if backupConfig.BackupInterval.ShouldTriggerBackup(time.Now().UTC(), lastBackupTime) ||
remainedBackupTryCount > 0 {
s.logger.Info(
"Triggering scheduled backup",
"databaseId",
backupConfig.DatabaseID,
"intervalType",
backupConfig.BackupInterval.Interval,
)
s.StartBackup(backupConfig.DatabaseID, remainedBackupTryCount == 1)
continue
}
}
return nil
}
func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
nodes, err := s.tasksRegistry.GetAvailableNodes()
if err != nil {
return nil, fmt.Errorf("failed to get available nodes: %w", err)
}
if len(nodes) == 0 {
return nil, fmt.Errorf("no nodes available")
}
stats, err := s.tasksRegistry.GetNodesStats()
if err != nil {
return nil, fmt.Errorf("failed to get backup nodes stats: %w", err)
}
statsMap := make(map[uuid.UUID]int)
for _, stat := range stats {
statsMap[stat.ID] = stat.ActiveTasks
}
var bestNode *task_registry.TaskNode
var bestScore float64 = -1
for i := range nodes {
node := &nodes[i]
activeBackups := statsMap[node.ID]
var score float64
if node.ThroughputMBs > 0 {
score = float64(activeBackups) / float64(node.ThroughputMBs)
} else {
score = float64(activeBackups) * 1000
}
if bestNode == nil || score < bestScore {
bestNode = node
bestScore = score
}
}
if bestNode == nil {
return nil, fmt.Errorf("no suitable nodes available")
}
return &bestNode.ID, nil
}
func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUID) {
nodeID, err := uuid.Parse(nodeIDStr)
if err != nil {
s.logger.Error(
"Failed to parse node ID from completion message",
"nodeId",
nodeIDStr,
"error",
err,
)
return
}
// Verify this task is actually a backup (registry contains multiple task types)
_, err = s.backupRepository.FindByID(backupID)
if err != nil {
// Not a backup task, ignore it
return
}
relation, exists := s.backupToNodeRelations[nodeID]
if !exists {
s.logger.Warn(
"Received completion for unknown node",
"nodeId",
nodeID,
"backupId",
backupID,
)
return
}
newBackupIDs := make([]uuid.UUID, 0)
found := false
for _, id := range relation.BackupsIDs {
if id == backupID {
found = true
continue
}
newBackupIDs = append(newBackupIDs, id)
}
if !found {
s.logger.Warn(
"Backup not found in node's backup list",
"nodeId",
nodeID,
"backupId",
backupID,
)
return
}
if len(newBackupIDs) == 0 {
delete(s.backupToNodeRelations, nodeID)
} else {
relation.BackupsIDs = newBackupIDs
s.backupToNodeRelations[nodeID] = relation
}
if err := s.tasksRegistry.DecrementTasksInProgress(nodeIDStr); err != nil {
s.logger.Error(
"Failed to decrement backups in progress",
"nodeId",
nodeID,
"backupId",
backupID,
"error",
err,
)
}
}
func (s *BackupsScheduler) checkDeadNodesAndFailBackups() error {
nodes, err := s.tasksRegistry.GetAvailableNodes()
if err != nil {
return fmt.Errorf("failed to get available nodes: %w", err)
}
aliveNodeIDs := make(map[uuid.UUID]bool)
for _, node := range nodes {
aliveNodeIDs[node.ID] = true
}
for nodeID, relation := range s.backupToNodeRelations {
if aliveNodeIDs[nodeID] {
continue
}
s.logger.Warn(
"Node is dead, failing its backups",
"nodeId",
nodeID,
"backupCount",
len(relation.BackupsIDs),
)
for _, backupID := range relation.BackupsIDs {
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
s.logger.Error(
"Failed to find backup for dead node",
"nodeId",
nodeID,
"backupId",
backupID,
"error",
err,
)
continue
}
failMessage := "Backup failed due to node unavailability"
backup.FailMessage = &failMessage
backup.Status = backups_core.BackupStatusFailed
backup.BackupSizeMb = 0
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error(
"Failed to save failed backup for dead node",
"nodeId",
nodeID,
"backupId",
backupID,
"error",
err,
)
continue
}
if err := s.tasksRegistry.DecrementTasksInProgress(nodeID.String()); err != nil {
s.logger.Error(
"Failed to decrement backups in progress for dead node",
"nodeId",
nodeID,
"backupId",
backupID,
"error",
err,
)
}
s.logger.Info(
"Failed backup due to dead node",
"nodeId",
nodeID,
"backupId",
backupID,
)
}
delete(s.backupToNodeRelations, nodeID)
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,280 @@
package backuping
import (
"context"
"fmt"
"testing"
"time"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/usecases"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
task_registry "databasus-backend/internal/features/tasks/registry"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_services "databasus-backend/internal/features/workspaces/services"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
func CreateTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
backups_config.GetBackupConfigController(),
)
return router
}
func CreateTestBackuperNode() *BackuperNode {
return &BackuperNode{
databases.GetDatabaseService(),
encryption.GetFieldEncryptor(),
workspaces_services.GetWorkspaceService(),
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
taskCancelManager,
nodesRegistry,
logger.GetLogger(),
usecases.GetCreateBackupUsecase(),
uuid.New(),
time.Time{},
}
}
// WaitForBackupCompletion waits for a new backup to be created and completed (or failed)
// for the given database. It checks for backups with count greater than expectedInitialCount.
func WaitForBackupCompletion(
t *testing.T,
databaseID uuid.UUID,
expectedInitialCount int,
timeout time.Duration,
) {
deadline := time.Now().UTC().Add(timeout)
for time.Now().UTC().Before(deadline) {
backups, err := backupRepository.FindByDatabaseID(databaseID)
if err != nil {
t.Logf("WaitForBackupCompletion: error finding backups: %v", err)
time.Sleep(50 * time.Millisecond)
continue
}
t.Logf(
"WaitForBackupCompletion: found %d backups (expected > %d)",
len(backups),
expectedInitialCount,
)
if len(backups) > expectedInitialCount {
// Check if the newest backup has completed or failed
newestBackup := backups[0]
t.Logf("WaitForBackupCompletion: newest backup status: %s", newestBackup.Status)
if newestBackup.Status == backups_core.BackupStatusCompleted ||
newestBackup.Status == backups_core.BackupStatusFailed ||
newestBackup.Status == backups_core.BackupStatusCanceled {
t.Logf(
"WaitForBackupCompletion: backup finished with status %s",
newestBackup.Status,
)
return
}
}
time.Sleep(50 * time.Millisecond)
}
t.Logf("WaitForBackupCompletion: timeout waiting for backup to complete")
}
// StartBackuperNodeForTest starts a BackuperNode in a goroutine for testing.
// The node registers itself in the registry and subscribes to backup assignments.
// Returns a context cancel function that should be deferred to stop the node.
func StartBackuperNodeForTest(t *testing.T, backuperNode *BackuperNode) context.CancelFunc {
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
backuperNode.Run(ctx)
close(done)
}()
// Poll registry for node presence instead of fixed sleep
deadline := time.Now().UTC().Add(5 * time.Second)
for time.Now().UTC().Before(deadline) {
nodes, err := nodesRegistry.GetAvailableNodes()
if err == nil {
for _, node := range nodes {
if node.ID == backuperNode.nodeID {
t.Logf("BackuperNode registered in registry: %s", backuperNode.nodeID)
return func() {
cancel()
select {
case <-done:
t.Log("BackuperNode stopped gracefully")
case <-time.After(2 * time.Second):
t.Log("BackuperNode stop timeout")
}
}
}
}
}
time.Sleep(50 * time.Millisecond)
}
t.Fatalf("BackuperNode failed to register in registry within timeout")
return nil
}
// StartSchedulerForTest starts the BackupsScheduler in a goroutine for testing.
// The scheduler subscribes to task completions and manages backup lifecycle.
// Returns a context cancel function that should be deferred to stop the scheduler.
func StartSchedulerForTest(t *testing.T) context.CancelFunc {
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
GetBackupsScheduler().Run(ctx)
close(done)
}()
// Give scheduler time to subscribe to completions
time.Sleep(100 * time.Millisecond)
t.Log("BackupsScheduler started")
return func() {
cancel()
select {
case <-done:
t.Log("BackupsScheduler stopped gracefully")
case <-time.After(2 * time.Second):
t.Log("BackupsScheduler stop timeout")
}
}
}
// StopBackuperNodeForTest stops the BackuperNode by canceling its context.
// It waits for the node to unregister from the registry.
func StopBackuperNodeForTest(t *testing.T, cancel context.CancelFunc, backuperNode *BackuperNode) {
cancel()
// Wait for node to unregister from registry
deadline := time.Now().UTC().Add(2 * time.Second)
for time.Now().UTC().Before(deadline) {
nodes, err := nodesRegistry.GetAvailableNodes()
if err == nil {
found := false
for _, node := range nodes {
if node.ID == backuperNode.nodeID {
found = true
break
}
}
if !found {
t.Logf("BackuperNode unregistered from registry: %s", backuperNode.nodeID)
return
}
}
time.Sleep(50 * time.Millisecond)
}
t.Logf("BackuperNode stop completed for %s", backuperNode.nodeID)
}
func CreateMockNodeInRegistry(nodeID uuid.UUID, throughputMBs int, lastHeartbeat time.Time) error {
backupNode := task_registry.TaskNode{
ID: nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: lastHeartbeat,
}
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
}
func UpdateNodeHeartbeatDirectly(
nodeID uuid.UUID,
throughputMBs int,
lastHeartbeat time.Time,
) error {
backupNode := task_registry.TaskNode{
ID: nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: lastHeartbeat,
}
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
}
func GetNodeFromRegistry(nodeID uuid.UUID) (*task_registry.TaskNode, error) {
nodes, err := nodesRegistry.GetAvailableNodes()
if err != nil {
return nil, err
}
for _, node := range nodes {
if node.ID == nodeID {
return &node, nil
}
}
return nil, fmt.Errorf("node not found")
}
// WaitForActiveTasksDecrease waits for the active task count to decrease below the initial count.
// It polls the registry every 500ms until the count decreases or the timeout is reached.
// Returns true if the count decreased, false if timeout was reached.
func WaitForActiveTasksDecrease(
t *testing.T,
nodeID uuid.UUID,
initialCount int,
timeout time.Duration,
) bool {
deadline := time.Now().UTC().Add(timeout)
for time.Now().UTC().Before(deadline) {
stats, err := nodesRegistry.GetNodesStats()
if err != nil {
t.Logf("WaitForActiveTasksDecrease: error getting node stats: %v", err)
time.Sleep(500 * time.Millisecond)
continue
}
for _, stat := range stats {
if stat.ID == nodeID {
t.Logf(
"WaitForActiveTasksDecrease: current active tasks = %d (initial = %d)",
stat.ActiveTasks,
initialCount,
)
if stat.ActiveTasks < initialCount {
t.Logf(
"WaitForActiveTasksDecrease: active tasks decreased from %d to %d",
initialCount,
stat.ActiveTasks,
)
return true
}
break
}
}
time.Sleep(500 * time.Millisecond)
}
t.Logf("WaitForActiveTasksDecrease: timeout waiting for active tasks to decrease")
return false
}

View File

@@ -1,11 +1,15 @@
package backups
import (
"context"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
"databasus-backend/internal/features/databases"
users_middleware "databasus-backend/internal/features/users/middleware"
"fmt"
"io"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -170,9 +174,10 @@ func (c *BackupController) CancelBackup(ctx *gin.Context) {
// @Description Generate a token for downloading a backup file (valid for 5 minutes)
// @Tags backups
// @Param id path string true "Backup ID"
// @Success 200 {object} GenerateDownloadTokenResponse
// @Success 200 {object} backups_download.GenerateDownloadTokenResponse
// @Failure 400
// @Failure 401
// @Failure 409 {object} map[string]string "Download already in progress"
// @Router /backups/{id}/download-token [post]
func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
@@ -189,6 +194,15 @@ func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
response, err := c.backupService.GenerateDownloadToken(user, id)
if err != nil {
if err == backups_download.ErrDownloadAlreadyInProgress {
ctx.JSON(
http.StatusConflict,
gin.H{
"error": "Download already in progress for some of backups. Please wait until previous download completed or cancel it",
},
)
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
@@ -198,14 +212,22 @@ func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
// GetFile
// @Summary Download a backup file
// @Description Download the backup file for the specified backup using a download token
// @Description Download the backup file for the specified backup using a download token.
// @Description
// @Description **Download Concurrency Control:**
// @Description - Only one download per user is allowed at a time
// @Description - If a download is already in progress, returns 409 Conflict
// @Description - Downloads are tracked using cache with 5-second TTL and 3-second heartbeat
// @Description - Browser cancellations automatically release the download lock
// @Description - Server crashes are handled via automatic cache expiry (5 seconds)
// @Tags backups
// @Param id path string true "Backup ID"
// @Param token query string true "Download token"
// @Success 200 {file} file
// @Failure 400
// @Failure 401
// @Failure 500
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 409 {object} map[string]string "Download already in progress"
// @Failure 500 {object} map[string]string
// @Router /backups/{id}/file [get]
func (c *BackupController) GetFile(ctx *gin.Context) {
token := ctx.Query("token")
@@ -214,7 +236,6 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
return
}
// Get backup ID from URL
backupIDParam := ctx.Param("id")
backupID, err := uuid.Parse(backupIDParam)
if err != nil {
@@ -222,13 +243,22 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
return
}
downloadToken, err := c.backupService.ValidateDownloadToken(token)
downloadToken, rateLimiter, err := c.backupService.ValidateDownloadToken(token)
if err != nil {
if err == backups_download.ErrDownloadAlreadyInProgress {
ctx.JSON(
http.StatusConflict,
gin.H{
"error": "download already in progress for this user. Please wait until previous download completed or cancel it",
},
)
return
}
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"})
return
}
// Verify token is for the requested backup
if downloadToken.BackupID != backupID {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"})
return
@@ -238,18 +268,28 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
downloadToken.BackupID,
)
if err != nil {
c.backupService.UnregisterDownload(downloadToken.UserID)
c.backupService.ReleaseDownloadLock(downloadToken.UserID)
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
rateLimitedReader := backups_download.NewRateLimitedReader(fileReader, rateLimiter)
heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background())
defer func() {
if err := fileReader.Close(); err != nil {
cancelHeartbeat()
c.backupService.UnregisterDownload(downloadToken.UserID)
c.backupService.ReleaseDownloadLock(downloadToken.UserID)
if err := rateLimitedReader.Close(); err != nil {
fmt.Printf("Error closing file reader: %v\n", err)
}
}()
go c.startDownloadHeartbeat(heartbeatCtx, downloadToken.UserID)
filename := c.generateBackupFilename(backup, database)
// Set Content-Length for progress tracking
if backup.BackupSizeMb > 0 {
sizeBytes := int64(backup.BackupSizeMb * 1024 * 1024)
ctx.Header("Content-Length", fmt.Sprintf("%d", sizeBytes))
@@ -261,13 +301,12 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
fmt.Sprintf("attachment; filename=\"%s\"", filename),
)
_, err = io.Copy(ctx.Writer, fileReader)
_, err = io.Copy(ctx.Writer, rateLimitedReader)
if err != nil {
fmt.Printf("Error streaming file: %v\n", err)
return
}
// Write audit log after successful download
c.backupService.WriteAuditLogForDownload(downloadToken.UserID, backup, database)
}
@@ -276,7 +315,7 @@ type MakeBackupRequest struct {
}
func (c *BackupController) generateBackupFilename(
backup *Backup,
backup *backups_core.Backup,
database *databases.Database,
) string {
// Format timestamp as YYYY-MM-DD_HH-mm-ss
@@ -333,3 +372,17 @@ func sanitizeFilename(name string) string {
return string(result)
}
func (c *BackupController) startDownloadHeartbeat(ctx context.Context, userID uuid.UUID) {
ticker := time.NewTicker(backups_download.GetDownloadHeartbeatInterval())
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
c.backupService.RefreshDownloadLock(userID)
}
}
}

View File

@@ -18,7 +18,8 @@ import (
"databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups/download_token"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/postgresql"
@@ -478,7 +479,7 @@ func Test_GenerateDownloadToken_PermissionsEnforced(t *testing.T) {
)
if tt.expectSuccess {
var response GenerateDownloadTokenResponse
var response backups_download.GenerateDownloadTokenResponse
err := json.Unmarshal(testResp.Body, &response)
assert.NoError(t, err)
assert.NotEmpty(t, response.Token)
@@ -499,7 +500,7 @@ func Test_DownloadBackup_WithValidToken_Success(t *testing.T) {
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
// Generate download token
var tokenResponse GenerateDownloadTokenResponse
var tokenResponse backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
@@ -620,7 +621,7 @@ func Test_DownloadBackup_TokenUsedOnce_CannotReuseToken(t *testing.T) {
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
// Generate download token
var tokenResponse GenerateDownloadTokenResponse
var tokenResponse backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
@@ -683,7 +684,7 @@ func Test_DownloadBackup_WithDifferentBackupToken_Unauthorized(t *testing.T) {
backup2 := createTestBackup(database2, owner)
// Generate token for backup1
var tokenResponse GenerateDownloadTokenResponse
var tokenResponse backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
@@ -714,7 +715,7 @@ func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
// Generate download token
var tokenResponse GenerateDownloadTokenResponse
var tokenResponse backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
@@ -806,7 +807,7 @@ func Test_DownloadBackup_ProperFilenameForPostgreSQL(t *testing.T) {
backup := createTestBackup(database, owner)
// Generate download token
var tokenResponse GenerateDownloadTokenResponse
var tokenResponse backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
@@ -897,22 +898,22 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
_, err = configService.SaveBackupConfig(config)
assert.NoError(t, err)
backup := &Backup{
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusInProgress,
Status: backups_core.BackupStatusInProgress,
BackupSizeMb: 0,
BackupDurationMs: 0,
CreatedAt: time.Now().UTC(),
}
repo := &BackupRepository{}
repo := &backups_core.BackupRepository{}
err = repo.Save(backup)
assert.NoError(t, err)
// Register a cancellable context for the backup
GetBackupService().backupContextManager.RegisterBackup(backup.ID, func() {})
GetBackupService().taskCancelManager.RegisterTask(backup.ID, func() {})
resp := test_utils.MakePostRequest(
t,
@@ -949,6 +950,189 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
assert.True(t, foundCancelLog, "Cancel audit log should be created")
}
func Test_ConcurrentDownloadPrevention(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
var token1Response backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
&token1Response,
)
var token2Response backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
&token2Response,
)
downloadInProgress := make(chan bool, 1)
downloadComplete := make(chan bool, 1)
go func() {
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf(
"/api/v1/backups/%s/file?token=%s",
backup.ID.String(),
token1Response.Token,
),
"",
http.StatusOK,
)
downloadComplete <- true
}()
time.Sleep(50 * time.Millisecond)
service := GetBackupService()
if !service.IsDownloadInProgress(owner.UserID) {
t.Log("Warning: First download completed before we could test concurrency")
<-downloadComplete
return
}
downloadInProgress <- true
resp := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), token2Response.Token),
"",
http.StatusConflict,
)
var errorResponse map[string]string
err := json.Unmarshal(resp.Body, &errorResponse)
assert.NoError(t, err)
assert.Contains(t, errorResponse["error"], "download already in progress")
<-downloadComplete
<-downloadInProgress
time.Sleep(100 * time.Millisecond)
var token3Response backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
&token3Response,
)
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), token3Response.Token),
"",
http.StatusOK,
)
t.Log("Database:", database.Name)
t.Log(
"Successfully prevented concurrent downloads and allowed subsequent downloads after completion",
)
}
func Test_GenerateDownloadToken_BlockedWhenDownloadInProgress(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
var token1Response backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
&token1Response,
)
downloadComplete := make(chan bool, 1)
go func() {
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf(
"/api/v1/backups/%s/file?token=%s",
backup.ID.String(),
token1Response.Token,
),
"",
http.StatusOK,
)
downloadComplete <- true
}()
time.Sleep(50 * time.Millisecond)
service := GetBackupService()
if !service.IsDownloadInProgress(owner.UserID) {
t.Log("Warning: First download completed before we could test token generation blocking")
<-downloadComplete
return
}
resp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusConflict,
)
var errorResponse map[string]string
err := json.Unmarshal(resp.Body, &errorResponse)
assert.NoError(t, err)
assert.Contains(t, errorResponse["error"], "download already in progress")
<-downloadComplete
time.Sleep(100 * time.Millisecond)
var token2Response backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
&token2Response,
)
assert.NotEmpty(t, token2Response.Token)
assert.NotEqual(t, token1Response.Token, token2Response.Token)
t.Log("Database:", database.Name)
t.Log(
"Successfully blocked token generation during download and allowed generation after completion",
)
}
func createTestRouter() *gin.Engine {
return CreateTestRouter()
}
@@ -1038,7 +1222,7 @@ func createTestDatabaseWithBackups(
workspace *workspaces_models.Workspace,
owner *users_dto.SignInResponseDTO,
router *gin.Engine,
) (*databases.Database, *Backup) {
) (*databases.Database, *backups_core.Backup) {
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
@@ -1064,7 +1248,7 @@ func createTestDatabaseWithBackups(
func createTestBackup(
database *databases.Database,
owner *users_dto.SignInResponseDTO,
) *Backup {
) *backups_core.Backup {
userService := users_services.GetUserService()
user, err := userService.GetUserFromToken(owner.Token)
if err != nil {
@@ -1076,17 +1260,17 @@ func createTestBackup(
panic("No storage found for workspace")
}
backup := &Backup{
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storages[0].ID,
Status: BackupStatusCompleted,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10.5,
BackupDurationMs: 1000,
CreatedAt: time.Now().UTC(),
}
repo := &BackupRepository{}
repo := &backups_core.BackupRepository{}
if err := repo.Save(backup); err != nil {
panic(err)
}
@@ -1116,7 +1300,7 @@ func createExpiredDownloadToken(backupID, userID uuid.UUID) string {
}
// Manually update the token to be expired
repo := &download_token.DownloadTokenRepository{}
repo := &backups_download.DownloadTokenRepository{}
downloadToken, err := repo.FindByToken(token)
if err != nil || downloadToken == nil {
panic(fmt.Sprintf("Failed to find generated token: %v", err))
@@ -1130,3 +1314,267 @@ func createExpiredDownloadToken(backupID, userID uuid.UUID) string {
return token
}
func Test_BandwidthThrottling_SingleDownload_Uses75Percent(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
bandwidthManager := backups_download.GetBandwidthManager()
initialCount := bandwidthManager.GetActiveDownloadCount()
var tokenResponse backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
&tokenResponse,
)
downloadStarted := make(chan bool, 1)
downloadComplete := make(chan bool, 1)
go func() {
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf(
"/api/v1/backups/%s/file?token=%s",
backup.ID.String(),
tokenResponse.Token,
),
"",
http.StatusOK,
)
downloadComplete <- true
}()
time.Sleep(50 * time.Millisecond)
activeCount := bandwidthManager.GetActiveDownloadCount()
if activeCount > initialCount {
downloadStarted <- true
assert.Equal(t, initialCount+1, activeCount, "Should have one active download")
}
<-downloadComplete
if len(downloadStarted) > 0 {
<-downloadStarted
}
time.Sleep(50 * time.Millisecond)
finalCount := bandwidthManager.GetActiveDownloadCount()
assert.Equal(t, initialCount, finalCount, "Download should be unregistered after completion")
}
func Test_BandwidthThrottling_MultipleDownloads_ShareBandwidth(t *testing.T) {
router := createTestRouter()
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
owner3 := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner1, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
owner2,
users_enums.WorkspaceRoleMember,
owner1.Token,
router,
)
workspaces_testing.AddMemberToWorkspace(
workspace,
owner3,
users_enums.WorkspaceRoleMember,
owner1.Token,
router,
)
database := createTestDatabase("Test Database", workspace.ID, owner1.Token, router)
storage := createTestStorage(workspace.ID)
configService := backups_config.GetBackupConfigService()
config, err := configService.GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
config.IsBackupsEnabled = true
config.StorageID = &storage.ID
config.Storage = storage
_, err = configService.SaveBackupConfig(config)
assert.NoError(t, err)
backup1 := createTestBackup(database, owner1)
backup2 := createTestBackup(database, owner2)
backup3 := createTestBackup(database, owner3)
var token1, token2, token3 backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup1.ID.String()),
"Bearer "+owner1.Token,
nil,
http.StatusOK,
&token1,
)
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup2.ID.String()),
"Bearer "+owner2.Token,
nil,
http.StatusOK,
&token2,
)
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup3.ID.String()),
"Bearer "+owner3.Token,
nil,
http.StatusOK,
&token3,
)
bandwidthManager := backups_download.GetBandwidthManager()
initialCount := bandwidthManager.GetActiveDownloadCount()
complete1 := make(chan bool, 1)
complete2 := make(chan bool, 1)
complete3 := make(chan bool, 1)
go func() {
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup1.ID.String(), token1.Token),
"",
http.StatusOK,
)
complete1 <- true
}()
go func() {
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup2.ID.String(), token2.Token),
"",
http.StatusOK,
)
complete2 <- true
}()
go func() {
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup3.ID.String(), token3.Token),
"",
http.StatusOK,
)
complete3 <- true
}()
time.Sleep(100 * time.Millisecond)
<-complete1
<-complete2
<-complete3
time.Sleep(100 * time.Millisecond)
finalCount := bandwidthManager.GetActiveDownloadCount()
assert.Equal(t, initialCount, finalCount, "All downloads should be unregistered")
}
func Test_BandwidthThrottling_DynamicAdjustment(t *testing.T) {
router := createTestRouter()
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner1, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
owner2,
users_enums.WorkspaceRoleMember,
owner1.Token,
router,
)
database := createTestDatabase("Test Database", workspace.ID, owner1.Token, router)
storage := createTestStorage(workspace.ID)
configService := backups_config.GetBackupConfigService()
config, err := configService.GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
config.IsBackupsEnabled = true
config.StorageID = &storage.ID
config.Storage = storage
_, err = configService.SaveBackupConfig(config)
assert.NoError(t, err)
backup1 := createTestBackup(database, owner1)
backup2 := createTestBackup(database, owner2)
var token1, token2 backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup1.ID.String()),
"Bearer "+owner1.Token,
nil,
http.StatusOK,
&token1,
)
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup2.ID.String()),
"Bearer "+owner2.Token,
nil,
http.StatusOK,
&token2,
)
bandwidthManager := backups_download.GetBandwidthManager()
initialCount := bandwidthManager.GetActiveDownloadCount()
complete1 := make(chan bool, 1)
complete2 := make(chan bool, 1)
go func() {
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup1.ID.String(), token1.Token),
"",
http.StatusOK,
)
complete1 <- true
}()
time.Sleep(50 * time.Millisecond)
go func() {
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup2.ID.String(), token2.Token),
"",
http.StatusOK,
)
complete2 <- true
}()
<-complete1
<-complete2
time.Sleep(100 * time.Millisecond)
finalCount := bandwidthManager.GetActiveDownloadCount()
assert.Equal(t, initialCount, finalCount, "All downloads completed and unregistered")
}

View File

@@ -1,4 +1,4 @@
package backups
package backups_core
type BackupStatus string

View File

@@ -1,4 +1,4 @@
package backups
package backups_core
import (
"context"

View File

@@ -1,4 +1,4 @@
package backups
package backups_core
import (
backups_config "databasus-backend/internal/features/backups/config"

View File

@@ -1,4 +1,4 @@
package backups
package backups_core
import (
"databasus-backend/internal/storage"

View File

@@ -1,63 +1,47 @@
package backups
import (
"time"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups/download_token"
"databasus-backend/internal/features/backups/backups/backuping"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
"databasus-backend/internal/features/backups/backups/usecases"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
)
var backupRepository = &BackupRepository{}
var backupRepository = &backups_core.BackupRepository{}
var backupContextManager = NewBackupContextManager()
var taskCancelManager = task_cancellation.GetTaskCancelManager()
var backupService = &BackupService{
databases.GetDatabaseService(),
storages.GetStorageService(),
backupRepository,
notifiers.GetNotifierService(),
notifiers.GetNotifierService(),
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
usecases.GetCreateBackupUsecase(),
logger.GetLogger(),
[]BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
backupContextManager,
download_token.GetDownloadTokenService(),
}
var backupBackgroundService = &BackupBackgroundService{
backupService,
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
time.Now().UTC(),
logger.GetLogger(),
databaseService: databases.GetDatabaseService(),
storageService: storages.GetStorageService(),
backupRepository: backupRepository,
notifierService: notifiers.GetNotifierService(),
notificationSender: notifiers.GetNotifierService(),
backupConfigService: backups_config.GetBackupConfigService(),
secretKeyService: encryption_secrets.GetSecretKeyService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
createBackupUseCase: usecases.GetCreateBackupUsecase(),
logger: logger.GetLogger(),
backupRemoveListeners: []backups_core.BackupRemoveListener{},
workspaceService: workspaces_services.GetWorkspaceService(),
auditLogService: audit_logs.GetAuditLogService(),
taskCancelManager: taskCancelManager,
downloadTokenService: backups_download.GetDownloadTokenService(),
backupSchedulerService: backuping.GetBackupsScheduler(),
}
var backupController = &BackupController{
backupService,
}
func SetupDependencies() {
backups_config.
GetBackupConfigService().
SetDatabaseStorageChangeListener(backupService)
databases.GetDatabaseService().AddDbRemoveListener(backupService)
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
backupService: backupService,
}
func GetBackupService() *BackupService {
@@ -68,10 +52,11 @@ func GetBackupController() *BackupController {
return backupController
}
func GetBackupBackgroundService() *BackupBackgroundService {
return backupBackgroundService
}
func SetupDependencies() {
backups_config.
GetBackupConfigService().
SetDatabaseStorageChangeListener(backupService)
func GetDownloadTokenBackgroundService() *download_token.DownloadTokenBackgroundService {
return download_token.GetDownloadTokenBackgroundService()
databases.GetDatabaseService().AddDbRemoveListener(backupService)
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
}

View File

@@ -0,0 +1,34 @@
package backups_download
import (
"context"
"log/slog"
"time"
)
type DownloadTokenBackgroundService struct {
downloadTokenService *DownloadTokenService
logger *slog.Logger
}
func (s *DownloadTokenBackgroundService) Run(ctx context.Context) {
s.logger.Info("Starting download token cleanup background service")
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
s.logger.Error("Failed to clean expired download tokens", "error", err)
}
}
}
}

View File

@@ -0,0 +1,81 @@
package backups_download
import (
"fmt"
"sync"
"github.com/google/uuid"
)
type BandwidthManager struct {
mu sync.RWMutex
activeDownloads map[uuid.UUID]*activeDownload
maxTotalBytesPerSecond int64
bytesPerSecondPerDownload int64
}
type activeDownload struct {
userID uuid.UUID
rateLimiter *RateLimiter
}
func NewBandwidthManager(throughputMBs int) *BandwidthManager {
// Use 75% of total throughput
maxBytes := int64(throughputMBs) * 1024 * 1024 * 75 / 100
return &BandwidthManager{
activeDownloads: make(map[uuid.UUID]*activeDownload),
maxTotalBytesPerSecond: maxBytes,
bytesPerSecondPerDownload: maxBytes,
}
}
func (bm *BandwidthManager) RegisterDownload(userID uuid.UUID) (*RateLimiter, error) {
bm.mu.Lock()
defer bm.mu.Unlock()
if _, exists := bm.activeDownloads[userID]; exists {
return nil, fmt.Errorf("download already registered for user %s", userID)
}
rateLimiter := NewRateLimiter(bm.bytesPerSecondPerDownload)
bm.activeDownloads[userID] = &activeDownload{
userID: userID,
rateLimiter: rateLimiter,
}
bm.recalculateRates()
return rateLimiter, nil
}
func (bm *BandwidthManager) UnregisterDownload(userID uuid.UUID) {
bm.mu.Lock()
defer bm.mu.Unlock()
delete(bm.activeDownloads, userID)
bm.recalculateRates()
}
func (bm *BandwidthManager) GetActiveDownloadCount() int {
bm.mu.RLock()
defer bm.mu.RUnlock()
return len(bm.activeDownloads)
}
func (bm *BandwidthManager) recalculateRates() {
activeCount := len(bm.activeDownloads)
if activeCount == 0 {
bm.bytesPerSecondPerDownload = bm.maxTotalBytesPerSecond
return
}
newRate := bm.maxTotalBytesPerSecond / int64(activeCount)
bm.bytesPerSecondPerDownload = newRate
for _, download := range bm.activeDownloads {
download.rateLimiter.UpdateRate(newRate)
}
}

View File

@@ -0,0 +1,150 @@
package backups_download
import (
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_BandwidthManager_RegisterSingleDownload(t *testing.T) {
throughputMBs := 100
manager := NewBandwidthManager(throughputMBs)
expectedBytesPerSec := int64(100 * 1024 * 1024 * 75 / 100)
assert.Equal(t, expectedBytesPerSec, manager.maxTotalBytesPerSecond)
assert.Equal(t, expectedBytesPerSec, manager.bytesPerSecondPerDownload)
userID := uuid.New()
rateLimiter, err := manager.RegisterDownload(userID)
assert.NoError(t, err)
assert.NotNil(t, rateLimiter)
assert.Equal(t, 1, manager.GetActiveDownloadCount())
assert.Equal(t, expectedBytesPerSec, manager.bytesPerSecondPerDownload)
assert.Equal(t, expectedBytesPerSec, rateLimiter.bytesPerSecond)
}
func Test_BandwidthManager_RegisterMultipleDownloads_BandwidthShared(t *testing.T) {
throughputMBs := 100
manager := NewBandwidthManager(throughputMBs)
maxBytes := int64(100 * 1024 * 1024 * 75 / 100)
user1 := uuid.New()
rateLimiter1, err := manager.RegisterDownload(user1)
assert.NoError(t, err)
assert.Equal(t, maxBytes, rateLimiter1.bytesPerSecond)
user2 := uuid.New()
rateLimiter2, err := manager.RegisterDownload(user2)
assert.NoError(t, err)
expectedPerDownload := maxBytes / 2
assert.Equal(t, expectedPerDownload, manager.bytesPerSecondPerDownload)
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
assert.Equal(t, expectedPerDownload, rateLimiter2.bytesPerSecond)
assert.Equal(t, expectedPerDownload, rateLimiter2.bytesPerSecond)
user3 := uuid.New()
rateLimiter3, err := manager.RegisterDownload(user3)
assert.NoError(t, err)
expectedPerDownload = maxBytes / 3
assert.Equal(t, expectedPerDownload, manager.bytesPerSecondPerDownload)
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
assert.Equal(t, expectedPerDownload, rateLimiter2.bytesPerSecond)
assert.Equal(t, expectedPerDownload, rateLimiter3.bytesPerSecond)
assert.Equal(t, 3, manager.GetActiveDownloadCount())
}
func Test_BandwidthManager_UnregisterDownload_BandwidthRebalanced(t *testing.T) {
throughputMBs := 100
manager := NewBandwidthManager(throughputMBs)
maxBytes := int64(100 * 1024 * 1024 * 75 / 100)
user1 := uuid.New()
rateLimiter1, _ := manager.RegisterDownload(user1)
user2 := uuid.New()
_, _ = manager.RegisterDownload(user2)
user3 := uuid.New()
rateLimiter3, _ := manager.RegisterDownload(user3)
assert.Equal(t, 3, manager.GetActiveDownloadCount())
expectedPerDownload := maxBytes / 3
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
manager.UnregisterDownload(user2)
assert.Equal(t, 2, manager.GetActiveDownloadCount())
expectedPerDownload = maxBytes / 2
assert.Equal(t, expectedPerDownload, manager.bytesPerSecondPerDownload)
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
assert.Equal(t, expectedPerDownload, rateLimiter3.bytesPerSecond)
manager.UnregisterDownload(user1)
assert.Equal(t, 1, manager.GetActiveDownloadCount())
assert.Equal(t, maxBytes, manager.bytesPerSecondPerDownload)
assert.Equal(t, maxBytes, rateLimiter3.bytesPerSecond)
manager.UnregisterDownload(user3)
assert.Equal(t, 0, manager.GetActiveDownloadCount())
assert.Equal(t, maxBytes, manager.bytesPerSecondPerDownload)
}
func Test_BandwidthManager_RegisterDuplicateUser_ReturnsError(t *testing.T) {
manager := NewBandwidthManager(100)
userID := uuid.New()
_, err := manager.RegisterDownload(userID)
assert.NoError(t, err)
_, err = manager.RegisterDownload(userID)
assert.Error(t, err)
assert.Contains(t, err.Error(), "download already registered")
}
func Test_RateLimiter_TokenBucketBasic(t *testing.T) {
bytesPerSec := int64(1024 * 1024)
limiter := NewRateLimiter(bytesPerSec)
assert.Equal(t, bytesPerSec, limiter.bytesPerSecond)
assert.Equal(t, bytesPerSec*2, limiter.bucketSize)
start := time.Now()
limiter.Wait(512 * 1024)
elapsed := time.Since(start)
assert.Less(t, elapsed, 100*time.Millisecond)
}
func Test_RateLimiter_UpdateRate(t *testing.T) {
limiter := NewRateLimiter(1024 * 1024)
assert.Equal(t, int64(1024*1024), limiter.bytesPerSecond)
newRate := int64(2 * 1024 * 1024)
limiter.UpdateRate(newRate)
assert.Equal(t, newRate, limiter.bytesPerSecond)
assert.Equal(t, newRate*2, limiter.bucketSize)
}
func Test_RateLimiter_ThrottlesCorrectly(t *testing.T) {
bytesPerSec := int64(1024 * 1024)
limiter := NewRateLimiter(bytesPerSec)
limiter.availableTokens = 0
start := time.Now()
limiter.Wait(bytesPerSec / 2)
elapsed := time.Since(start)
assert.GreaterOrEqual(t, elapsed, 400*time.Millisecond)
assert.LessOrEqual(t, elapsed, 700*time.Millisecond)
}

View File

@@ -0,0 +1,48 @@
package backups_download
import (
"databasus-backend/internal/config"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/logger"
)
var downloadTokenRepository = &DownloadTokenRepository{}
var downloadTracker = NewDownloadTracker(cache_utils.GetValkeyClient())
var bandwidthManager *BandwidthManager
var downloadTokenService *DownloadTokenService
var downloadTokenBackgroundService *DownloadTokenBackgroundService
func init() {
env := config.GetEnv()
throughputMBs := env.NodeNetworkThroughputMBs
if throughputMBs == 0 {
throughputMBs = 125
}
bandwidthManager = NewBandwidthManager(throughputMBs)
downloadTokenService = &DownloadTokenService{
downloadTokenRepository,
logger.GetLogger(),
downloadTracker,
bandwidthManager,
}
downloadTokenBackgroundService = &DownloadTokenBackgroundService{
downloadTokenService,
logger.GetLogger(),
}
}
func GetDownloadTokenService() *DownloadTokenService {
return downloadTokenService
}
func GetDownloadTokenBackgroundService() *DownloadTokenBackgroundService {
return downloadTokenBackgroundService
}
func GetBandwidthManager() *BandwidthManager {
return bandwidthManager
}

View File

@@ -0,0 +1,9 @@
package backups_download
import "github.com/google/uuid"
type GenerateDownloadTokenResponse struct {
Token string `json:"token"`
Filename string `json:"filename"`
BackupID uuid.UUID `json:"backupId"`
}

View File

@@ -1,4 +1,4 @@
package download_token
package backups_download
import (
"time"

View File

@@ -0,0 +1,101 @@
package backups_download
import (
"io"
"sync"
"time"
)
type RateLimiter struct {
mu sync.Mutex
bytesPerSecond int64
bucketSize int64
availableTokens float64
lastRefill time.Time
}
func NewRateLimiter(bytesPerSecond int64) *RateLimiter {
if bytesPerSecond <= 0 {
bytesPerSecond = 1024 * 1024 * 100
}
return &RateLimiter{
bytesPerSecond: bytesPerSecond,
bucketSize: bytesPerSecond * 2,
availableTokens: float64(bytesPerSecond * 2),
lastRefill: time.Now().UTC(),
}
}
func (rl *RateLimiter) UpdateRate(bytesPerSecond int64) {
rl.mu.Lock()
defer rl.mu.Unlock()
if bytesPerSecond <= 0 {
bytesPerSecond = 1024 * 1024 * 100
}
rl.bytesPerSecond = bytesPerSecond
rl.bucketSize = bytesPerSecond * 2
if rl.availableTokens > float64(rl.bucketSize) {
rl.availableTokens = float64(rl.bucketSize)
}
}
func (rl *RateLimiter) Wait(bytes int64) {
rl.mu.Lock()
defer rl.mu.Unlock()
for {
now := time.Now().UTC()
elapsed := now.Sub(rl.lastRefill).Seconds()
tokensToAdd := elapsed * float64(rl.bytesPerSecond)
rl.availableTokens += tokensToAdd
if rl.availableTokens > float64(rl.bucketSize) {
rl.availableTokens = float64(rl.bucketSize)
}
rl.lastRefill = now
if rl.availableTokens >= float64(bytes) {
rl.availableTokens -= float64(bytes)
return
}
tokensNeeded := float64(bytes) - rl.availableTokens
waitTime := time.Duration(tokensNeeded/float64(rl.bytesPerSecond)*1000) * time.Millisecond
if waitTime < time.Millisecond {
waitTime = time.Millisecond
}
rl.mu.Unlock()
time.Sleep(waitTime)
rl.mu.Lock()
}
}
type RateLimitedReader struct {
reader io.ReadCloser
rateLimiter *RateLimiter
}
func NewRateLimitedReader(reader io.ReadCloser, limiter *RateLimiter) *RateLimitedReader {
return &RateLimitedReader{
reader: reader,
rateLimiter: limiter,
}
}
func (r *RateLimitedReader) Read(p []byte) (n int, err error) {
n, err = r.reader.Read(p)
if n > 0 {
r.rateLimiter.Wait(int64(n))
}
return n, err
}
func (r *RateLimitedReader) Close() error {
return r.reader.Close()
}

View File

@@ -1,4 +1,4 @@
package download_token
package backups_download
import (
"crypto/rand"

View File

@@ -0,0 +1,105 @@
package backups_download
import (
"errors"
"log/slog"
"time"
"github.com/google/uuid"
)
type DownloadTokenService struct {
repository *DownloadTokenRepository
logger *slog.Logger
downloadTracker *DownloadTracker
bandwidthManager *BandwidthManager
}
func (s *DownloadTokenService) Generate(backupID, userID uuid.UUID) (string, error) {
if s.downloadTracker.IsDownloadInProgress(userID) {
return "", ErrDownloadAlreadyInProgress
}
token := GenerateSecureToken()
downloadToken := &DownloadToken{
Token: token,
BackupID: backupID,
UserID: userID,
ExpiresAt: time.Now().UTC().Add(5 * time.Minute),
Used: false,
}
if err := s.repository.Create(downloadToken); err != nil {
return "", err
}
s.logger.Info("Generated download token", "backupId", backupID, "userId", userID)
return token, nil
}
func (s *DownloadTokenService) ValidateAndConsume(
token string,
) (*DownloadToken, *RateLimiter, error) {
dt, err := s.repository.FindByToken(token)
if err != nil {
return nil, nil, err
}
if dt == nil {
return nil, nil, errors.New("invalid token")
}
if dt.Used {
return nil, nil, errors.New("token already used")
}
if time.Now().UTC().After(dt.ExpiresAt) {
return nil, nil, errors.New("token expired")
}
if err := s.downloadTracker.AcquireDownloadLock(dt.UserID); err != nil {
return nil, nil, err
}
rateLimiter, err := s.bandwidthManager.RegisterDownload(dt.UserID)
if err != nil {
s.downloadTracker.ReleaseDownloadLock(dt.UserID)
return nil, nil, err
}
dt.Used = true
if err := s.repository.Update(dt); err != nil {
s.logger.Error("Failed to mark token as used", "error", err)
}
s.logger.Info("Token validated and consumed", "backupId", dt.BackupID, "userId", dt.UserID)
return dt, rateLimiter, nil
}
func (s *DownloadTokenService) RefreshDownloadLock(userID uuid.UUID) {
s.downloadTracker.RefreshDownloadLock(userID)
}
func (s *DownloadTokenService) ReleaseDownloadLock(userID uuid.UUID) {
s.downloadTracker.ReleaseDownloadLock(userID)
s.logger.Info("Released download lock", "userId", userID)
}
func (s *DownloadTokenService) IsDownloadInProgress(userID uuid.UUID) bool {
return s.downloadTracker.IsDownloadInProgress(userID)
}
func (s *DownloadTokenService) UnregisterDownload(userID uuid.UUID) {
s.bandwidthManager.UnregisterDownload(userID)
s.logger.Info("Unregistered from bandwidth manager", "userId", userID)
}
func (s *DownloadTokenService) CleanExpiredTokens() error {
now := time.Now().UTC()
if err := s.repository.DeleteExpired(now); err != nil {
return err
}
s.logger.Debug("Cleaned expired download tokens")
return nil
}

View File

@@ -0,0 +1,66 @@
package backups_download
import (
cache_utils "databasus-backend/internal/util/cache"
"errors"
"time"
"github.com/google/uuid"
"github.com/valkey-io/valkey-go"
)
const (
downloadLockPrefix = "backup_download_lock:"
downloadLockTTL = 5 * time.Second
downloadLockValue = "1"
downloadHeartbeatDelay = 3 * time.Second
)
var (
ErrDownloadAlreadyInProgress = errors.New("download already in progress for this user")
)
type DownloadTracker struct {
cache *cache_utils.CacheUtil[string]
}
func NewDownloadTracker(client valkey.Client) *DownloadTracker {
return &DownloadTracker{
cache: cache_utils.NewCacheUtil[string](client, downloadLockPrefix),
}
}
func (t *DownloadTracker) AcquireDownloadLock(userID uuid.UUID) error {
key := userID.String()
existingLock := t.cache.Get(key)
if existingLock != nil {
return ErrDownloadAlreadyInProgress
}
value := downloadLockValue
t.cache.Set(key, &value)
return nil
}
func (t *DownloadTracker) RefreshDownloadLock(userID uuid.UUID) {
key := userID.String()
value := downloadLockValue
t.cache.Set(key, &value)
}
func (t *DownloadTracker) ReleaseDownloadLock(userID uuid.UUID) {
key := userID.String()
t.cache.Invalidate(key)
}
func (t *DownloadTracker) IsDownloadInProgress(userID uuid.UUID) bool {
key := userID.String()
existingLock := t.cache.Get(key)
return existingLock != nil
}
func GetDownloadHeartbeatInterval() time.Duration {
return downloadHeartbeatDelay
}

View File

@@ -1,32 +0,0 @@
package download_token
import (
"databasus-backend/internal/config"
"log/slog"
"time"
)
type DownloadTokenBackgroundService struct {
downloadTokenService *DownloadTokenService
logger *slog.Logger
}
func (s *DownloadTokenBackgroundService) Run() {
s.logger.Info("Starting download token cleanup background service")
if config.IsShouldShutdown() {
return
}
for {
if config.IsShouldShutdown() {
return
}
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
s.logger.Error("Failed to clean expired download tokens", "error", err)
}
time.Sleep(1 * time.Minute)
}
}

View File

@@ -1,25 +0,0 @@
package download_token
import (
"databasus-backend/internal/util/logger"
)
var downloadTokenRepository = &DownloadTokenRepository{}
var downloadTokenService = &DownloadTokenService{
downloadTokenRepository,
logger.GetLogger(),
}
var downloadTokenBackgroundService = &DownloadTokenBackgroundService{
downloadTokenService,
logger.GetLogger(),
}
func GetDownloadTokenService() *DownloadTokenService {
return downloadTokenService
}
func GetDownloadTokenBackgroundService() *DownloadTokenBackgroundService {
return downloadTokenBackgroundService
}

View File

@@ -1,69 +0,0 @@
package download_token
import (
"errors"
"log/slog"
"time"
"github.com/google/uuid"
)
type DownloadTokenService struct {
repository *DownloadTokenRepository
logger *slog.Logger
}
func (s *DownloadTokenService) Generate(backupID, userID uuid.UUID) (string, error) {
token := GenerateSecureToken()
downloadToken := &DownloadToken{
Token: token,
BackupID: backupID,
UserID: userID,
ExpiresAt: time.Now().UTC().Add(5 * time.Minute),
Used: false,
}
if err := s.repository.Create(downloadToken); err != nil {
return "", err
}
s.logger.Info("Generated download token", "backupId", backupID, "userId", userID)
return token, nil
}
func (s *DownloadTokenService) ValidateAndConsume(token string) (*DownloadToken, error) {
dt, err := s.repository.FindByToken(token)
if err != nil {
return nil, err
}
if dt == nil {
return nil, errors.New("invalid token")
}
if dt.Used {
return nil, errors.New("token already used")
}
if time.Now().UTC().After(dt.ExpiresAt) {
return nil, errors.New("token expired")
}
dt.Used = true
if err := s.repository.Update(dt); err != nil {
s.logger.Error("Failed to mark token as used", "error", err)
}
s.logger.Info("Token validated and consumed", "backupId", dt.BackupID)
return dt, nil
}
func (s *DownloadTokenService) CleanExpiredTokens() error {
now := time.Now().UTC()
if err := s.repository.DeleteExpired(now); err != nil {
return err
}
s.logger.Debug("Cleaned expired download tokens")
return nil
}

View File

@@ -1,10 +1,9 @@
package backups
import (
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/encryption"
"io"
"github.com/google/uuid"
)
type GetBackupsRequest struct {
@@ -14,23 +13,17 @@ type GetBackupsRequest struct {
}
type GetBackupsResponse struct {
Backups []*Backup `json:"backups"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
Backups []*backups_core.Backup `json:"backups"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}
type GenerateDownloadTokenResponse struct {
Token string `json:"token"`
Filename string `json:"filename"`
BackupID uuid.UUID `json:"backupId"`
}
type decryptionReaderCloser struct {
type DecryptionReaderCloser struct {
*encryption.DecryptionReader
baseReader io.ReadCloser
BaseReader io.ReadCloser
}
func (r *decryptionReaderCloser) Close() error {
return r.baseReader.Close()
func (r *DecryptionReaderCloser) Close() error {
return r.BaseReader.Close()
}

View File

@@ -1,24 +1,23 @@
package backups
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"log/slog"
"slices"
"strings"
"time"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups/download_token"
"databasus-backend/internal/features/backups/backups/backuping"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
"databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
users_models "databasus-backend/internal/features/users/models"
workspaces_services "databasus-backend/internal/features/workspaces/services"
util_encryption "databasus-backend/internal/util/encryption"
@@ -29,26 +28,27 @@ import (
type BackupService struct {
databaseService *databases.DatabaseService
storageService *storages.StorageService
backupRepository *BackupRepository
backupRepository *backups_core.BackupRepository
notifierService *notifiers.NotifierService
notificationSender NotificationSender
notificationSender backups_core.NotificationSender
backupConfigService *backups_config.BackupConfigService
secretKeyService *encryption_secrets.SecretKeyService
fieldEncryptor util_encryption.FieldEncryptor
createBackupUseCase CreateBackupUsecase
createBackupUseCase backups_core.CreateBackupUsecase
logger *slog.Logger
backupRemoveListeners []BackupRemoveListener
backupRemoveListeners []backups_core.BackupRemoveListener
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
backupContextManager *BackupContextManager
downloadTokenService *download_token.DownloadTokenService
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
taskCancelManager *task_cancellation.TaskCancelManager
downloadTokenService *backups_download.DownloadTokenService
backupSchedulerService *backuping.BackupsScheduler
}
func (s *BackupService) AddBackupRemoveListener(listener BackupRemoveListener) {
func (s *BackupService) AddBackupRemoveListener(listener backups_core.BackupRemoveListener) {
s.backupRemoveListeners = append(s.backupRemoveListeners, listener)
}
@@ -91,7 +91,7 @@ func (s *BackupService) MakeBackupWithAuth(
return errors.New("insufficient permissions to create backup for this database")
}
go s.MakeBackup(databaseID, true)
s.backupSchedulerService.StartBackup(databaseID, true)
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Backup manually initiated for database: %s", database.Name),
@@ -175,7 +175,7 @@ func (s *BackupService) DeleteBackup(
return errors.New("insufficient permissions to delete backup for this database")
}
if backup.Status == BackupStatusInProgress {
if backup.Status == backups_core.BackupStatusInProgress {
return errors.New("backup is in progress")
}
@@ -192,260 +192,7 @@ func (s *BackupService) DeleteBackup(
return s.deleteBackup(backup)
}
func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
database, err := s.databaseService.GetDatabaseByID(databaseID)
if err != nil {
s.logger.Error("Failed to get database by ID", "error", err)
return
}
lastBackup, err := s.backupRepository.FindLastByDatabaseID(databaseID)
if err != nil {
s.logger.Error("Failed to find last backup by database ID", "error", err)
return
}
if lastBackup != nil && lastBackup.Status == BackupStatusInProgress {
s.logger.Error("Backup is in progress")
return
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(databaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
return
}
if backupConfig.StorageID == nil {
s.logger.Error("Backup config storage ID is not defined")
return
}
storage, err := s.storageService.GetStorageByID(*backupConfig.StorageID)
if err != nil {
s.logger.Error("Failed to get storage by ID", "error", err)
return
}
backup := &Backup{
DatabaseID: databaseID,
StorageID: storage.ID,
Status: BackupStatusInProgress,
BackupSizeMb: 0,
CreatedAt: time.Now().UTC(),
}
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("Failed to save backup", "error", err)
return
}
start := time.Now().UTC()
backupProgressListener := func(
completedMBs float64,
) {
backup.BackupSizeMb = completedMBs
backup.BackupDurationMs = time.Since(start).Milliseconds()
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("Failed to update backup progress", "error", err)
}
}
ctx, cancel := context.WithCancel(context.Background())
s.backupContextManager.RegisterBackup(backup.ID, cancel)
defer s.backupContextManager.UnregisterBackup(backup.ID)
backupMetadata, err := s.createBackupUseCase.Execute(
ctx,
backup.ID,
backupConfig,
database,
storage,
backupProgressListener,
)
if err != nil {
errMsg := err.Error()
// Check if backup was cancelled (not due to shutdown)
isCancelled := strings.Contains(errMsg, "backup cancelled") ||
strings.Contains(errMsg, "context canceled") ||
errors.Is(err, context.Canceled)
isShutdown := strings.Contains(errMsg, "shutdown")
if isCancelled && !isShutdown {
backup.Status = BackupStatusCanceled
backup.BackupDurationMs = time.Since(start).Milliseconds()
backup.BackupSizeMb = 0
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("Failed to save cancelled backup", "error", err)
}
// Delete partial backup from storage
storage, storageErr := s.storageService.GetStorageByID(backup.StorageID)
if storageErr == nil {
if deleteErr := storage.DeleteFile(s.fieldEncryptor, backup.ID); deleteErr != nil {
s.logger.Error(
"Failed to delete partial backup file",
"backupId",
backup.ID,
"error",
deleteErr,
)
}
}
return
}
backup.FailMessage = &errMsg
backup.Status = BackupStatusFailed
backup.BackupDurationMs = time.Since(start).Milliseconds()
backup.BackupSizeMb = 0
if updateErr := s.databaseService.SetBackupError(databaseID, errMsg); updateErr != nil {
s.logger.Error(
"Failed to update database last backup time",
"databaseId",
databaseID,
"error",
updateErr,
)
}
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("Failed to save backup", "error", err)
}
s.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupFailed,
&errMsg,
)
return
}
backup.Status = BackupStatusCompleted
backup.BackupDurationMs = time.Since(start).Milliseconds()
// Update backup with encryption metadata if provided
if backupMetadata != nil {
backup.EncryptionSalt = backupMetadata.EncryptionSalt
backup.EncryptionIV = backupMetadata.EncryptionIV
backup.Encryption = backupMetadata.Encryption
}
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("Failed to save backup", "error", err)
return
}
// Update database last backup time
now := time.Now().UTC()
if updateErr := s.databaseService.SetLastBackupTime(databaseID, now); updateErr != nil {
s.logger.Error(
"Failed to update database last backup time",
"databaseId",
databaseID,
"error",
updateErr,
)
}
if backup.Status != BackupStatusCompleted && !isLastTry {
return
}
s.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupSuccess,
nil,
)
}
func (s *BackupService) SendBackupNotification(
backupConfig *backups_config.BackupConfig,
backup *Backup,
notificationType backups_config.BackupNotificationType,
errorMessage *string,
) {
database, err := s.databaseService.GetDatabaseByID(backupConfig.DatabaseID)
if err != nil {
return
}
workspace, err := s.workspaceService.GetWorkspaceByID(*database.WorkspaceID)
if err != nil {
return
}
for _, notifier := range database.Notifiers {
if !slices.Contains(
backupConfig.SendNotificationsOn,
notificationType,
) {
continue
}
title := ""
switch notificationType {
case backups_config.NotificationBackupFailed:
title = fmt.Sprintf(
"❌ Backup failed for database \"%s\" (workspace \"%s\")",
database.Name,
workspace.Name,
)
case backups_config.NotificationBackupSuccess:
title = fmt.Sprintf(
"✅ Backup completed for database \"%s\" (workspace \"%s\")",
database.Name,
workspace.Name,
)
}
message := ""
if errorMessage != nil {
message = *errorMessage
} else {
// Format size conditionally
var sizeStr string
if backup.BackupSizeMb < 1024 {
sizeStr = fmt.Sprintf("%.2f MB", backup.BackupSizeMb)
} else {
sizeGB := backup.BackupSizeMb / 1024
sizeStr = fmt.Sprintf("%.2f GB", sizeGB)
}
// Format duration as "0m 0s 0ms"
totalMs := backup.BackupDurationMs
minutes := totalMs / (1000 * 60)
seconds := (totalMs % (1000 * 60)) / 1000
durationStr := fmt.Sprintf("%dm %ds", minutes, seconds)
message = fmt.Sprintf(
"Backup completed successfully in %s.\nCompressed backup size: %s",
durationStr,
sizeStr,
)
}
s.notificationSender.SendNotification(
&notifier,
title,
message,
)
}
}
func (s *BackupService) GetBackup(backupID uuid.UUID) (*Backup, error) {
func (s *BackupService) GetBackup(backupID uuid.UUID) (*backups_core.Backup, error) {
return s.backupRepository.FindByID(backupID)
}
@@ -475,11 +222,11 @@ func (s *BackupService) CancelBackup(
return errors.New("insufficient permissions to cancel backup for this database")
}
if backup.Status != BackupStatusInProgress {
if backup.Status != backups_core.BackupStatusInProgress {
return errors.New("backup is not in progress")
}
if err := s.backupContextManager.CancelBackup(backupID); err != nil {
if err := s.taskCancelManager.CancelTask(backupID); err != nil {
return err
}
@@ -499,7 +246,7 @@ func (s *BackupService) CancelBackup(
func (s *BackupService) GetBackupFile(
user *users_models.User,
backupID uuid.UUID,
) (io.ReadCloser, *Backup, *databases.Database, error) {
) (io.ReadCloser, *backups_core.Backup, *databases.Database, error) {
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
return nil, nil, nil, err
@@ -545,7 +292,7 @@ func (s *BackupService) GetBackupFile(
return reader, backup, database, nil
}
func (s *BackupService) deleteBackup(backup *Backup) error {
func (s *BackupService) deleteBackup(backup *backups_core.Backup) error {
for _, listener := range s.backupRemoveListeners {
if err := listener.OnBeforeBackupRemove(backup); err != nil {
return err
@@ -571,7 +318,7 @@ func (s *BackupService) deleteBackup(backup *Backup) error {
func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
dbBackupsInProgress, err := s.backupRepository.FindByDatabaseIdAndStatus(
databaseID,
BackupStatusInProgress,
backups_core.BackupStatusInProgress,
)
if err != nil {
return err
@@ -680,16 +427,16 @@ func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, erro
s.logger.Info("Returning encrypted backup with decryption", "backupId", backupID)
return &decryptionReaderCloser{
decryptionReader,
fileReader,
return &DecryptionReaderCloser{
DecryptionReader: decryptionReader,
BaseReader: fileReader,
}, nil
}
func (s *BackupService) GenerateDownloadToken(
user *users_models.User,
backupID uuid.UUID,
) (*GenerateDownloadTokenResponse, error) {
) (*backups_download.GenerateDownloadTokenResponse, error) {
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
return nil, err
@@ -725,20 +472,22 @@ func (s *BackupService) GenerateDownloadToken(
database.WorkspaceID,
)
return &GenerateDownloadTokenResponse{
return &backups_download.GenerateDownloadTokenResponse{
Token: token,
Filename: filename,
BackupID: backupID,
}, nil
}
func (s *BackupService) ValidateDownloadToken(token string) (*download_token.DownloadToken, error) {
func (s *BackupService) ValidateDownloadToken(
token string,
) (*backups_download.DownloadToken, *backups_download.RateLimiter, error) {
return s.downloadTokenService.ValidateAndConsume(token)
}
func (s *BackupService) GetBackupFileWithoutAuth(
backupID uuid.UUID,
) (io.ReadCloser, *Backup, *databases.Database, error) {
) (io.ReadCloser, *backups_core.Backup, *databases.Database, error) {
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
return nil, nil, nil, err
@@ -759,7 +508,7 @@ func (s *BackupService) GetBackupFileWithoutAuth(
func (s *BackupService) WriteAuditLogForDownload(
userID uuid.UUID,
backup *Backup,
backup *backups_core.Backup,
database *databases.Database,
) {
s.auditLogService.WriteAuditLog(
@@ -773,8 +522,24 @@ func (s *BackupService) WriteAuditLogForDownload(
)
}
func (s *BackupService) RefreshDownloadLock(userID uuid.UUID) {
s.downloadTokenService.RefreshDownloadLock(userID)
}
func (s *BackupService) ReleaseDownloadLock(userID uuid.UUID) {
s.downloadTokenService.ReleaseDownloadLock(userID)
}
func (s *BackupService) IsDownloadInProgress(userID uuid.UUID) bool {
return s.downloadTokenService.IsDownloadInProgress(userID)
}
func (s *BackupService) UnregisterDownload(userID uuid.UUID) {
s.downloadTokenService.UnregisterDownload(userID)
}
func (s *BackupService) generateBackupFilename(
backup *Backup,
backup *backups_core.Backup,
database *databases.Database,
) string {
timestamp := backup.CreatedAt.Format("2006-01-02_15-04-05")

View File

@@ -4,6 +4,7 @@ import (
"testing"
"time"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
@@ -58,9 +59,9 @@ func WaitForBackupCompletion(
newestBackup := backups[0]
t.Logf("WaitForBackupCompletion: newest backup status: %s", newestBackup.Status)
if newestBackup.Status == BackupStatusCompleted ||
newestBackup.Status == BackupStatusFailed ||
newestBackup.Status == BackupStatusCanceled {
if newestBackup.Status == backups_core.BackupStatusCompleted ||
newestBackup.Status == backups_core.BackupStatusFailed ||
newestBackup.Status == backups_core.BackupStatusCanceled {
t.Logf(
"WaitForBackupCompletion: backup finished with status %s",
newestBackup.Status,

View File

@@ -122,6 +122,7 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
if mdb.IsHttps {
args = append(args, "--ssl")
args = append(args, "--skip-ssl-verify-server-cert")
}
if mdb.Database != nil && *mdb.Database != "" {

View File

@@ -515,11 +515,13 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
hasProcess := false
hasAllPrivileges := false
escapedDB := strings.ReplaceAll(database, "_", "\\_")
dbPattern := regexp.MustCompile(
fmt.Sprintf("(?i)ON\\s+[`'\"]?(%s|\\*)[`'\"]?\\.\\*", regexp.QuoteMeta(escapedDB)),
dbPatternStr := fmt.Sprintf(
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
regexp.QuoteMeta(database),
)
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\.\*`)
dbPattern := regexp.MustCompile(dbPatternStr)
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)
allPrivilegesPattern := regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`)
for rows.Next() {
var grant string
@@ -527,23 +529,26 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
return "", fmt.Errorf("failed to scan grant: %w", err)
}
if regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`).MatchString(grant) {
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
hasAllPrivileges = true
}
isRelevantGrant := globalPattern.MatchString(grant) || dbPattern.MatchString(grant)
if allPrivilegesPattern.MatchString(grant) && isRelevantGrant {
hasAllPrivileges = true
}
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
if isRelevantGrant {
for _, priv := range backupPrivileges {
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
privPattern := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(priv) + `\b`)
if privPattern.MatchString(grant) {
detectedPrivileges[priv] = true
}
}
}
if globalPattern.MatchString(grant) &&
regexp.MustCompile(`(?i)\bPROCESS\b`).MatchString(grant) {
hasProcess = true
if globalPattern.MatchString(grant) {
processPattern := regexp.MustCompile(`(?i)\bPROCESS\b`)
if processPattern.MatchString(grant) {
hasProcess = true
}
}
}

View File

@@ -537,6 +537,163 @@ func Test_ReadOnlyUser_CannotDropOrAlterTables(t *testing.T) {
dropUserSafe(container.DB, username)
}
func Test_TestConnection_DatabaseSpecificPrivilegesWithGlobalProcess_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MariadbVersion
port string
}{
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMariadbContainer(t, tc.port, tc.version)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS privilege_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`CREATE TABLE privilege_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO privilege_test (data) VALUES ('test1')`)
assert.NoError(t, err)
specificUsername := fmt.Sprintf("spec_%s", uuid.New().String()[:8])
specificPassword := "specificpass123"
_, err = container.DB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
specificUsername,
specificPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT SELECT, SHOW VIEW ON %s.* TO '%s'@'%%'",
container.Database,
specificUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT PROCESS ON *.* TO '%s'@'%%'",
specificUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer dropUserSafe(container.DB, specificUsername)
mariadbModel := &MariadbDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: specificUsername,
Password: specificPassword,
Database: &container.Database,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mariadbModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
})
}
}
func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
env := config.GetEnv()
container := connectToMariadbContainer(t, env.TestMariadb1011Port, tools.MariadbVersion1011)
defer container.DB.Close()
underscoreDbName := "test_db_name"
_, err := container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
}()
underscoreDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
container.Username, container.Password, container.Host, container.Port, underscoreDbName)
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
assert.NoError(t, err)
defer underscoreDB.Close()
_, err = underscoreDB.Exec(`
CREATE TABLE underscore_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)
`)
assert.NoError(t, err)
_, err = underscoreDB.Exec(`INSERT INTO underscore_test (data) VALUES ('test1')`)
assert.NoError(t, err)
underscoreUsername := fmt.Sprintf("under%s", uuid.New().String()[:8])
underscorePassword := "underscorepass123"
_, err = underscoreDB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
underscoreUsername,
underscorePassword,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec(fmt.Sprintf(
"GRANT SELECT, SHOW VIEW ON `%s`.* TO '%s'@'%%'",
underscoreDbName,
underscoreUsername,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer dropUserSafe(underscoreDB, underscoreUsername)
mariadbModel := &MariadbDatabase{
Version: tools.MariadbVersion1011,
Host: container.Host,
Port: container.Port,
Username: underscoreUsername,
Password: underscorePassword,
Database: &underscoreDbName,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mariadbModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
}
type MariadbContainer struct {
Host string
Port int

View File

@@ -400,6 +400,7 @@ func HasPrivilege(privileges, priv string) bool {
func (m *MysqlDatabase) buildDSN(password string, database string) string {
tlsConfig := "false"
allowCleartext := ""
if m.IsHttps {
err := mysql.RegisterTLSConfig("mysql-skip-verify", &tls.Config{
@@ -411,16 +412,18 @@ func (m *MysqlDatabase) buildDSN(password string, database string) string {
}
tlsConfig = "mysql-skip-verify"
allowCleartext = "&allowCleartextPasswords=1"
}
return fmt.Sprintf(
"%s:%s@tcp(%s:%d)/%s?parseTime=true&timeout=15s&tls=%s&charset=utf8mb4",
"%s:%s@tcp(%s:%d)/%s?parseTime=true&timeout=15s&tls=%s&charset=utf8mb4%s",
m.Username,
password,
m.Host,
m.Port,
database,
tlsConfig,
allowCleartext,
)
}
@@ -486,11 +489,13 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
hasProcess := false
hasAllPrivileges := false
escapedDB := strings.ReplaceAll(database, "_", "\\_")
dbPattern := regexp.MustCompile(
fmt.Sprintf("(?i)ON\\s+[`'\"]?(%s|\\*)[`'\"]?\\.\\*", regexp.QuoteMeta(escapedDB)),
dbPatternStr := fmt.Sprintf(
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
regexp.QuoteMeta(database),
)
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\.\*`)
dbPattern := regexp.MustCompile(dbPatternStr)
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)
allPrivilegesPattern := regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`)
for rows.Next() {
var grant string
@@ -498,23 +503,26 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
return "", fmt.Errorf("failed to scan grant: %w", err)
}
if regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`).MatchString(grant) {
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
hasAllPrivileges = true
}
isRelevantGrant := globalPattern.MatchString(grant) || dbPattern.MatchString(grant)
if allPrivilegesPattern.MatchString(grant) && isRelevantGrant {
hasAllPrivileges = true
}
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
if isRelevantGrant {
for _, priv := range backupPrivileges {
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
privPattern := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(priv) + `\b`)
if privPattern.MatchString(grant) {
detectedPrivileges[priv] = true
}
}
}
if globalPattern.MatchString(grant) &&
regexp.MustCompile(`(?i)\bPROCESS\b`).MatchString(grant) {
hasProcess = true
if globalPattern.MatchString(grant) {
processPattern := regexp.MustCompile(`(?i)\bPROCESS\b`)
if processPattern.MatchString(grant) {
hasProcess = true
}
}
}

View File

@@ -518,6 +518,162 @@ func Test_ReadOnlyUser_CannotDropOrAlterTables(t *testing.T) {
assert.NoError(t, err)
}
func Test_TestConnection_DatabaseSpecificPrivilegesWithGlobalProcess_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MysqlVersion
port string
}{
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
{"MySQL 9", tools.MysqlVersion9, env.TestMysql90Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMysqlContainer(t, tc.port, tc.version)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS privilege_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`CREATE TABLE privilege_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO privilege_test (data) VALUES ('test1')`)
assert.NoError(t, err)
specificUsername := fmt.Sprintf("specific_%s", uuid.New().String()[:8])
specificPassword := "specificpass123"
_, err = container.DB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
specificUsername,
specificPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT SELECT, SHOW VIEW ON %s.* TO '%s'@'%%'",
container.Database,
specificUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT PROCESS ON *.* TO '%s'@'%%'",
specificUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(
fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", specificUsername),
)
}()
mysqlModel := &MysqlDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: specificUsername,
Password: specificPassword,
Database: &container.Database,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mysqlModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
})
}
}
func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
env := config.GetEnv()
container := connectToMysqlContainer(t, env.TestMysql80Port, tools.MysqlVersion80)
defer container.DB.Close()
underscoreDbName := "test_db_name"
_, err := container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
}()
underscoreDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
container.Username, container.Password, container.Host, container.Port, underscoreDbName)
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
assert.NoError(t, err)
defer underscoreDB.Close()
_, err = underscoreDB.Exec(`
CREATE TABLE underscore_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)
`)
assert.NoError(t, err)
_, err = underscoreDB.Exec(`INSERT INTO underscore_test (data) VALUES ('test1')`)
assert.NoError(t, err)
underscoreUsername := fmt.Sprintf("under_%s", uuid.New().String()[:8])
underscorePassword := "underscorepass123"
_, err = underscoreDB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
underscoreUsername,
underscorePassword,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec(fmt.Sprintf(
"GRANT SELECT, SHOW VIEW ON `%s`.* TO '%s'@'%%'",
underscoreDbName,
underscoreUsername,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer func() {
_, _ = underscoreDB.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", underscoreUsername))
}()
mysqlModel := &MysqlDatabase{
Version: tools.MysqlVersion80,
Host: container.Host,
Port: container.Port,
Username: underscoreUsername,
Password: underscorePassword,
Database: &underscoreDbName,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mysqlModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
}
type MysqlContainer struct {
Host string
Port int

View File

@@ -85,6 +85,27 @@ func (p *PostgresqlDatabase) Validate() error {
return errors.New("cpu count must be greater than 0")
}
// Prevent Databasus from backing up itself
// Databasus runs an internal PostgreSQL instance that should not be backed up through the UI
// because it would expose internal metadata to non-system administrators.
// To properly backup Databasus, see: https://databasus.com/faq#backup-databasus
if p.Database != nil && *p.Database != "" {
localhostHosts := []string{"localhost", "127.0.0.1", "172.17.0.1", "host.docker.internal"}
isLocalhost := false
for _, host := range localhostHosts {
if strings.EqualFold(p.Host, host) {
isLocalhost = true
break
}
}
if isLocalhost && strings.EqualFold(*p.Database, "databasus") {
return errors.New(
"backing up Databasus internal database is not allowed. To backup Databasus itself, see https://databasus.com/faq#backup-databasus",
)
}
}
return nil
}
@@ -671,7 +692,7 @@ func testSingleDatabaseConnection(
postgresDb.Version = detectedVersion
// Verify user has sufficient permissions for backup operations
if err := checkBackupPermissions(ctx, conn, *postgresDb.Database); err != nil {
if err := checkBackupPermissions(ctx, conn, postgresDb.IncludeSchemas); err != nil {
return err
}
@@ -707,7 +728,12 @@ func detectDatabaseVersion(ctx context.Context, conn *pgx.Conn) (tools.Postgresq
// checkBackupPermissions verifies the user has sufficient privileges for pg_dump backup.
// Required privileges: CONNECT on database, USAGE on schemas, SELECT on tables.
func checkBackupPermissions(ctx context.Context, conn *pgx.Conn, dbName string) error {
// If includeSchemas is specified, only checks permissions on those schemas.
func checkBackupPermissions(
ctx context.Context,
conn *pgx.Conn,
includeSchemas []string,
) error {
var missingPrivileges []string
// Check CONNECT privilege on database
@@ -723,14 +749,29 @@ func checkBackupPermissions(ctx context.Context, conn *pgx.Conn, dbName string)
// Check USAGE privilege on at least one non-system schema
var schemaCount int
err = conn.QueryRow(ctx, `
SELECT COUNT(*)
FROM pg_namespace n
WHERE has_schema_privilege(current_user, n.nspname, 'USAGE')
AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
AND n.nspname NOT LIKE 'pg_temp_%'
AND n.nspname NOT LIKE 'pg_toast_temp_%'
`).Scan(&schemaCount)
if len(includeSchemas) > 0 {
// Check only the specified schemas
err = conn.QueryRow(ctx, `
SELECT COUNT(*)
FROM pg_namespace n
WHERE has_schema_privilege(current_user, n.nspname, 'USAGE')
AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
AND n.nspname NOT LIKE 'pg_temp_%'
AND n.nspname NOT LIKE 'pg_toast_temp_%'
AND n.nspname = ANY($1::text[])
`, includeSchemas).Scan(&schemaCount)
} else {
// Check all non-system schemas
err = conn.QueryRow(ctx, `
SELECT COUNT(*)
FROM pg_namespace n
WHERE has_schema_privilege(current_user, n.nspname, 'USAGE')
AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
AND n.nspname NOT LIKE 'pg_temp_%'
AND n.nspname NOT LIKE 'pg_toast_temp_%'
`).Scan(&schemaCount)
}
if err != nil {
return fmt.Errorf("cannot check schema privileges: %w", err)
}
@@ -741,11 +782,28 @@ func checkBackupPermissions(ctx context.Context, conn *pgx.Conn, dbName string)
// Check SELECT privilege on at least one table (if tables exist)
// Use pg_tables from pg_catalog which shows all tables regardless of user privileges
var tableCount int
err = conn.QueryRow(ctx, `
SELECT COUNT(*)
FROM pg_catalog.pg_tables t
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
`).Scan(&tableCount)
if len(includeSchemas) > 0 {
// Check only tables in the specified schemas
err = conn.QueryRow(ctx, `
SELECT COUNT(*)
FROM pg_catalog.pg_tables t
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
AND t.schemaname NOT LIKE 'pg_temp_%'
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
AND t.schemaname = ANY($1::text[])
`, includeSchemas).Scan(&tableCount)
} else {
// Check all tables in non-system schemas
err = conn.QueryRow(ctx, `
SELECT COUNT(*)
FROM pg_catalog.pg_tables t
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
AND t.schemaname NOT LIKE 'pg_temp_%'
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
`).Scan(&tableCount)
}
if err != nil {
return fmt.Errorf("cannot check table count: %w", err)
}
@@ -753,12 +811,30 @@ func checkBackupPermissions(ctx context.Context, conn *pgx.Conn, dbName string)
if tableCount > 0 {
// Check if user has SELECT on at least one of these tables
var selectableTableCount int
err = conn.QueryRow(ctx, `
SELECT COUNT(*)
FROM pg_catalog.pg_tables t
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
AND has_table_privilege(current_user, quote_ident(t.schemaname) || '.' || quote_ident(t.tablename), 'SELECT')
`).Scan(&selectableTableCount)
if len(includeSchemas) > 0 {
// Check only tables in the specified schemas
err = conn.QueryRow(ctx, `
SELECT COUNT(*)
FROM pg_catalog.pg_tables t
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
AND t.schemaname NOT LIKE 'pg_temp_%'
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
AND t.schemaname = ANY($1::text[])
AND has_table_privilege(current_user, quote_ident(t.schemaname) || '.' || quote_ident(t.tablename), 'SELECT')
`, includeSchemas).Scan(&selectableTableCount)
} else {
// Check all tables in non-system schemas
err = conn.QueryRow(ctx, `
SELECT COUNT(*)
FROM pg_catalog.pg_tables t
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
AND t.schemaname NOT LIKE 'pg_temp_%'
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
AND has_table_privilege(current_user, quote_ident(t.schemaname) || '.' || quote_ident(t.tablename), 'SELECT')
`).Scan(&selectableTableCount)
}
if err != nil {
return fmt.Errorf("cannot check SELECT privileges: %w", err)
}

View File

@@ -705,6 +705,233 @@ func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
assert.Contains(t, err.Error(), "permission denied")
}
func Test_Validate_WhenLocalhostAndDatabasus_ReturnsError(t *testing.T) {
testCases := []struct {
name string
host string
username string
database string
}{
{
name: "localhost with databasus db",
host: "localhost",
username: "postgres",
database: "databasus",
},
{
name: "127.0.0.1 with databasus db",
host: "127.0.0.1",
username: "postgres",
database: "databasus",
},
{
name: "172.17.0.1 with databasus db",
host: "172.17.0.1",
username: "postgres",
database: "databasus",
},
{
name: "host.docker.internal with databasus db",
host: "host.docker.internal",
username: "postgres",
database: "databasus",
},
{
name: "LOCALHOST (uppercase) with DATABASUS db",
host: "LOCALHOST",
username: "POSTGRES",
database: "DATABASUS",
},
{
name: "LocalHost (mixed case) with DataBasus db",
host: "LocalHost",
username: "anyuser",
database: "DataBasus",
},
{
name: "localhost with databasus and any username",
host: "localhost",
username: "myuser",
database: "databasus",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
pgModel := &PostgresqlDatabase{
Host: tc.host,
Port: 5437,
Username: tc.username,
Password: "somepassword",
Database: &tc.database,
CpuCount: 1,
}
err := pgModel.Validate()
assert.Error(t, err)
assert.Contains(t, err.Error(), "backing up Databasus internal database is not allowed")
assert.Contains(t, err.Error(), "https://databasus.com/faq#backup-databasus")
})
}
}
func Test_Validate_WhenNotLocalhostOrNotDatabasus_ValidatesSuccessfully(t *testing.T) {
testCases := []struct {
name string
host string
username string
database string
}{
{
name: "different host (remote server) with databasus db",
host: "192.168.1.100",
username: "postgres",
database: "databasus",
},
{
name: "different database name on localhost",
host: "localhost",
username: "postgres",
database: "myapp",
},
{
name: "all different",
host: "db.example.com",
username: "appuser",
database: "production",
},
{
name: "localhost with postgres database",
host: "localhost",
username: "postgres",
database: "postgres",
},
{
name: "remote host with databasus db name (allowed)",
host: "db.example.com",
username: "postgres",
database: "databasus",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
pgModel := &PostgresqlDatabase{
Host: tc.host,
Port: 5432,
Username: tc.username,
Password: "somepassword",
Database: &tc.database,
CpuCount: 1,
}
err := pgModel.Validate()
assert.NoError(t, err)
})
}
}
func Test_Validate_WhenDatabaseIsNil_ValidatesSuccessfully(t *testing.T) {
pgModel := &PostgresqlDatabase{
Host: "localhost",
Port: 5437,
Username: "postgres",
Password: "somepassword",
Database: nil,
CpuCount: 1,
}
err := pgModel.Validate()
assert.NoError(t, err)
}
func Test_Validate_WhenDatabaseIsEmpty_ValidatesSuccessfully(t *testing.T) {
emptyDb := ""
pgModel := &PostgresqlDatabase{
Host: "localhost",
Port: 5437,
Username: "postgres",
Password: "somepassword",
Database: &emptyDb,
CpuCount: 1,
}
err := pgModel.Validate()
assert.NoError(t, err)
}
func Test_Validate_WhenRequiredFieldsMissing_ReturnsError(t *testing.T) {
testCases := []struct {
name string
model *PostgresqlDatabase
expectedError string
}{
{
name: "missing host",
model: &PostgresqlDatabase{
Host: "",
Port: 5432,
Username: "user",
Password: "pass",
CpuCount: 1,
},
expectedError: "host is required",
},
{
name: "missing port",
model: &PostgresqlDatabase{
Host: "localhost",
Port: 0,
Username: "user",
Password: "pass",
CpuCount: 1,
},
expectedError: "port is required",
},
{
name: "missing username",
model: &PostgresqlDatabase{
Host: "localhost",
Port: 5432,
Username: "",
Password: "pass",
CpuCount: 1,
},
expectedError: "username is required",
},
{
name: "missing password",
model: &PostgresqlDatabase{
Host: "localhost",
Port: 5432,
Username: "user",
Password: "",
CpuCount: 1,
},
expectedError: "password is required",
},
{
name: "invalid cpu count",
model: &PostgresqlDatabase{
Host: "localhost",
Port: 5432,
Username: "user",
Password: "pass",
CpuCount: 0,
},
expectedError: "cpu count must be greater than 0",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.model.Validate()
assert.Error(t, err)
assert.Contains(t, err.Error(), tc.expectedError)
})
}
}
type PostgresContainer struct {
Host string
Port int

View File

@@ -1,7 +1,7 @@
package healthcheck_attempt
import (
"databasus-backend/internal/config"
"context"
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
"log/slog"
"time"
@@ -13,18 +13,19 @@ type HealthcheckAttemptBackgroundService struct {
logger *slog.Logger
}
func (s *HealthcheckAttemptBackgroundService) Run() {
func (s *HealthcheckAttemptBackgroundService) Run(ctx context.Context) {
// first healthcheck immediately
s.checkDatabases()
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for range ticker.C {
if config.IsShouldShutdown() {
break
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.checkDatabases()
}
s.checkDatabases()
}
}

View File

@@ -675,6 +675,10 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
WebhookNotifier: &webhook_notifier.WebhookNotifier{
WebhookURL: "https://webhook.example.com/test",
WebhookMethod: webhook_notifier.WebhookMethodPOST,
Headers: []webhook_notifier.WebhookHeader{
{Key: "Authorization", Value: "Bearer my-secret-token"},
{Key: "X-Custom-Header", Value: "custom-value"},
},
},
}
},
@@ -687,14 +691,40 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
WebhookNotifier: &webhook_notifier.WebhookNotifier{
WebhookURL: "https://webhook.example.com/updated",
WebhookMethod: webhook_notifier.WebhookMethodGET,
Headers: []webhook_notifier.WebhookHeader{
{Key: "Authorization", Value: "Bearer updated-token"},
},
},
}
},
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
// No sensitive data to verify for webhook
assert.NotEmpty(
t,
notifier.WebhookNotifier.WebhookURL,
"WebhookURL should be visible",
)
// Verify header values are encrypted in DB
assert.True(
t,
isEncrypted(notifier.WebhookNotifier.Headers[0].Value),
"Header value should be encrypted in DB",
)
decrypted := decryptField(
t,
notifier.ID,
notifier.WebhookNotifier.Headers[0].Value,
)
assert.Equal(t, "Bearer updated-token", decrypted)
},
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
// No sensitive data to hide for webhook
assert.NotEmpty(
t,
notifier.WebhookNotifier.WebhookURL,
"WebhookURL should be visible",
)
for _, header := range notifier.WebhookNotifier.Headers {
assert.Empty(t, header.Value, "Header value should be hidden")
}
},
},
}
@@ -905,7 +935,7 @@ func Test_CreateNotifier_AllSensitiveFieldsEncryptedInDB(t *testing.T) {
},
},
{
name: "Webhook Notifier - WebhookURL encrypted",
name: "Webhook Notifier - Header values encrypted, URL not encrypted",
createNotifier: func(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
@@ -914,17 +944,48 @@ func Test_CreateNotifier_AllSensitiveFieldsEncryptedInDB(t *testing.T) {
WebhookNotifier: &webhook_notifier.WebhookNotifier{
WebhookURL: "https://webhook.example.com/test456",
WebhookMethod: webhook_notifier.WebhookMethodPOST,
Headers: []webhook_notifier.WebhookHeader{
{Key: "Authorization", Value: "Bearer secret-token-12345"},
{Key: "X-API-Key", Value: "api-key-67890"},
},
},
}
},
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
assert.True(
assert.False(
t,
isEncrypted(notifier.WebhookNotifier.WebhookURL),
"WebhookURL should be encrypted",
"WebhookURL should NOT be encrypted",
)
decrypted := decryptField(t, notifier.ID, notifier.WebhookNotifier.WebhookURL)
assert.Equal(t, "https://webhook.example.com/test456", decrypted)
assert.Equal(
t,
"https://webhook.example.com/test456",
notifier.WebhookNotifier.WebhookURL,
)
assert.True(
t,
isEncrypted(notifier.WebhookNotifier.Headers[0].Value),
"Header value should be encrypted",
)
decrypted1 := decryptField(
t,
notifier.ID,
notifier.WebhookNotifier.Headers[0].Value,
)
assert.Equal(t, "Bearer secret-token-12345", decrypted1)
assert.True(
t,
isEncrypted(notifier.WebhookNotifier.Headers[1].Value),
"Header value should be encrypted",
)
decrypted2 := decryptField(
t,
notifier.ID,
notifier.WebhookNotifier.Headers[1].Value,
)
assert.Equal(t, "api-key-67890", decrypted2)
},
},
}

View File

@@ -21,6 +21,10 @@ type WebhookHeader struct {
Value string `json:"value"`
}
// Before both WebhookURL, BodyTemplate and HeadersJSON were considered
// as sensetive data and it was causing issues. Now only headers values
// considered as sensetive data, but we try to decrypt webhook URL and
// body template for backward combability
type WebhookNotifier struct {
NotifierID uuid.UUID `json:"notifierId" gorm:"primaryKey;column:notifier_id"`
WebhookURL string `json:"webhookUrl" gorm:"not null;column:webhook_url"`
@@ -58,6 +62,20 @@ func (t *WebhookNotifier) AfterFind(_ *gorm.DB) error {
}
}
encryptor := encryption.GetFieldEncryptor()
if t.WebhookURL != "" {
if decrypted, err := encryptor.Decrypt(t.NotifierID, t.WebhookURL); err == nil {
t.WebhookURL = decrypted
}
}
if t.BodyTemplate != nil && *t.BodyTemplate != "" {
if decrypted, err := encryptor.Decrypt(t.NotifierID, *t.BodyTemplate); err == nil {
t.BodyTemplate = &decrypted
}
}
return nil
}
@@ -79,22 +97,24 @@ func (t *WebhookNotifier) Send(
heading string,
message string,
) error {
webhookURL, err := encryptor.Decrypt(t.NotifierID, t.WebhookURL)
if err != nil {
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
if err := t.decryptHeadersForSending(encryptor); err != nil {
return err
}
switch t.WebhookMethod {
case WebhookMethodGET:
return t.sendGET(webhookURL, heading, message, logger)
return t.sendGET(t.WebhookURL, heading, message, logger)
case WebhookMethodPOST:
return t.sendPOST(webhookURL, heading, message, logger)
return t.sendPOST(t.WebhookURL, heading, message, logger)
default:
return fmt.Errorf("unsupported webhook method: %s", t.WebhookMethod)
}
}
func (t *WebhookNotifier) HideSensitiveData() {
for i := range t.Headers {
t.Headers[i].Value = ""
}
}
func (t *WebhookNotifier) Update(incoming *WebhookNotifier) {
@@ -105,14 +125,15 @@ func (t *WebhookNotifier) Update(incoming *WebhookNotifier) {
}
func (t *WebhookNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if t.WebhookURL != "" {
encrypted, err := encryptor.Encrypt(t.NotifierID, t.WebhookURL)
for i := range t.Headers {
if t.Headers[i].Value != "" {
encrypted, err := encryptor.Encrypt(t.NotifierID, t.Headers[i].Value)
if err != nil {
return fmt.Errorf("failed to encrypt header value: %w", err)
}
if err != nil {
return fmt.Errorf("failed to encrypt webhook URL: %w", err)
t.Headers[i].Value = encrypted
}
t.WebhookURL = encrypted
}
return nil
@@ -241,3 +262,15 @@ func escapeJSONString(s string) string {
return string(b[1 : len(b)-1])
}
func (t *WebhookNotifier) decryptHeadersForSending(encryptor encryption.FieldEncryptor) error {
for i := range t.Headers {
if t.Headers[i].Value != "" {
if decrypted, err := encryptor.Decrypt(t.NotifierID, t.Headers[i].Value); err == nil {
t.Headers[i].Value = decrypted
}
}
}
return nil
}

View File

@@ -1,6 +1,7 @@
package restores
import (
"context"
"databasus-backend/internal/features/restores/enums"
"log/slog"
)
@@ -10,7 +11,7 @@ type RestoreBackgroundService struct {
logger *slog.Logger
}
func (s *RestoreBackgroundService) Run() {
func (s *RestoreBackgroundService) Run(ctx context.Context) {
if err := s.failRestoresInProgress(); err != nil {
s.logger.Error("Failed to fail restores in progress", "error", err)
panic(err)

View File

@@ -19,6 +19,7 @@ import (
"databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/mysql"
@@ -274,7 +275,7 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
var backup *backups.Backup
var backup *backups_core.Backup
var request RestoreBackupRequest
if tc.dbType == databases.DatabaseTypePostgres {
@@ -321,7 +322,7 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
}
// Set huge backup size (10 TB) that would fail disk validation if checked
repo := &backups.BackupRepository{}
repo := &backups_core.BackupRepository{}
backup.BackupSizeMb = 10485760.0
err := repo.Save(backup)
assert.NoError(t, err)
@@ -368,7 +369,7 @@ func createTestDatabaseWithBackupForRestore(
workspace *workspaces_models.Workspace,
owner *users_dto.SignInResponseDTO,
router *gin.Engine,
) (*databases.Database, *backups.Backup) {
) (*databases.Database, *backups_core.Backup) {
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
@@ -504,7 +505,7 @@ func createTestStorage(workspaceID uuid.UUID) *storages.Storage {
func createTestBackup(
database *databases.Database,
owner *users_dto.SignInResponseDTO,
) *backups.Backup {
) *backups_core.Backup {
fieldEncryptor := util_encryption.GetFieldEncryptor()
userService := users_services.GetUserService()
user, err := userService.GetUserFromToken(owner.Token)
@@ -517,17 +518,17 @@ func createTestBackup(
panic("No storage found for workspace")
}
backup := &backups.Backup{
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storages[0].ID,
Status: backups.BackupStatusCompleted,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10.5,
BackupDurationMs: 1000,
CreatedAt: time.Now().UTC(),
}
repo := &backups.BackupRepository{}
repo := &backups_core.BackupRepository{}
if err := repo.Save(backup); err != nil {
panic(err)
}

View File

@@ -1,7 +1,7 @@
package models
import (
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/restores/enums"
"time"
@@ -13,7 +13,7 @@ type Restore struct {
Status enums.RestoreStatus `json:"status" gorm:"column:status;type:text;not null"`
BackupID uuid.UUID `json:"backupId" gorm:"column:backup_id;type:uuid;not null"`
Backup *backups.Backup
Backup *backups_core.Backup
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`

View File

@@ -3,6 +3,7 @@ package restores
import (
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/disk"
@@ -36,7 +37,7 @@ type RestoreService struct {
diskService *disk.DiskService
}
func (s *RestoreService) OnBeforeBackupRemove(backup *backups.Backup) error {
func (s *RestoreService) OnBeforeBackupRemove(backup *backups_core.Backup) error {
restores, err := s.restoreRepository.FindByBackupID(backup.ID)
if err != nil {
return err
@@ -153,10 +154,10 @@ func (s *RestoreService) RestoreBackupWithAuth(
}
func (s *RestoreService) RestoreBackup(
backup *backups.Backup,
backup *backups_core.Backup,
requestDTO RestoreBackupRequest,
) error {
if backup.Status != backups.BackupStatusCompleted {
if backup.Status != backups_core.BackupStatusCompleted {
return errors.New("backup is not completed")
}
@@ -370,7 +371,7 @@ func (s *RestoreService) validateVersionCompatibility(
}
func (s *RestoreService) validateDiskSpace(
backup *backups.Backup,
backup *backups_core.Backup,
requestDTO RestoreBackupRequest,
) error {
// Only validate disk space for PostgreSQL when file-based restore is needed:

View File

@@ -18,7 +18,7 @@ import (
"github.com/klauspost/compress/zstd"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -40,7 +40,7 @@ func (uc *RestoreMariadbBackupUsecase) Execute(
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
) error {
if originalDB.Type != databases.DatabaseTypeMariadb {
@@ -71,6 +71,7 @@ func (uc *RestoreMariadbBackupUsecase) Execute(
if mdb.IsHttps {
args = append(args, "--ssl")
args = append(args, "--skip-ssl-verify-server-cert")
}
if mdb.Database != nil && *mdb.Database != "" {
@@ -98,7 +99,7 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
mariadbBin string,
args []string,
password string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
mdbConfig *mariadbtypes.MariadbDatabase,
) error {
@@ -162,7 +163,7 @@ func (uc *RestoreMariadbBackupUsecase) executeMariadbRestore(
args []string,
myCnfFile string,
backupReader io.ReadCloser,
backup *backups.Backup,
backup *backups_core.Backup,
) error {
fullArgs := append([]string{"--defaults-file=" + myCnfFile}, args...)
@@ -225,7 +226,7 @@ func (uc *RestoreMariadbBackupUsecase) executeMariadbRestore(
func (uc *RestoreMariadbBackupUsecase) setupDecryption(
reader io.Reader,
backup *backups.Backup,
backup *backups_core.Backup,
) (io.Reader, error) {
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
return nil, fmt.Errorf("backup is encrypted but missing encryption metadata")

View File

@@ -14,7 +14,7 @@ import (
"time"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -40,7 +40,7 @@ func (uc *RestoreMongodbBackupUsecase) Execute(
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
) error {
if originalDB.Type != databases.DatabaseTypeMongodb {
@@ -124,7 +124,7 @@ func (uc *RestoreMongodbBackupUsecase) buildMongorestoreArgs(
func (uc *RestoreMongodbBackupUsecase) restoreFromStorage(
mongorestoreBin string,
args []string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
) error {
ctx, cancel := context.WithTimeout(context.Background(), restoreTimeout)
@@ -166,7 +166,7 @@ func (uc *RestoreMongodbBackupUsecase) executeMongoRestore(
mongorestoreBin string,
args []string,
backupReader io.ReadCloser,
backup *backups.Backup,
backup *backups_core.Backup,
) error {
cmd := exec.CommandContext(ctx, mongorestoreBin, args...)
@@ -231,7 +231,7 @@ func (uc *RestoreMongodbBackupUsecase) executeMongoRestore(
func (uc *RestoreMongodbBackupUsecase) setupDecryption(
reader io.Reader,
backup *backups.Backup,
backup *backups_core.Backup,
) (io.Reader, error) {
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
return nil, errors.New("encrypted backup missing salt or IV")

View File

@@ -18,7 +18,7 @@ import (
"github.com/klauspost/compress/zstd"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -40,7 +40,7 @@ func (uc *RestoreMysqlBackupUsecase) Execute(
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
) error {
if originalDB.Type != databases.DatabaseTypeMysql {
@@ -98,7 +98,7 @@ func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
mysqlBin string,
args []string,
password string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
myConfig *mysqltypes.MysqlDatabase,
) error {
@@ -154,7 +154,7 @@ func (uc *RestoreMysqlBackupUsecase) executeMysqlRestore(
args []string,
myCnfFile string,
backupReader io.ReadCloser,
backup *backups.Backup,
backup *backups_core.Backup,
) error {
fullArgs := append([]string{"--defaults-file=" + myCnfFile}, args...)
@@ -217,7 +217,7 @@ func (uc *RestoreMysqlBackupUsecase) executeMysqlRestore(
func (uc *RestoreMysqlBackupUsecase) setupDecryption(
reader io.Reader,
backup *backups.Backup,
backup *backups_core.Backup,
) (io.Reader, error) {
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
return nil, fmt.Errorf("backup is encrypted but missing encryption metadata")

View File

@@ -15,7 +15,7 @@ import (
"time"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -39,7 +39,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
isExcludeExtensions bool,
) error {
@@ -86,7 +86,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
func (uc *RestorePostgresqlBackupUsecase) restoreCustomType(
originalDB *databases.Database,
pgBin string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
pg *pgtypes.PostgresqlDatabase,
isExcludeExtensions bool,
@@ -113,7 +113,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreCustomType(
func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
originalDB *databases.Database,
pgBin string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
pg *pgtypes.PostgresqlDatabase,
) error {
@@ -321,7 +321,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
func (uc *RestorePostgresqlBackupUsecase) restoreViaFile(
originalDB *databases.Database,
pgBin string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
pg *pgtypes.PostgresqlDatabase,
isExcludeExtensions bool,
@@ -371,7 +371,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
pgBin string,
args []string,
password string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
pgConfig *pgtypes.PostgresqlDatabase,
isExcludeExtensions bool,
@@ -469,7 +469,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
// downloadBackupToTempFile downloads backup data from storage to a temporary file
func (uc *RestorePostgresqlBackupUsecase) downloadBackupToTempFile(
ctx context.Context,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
) (string, func(), error) {
// Create temporary directory for backup data

View File

@@ -3,7 +3,7 @@ package usecases
import (
"errors"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/restores/models"
@@ -26,7 +26,7 @@ func (uc *RestoreBackupUsecase) Execute(
restore models.Restore,
originalDB *databases.Database,
restoringToDB *databases.Database,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
isExcludeExtensions bool,
) error {

View File

@@ -1,13 +1,14 @@
package system_healthcheck
import (
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/features/backups/backups/backuping"
"databasus-backend/internal/features/disk"
)
var healthcheckService = &HealthcheckService{
disk.GetDiskService(),
backups.GetBackupBackgroundService(),
backuping.GetBackupsScheduler(),
backuping.GetBackuperNode(),
}
var healthcheckController = &HealthcheckController{
healthcheckService,

View File

@@ -1,7 +1,8 @@
package system_healthcheck
import (
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups/backuping"
"databasus-backend/internal/features/disk"
"databasus-backend/internal/storage"
"errors"
@@ -9,7 +10,8 @@ import (
type HealthcheckService struct {
diskService *disk.DiskService
backupBackgroundService *backups.BackupBackgroundService
backupBackgroundService *backuping.BackupsScheduler
backuperNode *backuping.BackuperNode
}
func (s *HealthcheckService) IsHealthy() error {
@@ -29,8 +31,16 @@ func (s *HealthcheckService) IsHealthy() error {
return errors.New("cannot connect to the database")
}
if !s.backupBackgroundService.IsBackupsWorkerRunning() {
return errors.New("backups are not running for more than 5 minutes")
if config.GetEnv().IsPrimaryNode {
if !s.backupBackgroundService.IsSchedulerRunning() {
return errors.New("backups are not running for more than 5 minutes")
}
}
if config.GetEnv().IsBackupNode {
if !s.backuperNode.IsBackuperRunning() {
return errors.New("backuper node is not running for more than 5 minutes")
}
}
return nil

View File

@@ -0,0 +1,75 @@
package task_cancellation
import (
"context"
cache_utils "databasus-backend/internal/util/cache"
"log/slog"
"sync"
"github.com/google/uuid"
)
const taskCancelChannel = "task:cancel"
type TaskCancelManager struct {
mu sync.RWMutex
cancelFuncs map[uuid.UUID]context.CancelFunc
pubsub *cache_utils.PubSubManager
logger *slog.Logger
}
func (m *TaskCancelManager) StartSubscription() {
ctx := context.Background()
handler := func(message string) {
taskID, err := uuid.Parse(message)
if err != nil {
m.logger.Error("Invalid task ID in cancel message", "message", message, "error", err)
return
}
m.mu.Lock()
defer m.mu.Unlock()
cancelFunc, exists := m.cancelFuncs[taskID]
if exists {
cancelFunc()
delete(m.cancelFuncs, taskID)
m.logger.Info("Cancelled task via Pub/Sub", "taskID", taskID)
}
}
err := m.pubsub.Subscribe(ctx, taskCancelChannel, handler)
if err != nil {
m.logger.Error("Failed to subscribe to task cancel channel", "error", err)
} else {
m.logger.Info("Successfully subscribed to task cancel channel")
}
}
func (m *TaskCancelManager) RegisterTask(task uuid.UUID, cancelFunc context.CancelFunc) {
m.mu.Lock()
defer m.mu.Unlock()
m.cancelFuncs[task] = cancelFunc
m.logger.Debug("Registered task", "taskID", task)
}
func (m *TaskCancelManager) CancelTask(taskID uuid.UUID) error {
ctx := context.Background()
err := m.pubsub.Publish(ctx, taskCancelChannel, taskID.String())
if err != nil {
m.logger.Error("Failed to publish cancel message", "taskID", taskID, "error", err)
return err
}
m.logger.Info("Published task cancel message", "taskID", taskID)
return nil
}
func (m *TaskCancelManager) UnregisterTask(taskID uuid.UUID) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.cancelFuncs, taskID)
m.logger.Debug("Unregistered task", "taskID", taskID)
}

View File

@@ -0,0 +1,200 @@
package task_cancellation
import (
"context"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_RegisterTask_TaskRegisteredSuccessfully(t *testing.T) {
manager := taskCancelManager
taskID := uuid.New()
_, cancel := context.WithCancel(context.Background())
defer cancel()
manager.RegisterTask(taskID, cancel)
manager.mu.RLock()
_, exists := manager.cancelFuncs[taskID]
manager.mu.RUnlock()
assert.True(t, exists, "Task should be registered")
}
func Test_UnregisterTask_TaskUnregisteredSuccessfully(t *testing.T) {
manager := taskCancelManager
taskID := uuid.New()
_, cancel := context.WithCancel(context.Background())
defer cancel()
manager.RegisterTask(taskID, cancel)
manager.UnregisterTask(taskID)
manager.mu.RLock()
_, exists := manager.cancelFuncs[taskID]
manager.mu.RUnlock()
assert.False(t, exists, "Task should be unregistered")
}
func Test_CancelTask_OnSameInstance_TaskCancelledViaPubSub(t *testing.T) {
manager := taskCancelManager
taskID := uuid.New()
ctx, cancel := context.WithCancel(context.Background())
cancelled := false
var mu sync.Mutex
wrappedCancel := func() {
mu.Lock()
cancelled = true
mu.Unlock()
cancel()
}
manager.RegisterTask(taskID, wrappedCancel)
manager.StartSubscription()
time.Sleep(100 * time.Millisecond)
err := manager.CancelTask(taskID)
assert.NoError(t, err, "Cancel should not return error")
time.Sleep(500 * time.Millisecond)
mu.Lock()
wasCancelled := cancelled
mu.Unlock()
assert.True(t, wasCancelled, "Cancel function should have been called")
assert.Error(t, ctx.Err(), "Context should be cancelled")
}
func Test_CancelTask_FromDifferentInstance_TaskCancelledOnRunningInstance(t *testing.T) {
manager1 := taskCancelManager
manager2 := taskCancelManager
taskID := uuid.New()
ctx, cancel := context.WithCancel(context.Background())
cancelled := false
var mu sync.Mutex
wrappedCancel := func() {
mu.Lock()
cancelled = true
mu.Unlock()
cancel()
}
manager1.RegisterTask(taskID, wrappedCancel)
manager1.StartSubscription()
manager2.StartSubscription()
time.Sleep(100 * time.Millisecond)
err := manager2.CancelTask(taskID)
assert.NoError(t, err, "Cancel should not return error")
time.Sleep(500 * time.Millisecond)
mu.Lock()
wasCancelled := cancelled
mu.Unlock()
assert.True(t, wasCancelled, "Cancel function should have been called on instance 1")
assert.Error(t, ctx.Err(), "Context should be cancelled")
}
func Test_CancelTask_WhenTaskDoesNotExist_NoErrorReturned(t *testing.T) {
manager := taskCancelManager
manager.StartSubscription()
time.Sleep(100 * time.Millisecond)
nonExistentID := uuid.New()
err := manager.CancelTask(nonExistentID)
assert.NoError(t, err, "Cancelling non-existent task should not error")
}
func Test_CancelTask_WithMultipleTasks_AllTasksCancelled(t *testing.T) {
manager := taskCancelManager
numTasks := 5
taskIDs := make([]uuid.UUID, numTasks)
contexts := make([]context.Context, numTasks)
cancels := make([]context.CancelFunc, numTasks)
cancelledFlags := make([]bool, numTasks)
var mu sync.Mutex
for i := 0; i < numTasks; i++ {
taskIDs[i] = uuid.New()
contexts[i], cancels[i] = context.WithCancel(context.Background())
idx := i
wrappedCancel := func() {
mu.Lock()
cancelledFlags[idx] = true
mu.Unlock()
cancels[idx]()
}
manager.RegisterTask(taskIDs[i], wrappedCancel)
}
manager.StartSubscription()
time.Sleep(100 * time.Millisecond)
for i := 0; i < numTasks; i++ {
err := manager.CancelTask(taskIDs[i])
assert.NoError(t, err, "Cancel should not return error")
}
time.Sleep(1 * time.Second)
mu.Lock()
for i := 0; i < numTasks; i++ {
assert.True(t, cancelledFlags[i], "Task %d should be cancelled", i)
assert.Error(t, contexts[i].Err(), "Context %d should be cancelled", i)
}
mu.Unlock()
}
func Test_CancelTask_AfterUnregister_TaskNotCancelled(t *testing.T) {
manager := taskCancelManager
taskID := uuid.New()
_, cancel := context.WithCancel(context.Background())
defer cancel()
cancelled := false
var mu sync.Mutex
wrappedCancel := func() {
mu.Lock()
cancelled = true
mu.Unlock()
cancel()
}
manager.RegisterTask(taskID, wrappedCancel)
manager.StartSubscription()
time.Sleep(100 * time.Millisecond)
manager.UnregisterTask(taskID)
err := manager.CancelTask(taskID)
assert.NoError(t, err, "Cancel should not return error")
time.Sleep(500 * time.Millisecond)
mu.Lock()
wasCancelled := cancelled
mu.Unlock()
assert.False(t, wasCancelled, "Cancel function should not be called after unregister")
}

View File

@@ -0,0 +1,25 @@
package task_cancellation
import (
"context"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/logger"
"sync"
"github.com/google/uuid"
)
var taskCancelManager = &TaskCancelManager{
sync.RWMutex{},
make(map[uuid.UUID]context.CancelFunc),
cache_utils.NewPubSubManager(),
logger.GetLogger(),
}
func GetTaskCancelManager() *TaskCancelManager {
return taskCancelManager
}
func SetupDependencies() {
taskCancelManager.StartSubscription()
}

View File

@@ -0,0 +1,18 @@
package task_registry
import (
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/logger"
)
var taskNodesRegistry = &TaskNodesRegistry{
cache_utils.GetValkeyClient(),
logger.GetLogger(),
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
}
func GetTaskNodesRegistry() *TaskNodesRegistry {
return taskNodesRegistry
}

View File

@@ -0,0 +1,29 @@
package task_registry
import (
"time"
"github.com/google/uuid"
)
type TaskNode struct {
ID uuid.UUID `json:"id"`
ThroughputMBs int `json:"throughputMBs"`
LastHeartbeat time.Time `json:"lastHeartbeat"`
}
type TaskNodeStats struct {
ID uuid.UUID `json:"id"`
ActiveTasks int `json:"activeTasks"`
}
type TaskSubmitMessage struct {
NodeID string `json:"nodeId"`
TaskID string `json:"taskId"`
IsCallNotifier bool `json:"isCallNotifier"`
}
type TaskCompletionMessage struct {
NodeID string `json:"nodeId"`
TaskID string `json:"taskId"`
}

View File

@@ -0,0 +1,641 @@
package task_registry
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"time"
cache_utils "databasus-backend/internal/util/cache"
"github.com/google/uuid"
"github.com/valkey-io/valkey-go"
)
const (
nodeInfoKeyPrefix = "node:"
nodeInfoKeySuffix = ":info"
nodeActiveTasksPrefix = "node:"
nodeActiveTasksSuffix = ":active_tasks"
taskSubmitChannel = "task:submit"
taskCompletionChannel = "task:completion"
deadNodeThreshold = 2 * time.Minute
cleanupTickerInterval = 1 * time.Second
)
// TaskNodesRegistry helps to sync tasks scheduler (backuping or restoring)
// and task nodes which are used for network-intensive tasks processing
//
// Features:
// - Track node availability and load level
// - Assign from scheduler to node tasks needed to be processed
// - Notify scheduler from node about task completion
//
// Important things to remember:
// - Node can contain different tasks types so when task is assigned
// or node's tasks cleaned - should be performed DB check in DB
// that task with this ID exists for this task type at all
// - Nodes without heathbeat for more than 2 minutes are not included
// in available nodes list and stats
//
// Cleanup dead nodes performed on 2 levels:
// - List and stats functions do not return dead nodes
// - Periodically dead nodes are cleaned up in cache (to not
// accumulate too many dead nodes in cache)
type TaskNodesRegistry struct {
client valkey.Client
logger *slog.Logger
timeout time.Duration
pubsubTasks *cache_utils.PubSubManager
pubsubCompletions *cache_utils.PubSubManager
}
func (r *TaskNodesRegistry) Run(ctx context.Context) {
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
}
ticker := time.NewTicker(cleanupTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes", "error", err)
}
}
}
}
func (r *TaskNodesRegistry) GetAvailableNodes() ([]TaskNode, error) {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
var allKeys []string
cursor := uint64(0)
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
for {
result := r.client.Do(
ctx,
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
)
if result.Error() != nil {
return nil, fmt.Errorf("failed to scan node keys: %w", result.Error())
}
scanResult, err := result.AsScanEntry()
if err != nil {
return nil, fmt.Errorf("failed to parse scan result: %w", err)
}
allKeys = append(allKeys, scanResult.Elements...)
cursor = scanResult.Cursor
if cursor == 0 {
break
}
}
if len(allKeys) == 0 {
return []TaskNode{}, nil
}
keyDataMap, err := r.pipelineGetKeys(allKeys)
if err != nil {
return nil, fmt.Errorf("failed to pipeline get node keys: %w", err)
}
threshold := time.Now().UTC().Add(-deadNodeThreshold)
var nodes []TaskNode
for key, data := range keyDataMap {
// Skip if the key doesn't exist (data is empty)
if len(data) == 0 {
continue
}
var node TaskNode
if err := json.Unmarshal(data, &node); err != nil {
r.logger.Warn("Failed to unmarshal node data", "key", key, "error", err)
continue
}
// Skip nodes with zero/uninitialized heartbeat
if node.LastHeartbeat.IsZero() {
continue
}
if node.LastHeartbeat.Before(threshold) {
continue
}
nodes = append(nodes, node)
}
return nodes, nil
}
func (r *TaskNodesRegistry) GetNodesStats() ([]TaskNodeStats, error) {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
var allKeys []string
cursor := uint64(0)
pattern := nodeActiveTasksPrefix + "*" + nodeActiveTasksSuffix
for {
result := r.client.Do(
ctx,
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(100).Build(),
)
if result.Error() != nil {
return nil, fmt.Errorf("failed to scan active tasks keys: %w", result.Error())
}
scanResult, err := result.AsScanEntry()
if err != nil {
return nil, fmt.Errorf("failed to parse scan result: %w", err)
}
allKeys = append(allKeys, scanResult.Elements...)
cursor = scanResult.Cursor
if cursor == 0 {
break
}
}
if len(allKeys) == 0 {
return []TaskNodeStats{}, nil
}
keyDataMap, err := r.pipelineGetKeys(allKeys)
if err != nil {
return nil, fmt.Errorf("failed to pipeline get active tasks keys: %w", err)
}
var nodeInfoKeys []string
nodeIDToStatsKey := make(map[string]string)
for key := range keyDataMap {
nodeID := r.extractNodeIDFromKey(key, nodeActiveTasksPrefix, nodeActiveTasksSuffix)
nodeIDStr := nodeID.String()
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeIDStr, nodeInfoKeySuffix)
nodeInfoKeys = append(nodeInfoKeys, infoKey)
nodeIDToStatsKey[infoKey] = key
}
nodeInfoMap, err := r.pipelineGetKeys(nodeInfoKeys)
if err != nil {
return nil, fmt.Errorf("failed to pipeline get node info keys: %w", err)
}
threshold := time.Now().UTC().Add(-deadNodeThreshold)
var stats []TaskNodeStats
for infoKey, nodeData := range nodeInfoMap {
// Skip if the info key doesn't exist (nodeData is empty)
if len(nodeData) == 0 {
continue
}
var node TaskNode
if err := json.Unmarshal(nodeData, &node); err != nil {
r.logger.Warn("Failed to unmarshal node data", "key", infoKey, "error", err)
continue
}
// Skip nodes with zero/uninitialized heartbeat
if node.LastHeartbeat.IsZero() {
continue
}
if node.LastHeartbeat.Before(threshold) {
continue
}
statsKey := nodeIDToStatsKey[infoKey]
tasksData := keyDataMap[statsKey]
count, err := r.parseIntFromBytes(tasksData)
if err != nil {
r.logger.Warn("Failed to parse active tasks count", "key", statsKey, "error", err)
continue
}
stat := TaskNodeStats{
ID: node.ID,
ActiveTasks: int(count),
}
stats = append(stats, stat)
}
return stats, nil
}
func (r *TaskNodesRegistry) IncrementTasksInProgress(nodeID string) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeActiveTasksPrefix, nodeID, nodeActiveTasksSuffix)
result := r.client.Do(ctx, r.client.B().Incr().Key(key).Build())
if result.Error() != nil {
return fmt.Errorf(
"failed to increment tasks in progress for node %s: %w",
nodeID,
result.Error(),
)
}
return nil
}
func (r *TaskNodesRegistry) DecrementTasksInProgress(nodeID string) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeActiveTasksPrefix, nodeID, nodeActiveTasksSuffix)
result := r.client.Do(ctx, r.client.B().Decr().Key(key).Build())
if result.Error() != nil {
return fmt.Errorf(
"failed to decrement tasks in progress for node %s: %w",
nodeID,
result.Error(),
)
}
newValue, err := result.AsInt64()
if err != nil {
return fmt.Errorf("failed to parse decremented value for node %s: %w", nodeID, err)
}
if newValue < 0 {
setCtx, setCancel := context.WithTimeout(context.Background(), r.timeout)
r.client.Do(setCtx, r.client.B().Set().Key(key).Value("0").Build())
setCancel()
r.logger.Warn("Active tasks counter went below 0, reset to 0", "nodeID", nodeID)
}
return nil
}
func (r *TaskNodesRegistry) HearthbeatNodeInRegistry(now time.Time, node TaskNode) error {
if now.IsZero() {
return fmt.Errorf("cannot register node with zero heartbeat timestamp")
}
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
node.LastHeartbeat = now
data, err := json.Marshal(node)
if err != nil {
return fmt.Errorf("failed to marshal node: %w", err)
}
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node.ID.String(), nodeInfoKeySuffix)
result := r.client.Do(
ctx,
r.client.B().Set().Key(key).Value(string(data)).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to register node %s: %w", node.ID, result.Error())
}
return nil
}
func (r *TaskNodesRegistry) UnregisterNodeFromRegistry(node TaskNode) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node.ID.String(), nodeInfoKeySuffix)
counterKey := fmt.Sprintf(
"%s%s%s",
nodeActiveTasksPrefix,
node.ID.String(),
nodeActiveTasksSuffix,
)
result := r.client.Do(
ctx,
r.client.B().Del().Key(infoKey, counterKey).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to unregister node %s: %w", node.ID, result.Error())
}
r.logger.Info("Unregistered node from registry", "nodeID", node.ID)
return nil
}
func (r *TaskNodesRegistry) AssignTaskToNode(
targetNodeID string,
taskID uuid.UUID,
isCallNotifier bool,
) error {
ctx := context.Background()
message := TaskSubmitMessage{
NodeID: targetNodeID,
TaskID: taskID.String(),
IsCallNotifier: isCallNotifier,
}
messageJSON, err := json.Marshal(message)
if err != nil {
return fmt.Errorf("failed to marshal task submit message: %w", err)
}
err = r.pubsubTasks.Publish(ctx, taskSubmitChannel, string(messageJSON))
if err != nil {
return fmt.Errorf("failed to publish task submit message: %w", err)
}
return nil
}
func (r *TaskNodesRegistry) SubscribeNodeForTasksAssignment(
nodeID string,
handler func(taskID uuid.UUID, isCallNotifier bool),
) error {
ctx := context.Background()
wrappedHandler := func(message string) {
var msg TaskSubmitMessage
if err := json.Unmarshal([]byte(message), &msg); err != nil {
r.logger.Warn("Failed to unmarshal task submit message", "error", err)
return
}
if msg.NodeID != nodeID {
return
}
taskID, err := uuid.Parse(msg.TaskID)
if err != nil {
r.logger.Warn(
"Failed to parse task ID from message",
"taskId",
msg.TaskID,
"error",
err,
)
return
}
handler(taskID, msg.IsCallNotifier)
}
err := r.pubsubTasks.Subscribe(ctx, taskSubmitChannel, wrappedHandler)
if err != nil {
return fmt.Errorf("failed to subscribe to task submit channel: %w", err)
}
r.logger.Info("Subscribed to task submit channel", "nodeID", nodeID)
return nil
}
func (r *TaskNodesRegistry) UnsubscribeNodeForTasksAssignments() error {
err := r.pubsubTasks.Close()
if err != nil {
return fmt.Errorf("failed to unsubscribe from task submit channel: %w", err)
}
r.logger.Info("Unsubscribed from task submit channel")
return nil
}
func (r *TaskNodesRegistry) PublishTaskCompletion(nodeID string, taskID uuid.UUID) error {
ctx := context.Background()
message := TaskCompletionMessage{
NodeID: nodeID,
TaskID: taskID.String(),
}
messageJSON, err := json.Marshal(message)
if err != nil {
return fmt.Errorf("failed to marshal task completion message: %w", err)
}
err = r.pubsubCompletions.Publish(ctx, taskCompletionChannel, string(messageJSON))
if err != nil {
return fmt.Errorf("failed to publish task completion message: %w", err)
}
return nil
}
func (r *TaskNodesRegistry) SubscribeForTasksCompletions(
handler func(nodeID string, taskID uuid.UUID),
) error {
ctx := context.Background()
wrappedHandler := func(message string) {
var msg TaskCompletionMessage
if err := json.Unmarshal([]byte(message), &msg); err != nil {
r.logger.Warn("Failed to unmarshal task completion message", "error", err)
return
}
taskID, err := uuid.Parse(msg.TaskID)
if err != nil {
r.logger.Warn(
"Failed to parse task ID from completion message",
"taskId",
msg.TaskID,
"error",
err,
)
return
}
handler(msg.NodeID, taskID)
}
err := r.pubsubCompletions.Subscribe(ctx, taskCompletionChannel, wrappedHandler)
if err != nil {
return fmt.Errorf("failed to subscribe to task completion channel: %w", err)
}
r.logger.Info("Subscribed to task completion channel")
return nil
}
func (r *TaskNodesRegistry) UnsubscribeForTasksCompletions() error {
err := r.pubsubCompletions.Close()
if err != nil {
return fmt.Errorf("failed to unsubscribe from task completion channel: %w", err)
}
r.logger.Info("Unsubscribed from task completion channel")
return nil
}
func (r *TaskNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uuid.UUID {
nodeIDStr := strings.TrimPrefix(key, prefix)
nodeIDStr = strings.TrimSuffix(nodeIDStr, suffix)
nodeID, err := uuid.Parse(nodeIDStr)
if err != nil {
r.logger.Warn("Failed to parse node ID from key", "key", key, "error", err)
return uuid.Nil
}
return nodeID
}
func (r *TaskNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, error) {
if len(keys) == 0 {
return make(map[string][]byte), nil
}
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
commands := make([]valkey.Completed, 0, len(keys))
for _, key := range keys {
commands = append(commands, r.client.B().Get().Key(key).Build())
}
results := r.client.DoMulti(ctx, commands...)
keyDataMap := make(map[string][]byte, len(keys))
for i, result := range results {
if result.Error() != nil {
r.logger.Warn("Failed to get key in pipeline", "key", keys[i], "error", result.Error())
continue
}
data, err := result.AsBytes()
if err != nil {
r.logger.Warn("Failed to parse key data in pipeline", "key", keys[i], "error", err)
continue
}
keyDataMap[keys[i]] = data
}
return keyDataMap, nil
}
func (r *TaskNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
str := string(data)
var count int64
_, err := fmt.Sscanf(str, "%d", &count)
if err != nil {
return 0, fmt.Errorf("failed to parse integer from bytes: %w", err)
}
return count, nil
}
func (r *TaskNodesRegistry) cleanupDeadNodes() error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
var allKeys []string
cursor := uint64(0)
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
for {
result := r.client.Do(
ctx,
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to scan node keys: %w", result.Error())
}
scanResult, err := result.AsScanEntry()
if err != nil {
return fmt.Errorf("failed to parse scan result: %w", err)
}
allKeys = append(allKeys, scanResult.Elements...)
cursor = scanResult.Cursor
if cursor == 0 {
break
}
}
if len(allKeys) == 0 {
return nil
}
keyDataMap, err := r.pipelineGetKeys(allKeys)
if err != nil {
return fmt.Errorf("failed to pipeline get node keys: %w", err)
}
threshold := time.Now().UTC().Add(-deadNodeThreshold)
var deadNodeKeys []string
for key, data := range keyDataMap {
// Skip if the key doesn't exist (data is empty)
if len(data) == 0 {
continue
}
var node TaskNode
if err := json.Unmarshal(data, &node); err != nil {
r.logger.Warn("Failed to unmarshal node data during cleanup", "key", key, "error", err)
continue
}
// Skip nodes with zero/uninitialized heartbeat
if node.LastHeartbeat.IsZero() {
continue
}
if node.LastHeartbeat.Before(threshold) {
nodeID := node.ID.String()
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeID, nodeInfoKeySuffix)
statsKey := fmt.Sprintf("%s%s%s", nodeActiveTasksPrefix, nodeID, nodeActiveTasksSuffix)
deadNodeKeys = append(deadNodeKeys, infoKey, statsKey)
r.logger.Info(
"Marking node for cleanup",
"nodeID", nodeID,
"lastHeartbeat", node.LastHeartbeat,
"threshold", threshold,
)
}
}
if len(deadNodeKeys) == 0 {
return nil
}
delCtx, delCancel := context.WithTimeout(context.Background(), r.timeout)
defer delCancel()
result := r.client.Do(
delCtx,
r.client.B().Del().Key(deadNodeKeys...).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to delete dead node keys: %w", result.Error())
}
deletedCount, err := result.AsInt64()
if err != nil {
return fmt.Errorf("failed to parse deleted count: %w", err)
}
r.logger.Info("Cleaned up dead nodes", "deletedKeysCount", deletedCount)
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -17,7 +17,7 @@ import (
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
mariadbtypes "databasus-backend/internal/features/databases/databases/mariadb"
@@ -189,7 +189,7 @@ func testMariadbBackupRestoreForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_mariadb"
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -286,7 +286,7 @@ func testMariadbBackupRestoreWithEncryptionForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
newDBName := "restoreddb_mariadb_encrypted"
@@ -394,7 +394,7 @@ func testMariadbBackupRestoreWithReadOnlyUserForVersion(
createBackupViaAPI(t, router, updatedDatabase.ID, user.Token)
backup := waitForBackupCompletion(t, router, updatedDatabase.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_mariadb_readonly"
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))

View File

@@ -19,7 +19,7 @@ import (
"go.mongodb.org/mongo-driver/mongo/options"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
mongodbtypes "databasus-backend/internal/features/databases/databases/mongodb"
@@ -161,7 +161,7 @@ func testMongodbBackupRestoreForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_mongodb_" + uuid.New().String()[:8]
@@ -239,7 +239,7 @@ func testMongodbBackupRestoreWithEncryptionForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
newDBName := "restoreddb_mongodb_enc_" + uuid.New().String()[:8]
@@ -328,7 +328,7 @@ func testMongodbBackupRestoreWithReadOnlyUserForVersion(
createBackupViaAPI(t, router, updatedDatabase.ID, user.Token)
backup := waitForBackupCompletion(t, router, updatedDatabase.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_mongodb_ro_" + uuid.New().String()[:8]

View File

@@ -17,7 +17,7 @@ import (
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
mysqltypes "databasus-backend/internal/features/databases/databases/mysql"
@@ -164,7 +164,7 @@ func testMysqlBackupRestoreForVersion(t *testing.T, mysqlVersion tools.MysqlVers
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_mysql"
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -261,7 +261,7 @@ func testMysqlBackupRestoreWithEncryptionForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
newDBName := "restoreddb_mysql_encrypted"
@@ -369,7 +369,7 @@ func testMysqlBackupRestoreWithReadOnlyUserForVersion(
createBackupViaAPI(t, router, updatedDatabase.ID, user.Token)
backup := waitForBackupCompletion(t, router, updatedDatabase.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_mysql_readonly"
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))

View File

@@ -18,6 +18,7 @@ import (
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
pgtypes "databasus-backend/internal/features/databases/databases/postgresql"
@@ -190,7 +191,7 @@ func Test_BackupAndRestoreSupabase_PublicSchemaOnly_RestoreIsSuccessful(t *testi
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
_, err = supabaseDB.Exec(fmt.Sprintf(`DELETE FROM public.%s`, tableName))
assert.NoError(t, err)
@@ -410,7 +411,7 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string, cp
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := fmt.Sprintf("restoreddb_%s_cpu%d_%s", pgVersion, cpuCount, uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -527,7 +528,7 @@ func testSchemaSelectionAllSchemasForVersion(t *testing.T, pgVersion string, por
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := fmt.Sprintf("restored_all_schemas_%s_%s", pgVersion, uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -655,7 +656,7 @@ func testBackupRestoreWithExcludeExtensionsForVersion(t *testing.T, pgVersion st
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := fmt.Sprintf("restored_exclude_ext_%s_%s", pgVersion, uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -789,7 +790,7 @@ func testBackupRestoreWithoutExcludeExtensionsForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := fmt.Sprintf("restored_with_ext_%s_%s", pgVersion, uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -928,7 +929,7 @@ func testBackupRestoreWithReadOnlyUserForVersion(t *testing.T, pgVersion string,
createBackupViaAPI(t, router, updatedDatabase.ID, user.Token)
backup := waitForBackupCompletion(t, router, updatedDatabase.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := fmt.Sprintf("restoreddb_readonly_%s", uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -1048,7 +1049,7 @@ func testSchemaSelectionOnlySpecifiedSchemasForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := fmt.Sprintf("restored_specific_schemas_%s_%s", pgVersion, uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -1161,7 +1162,7 @@ func testBackupRestoreWithEncryptionForVersion(t *testing.T, pgVersion string, p
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
newDBName := fmt.Sprintf("restoreddb_encrypted_%s", uuid.New().String()[:8])
@@ -1242,7 +1243,7 @@ func waitForBackupCompletion(
databaseID uuid.UUID,
token string,
timeout time.Duration,
) *backups.Backup {
) *backups_core.Backup {
startTime := time.Now()
pollInterval := 500 * time.Millisecond
@@ -1263,10 +1264,10 @@ func waitForBackupCompletion(
if len(response.Backups) > 0 {
backup := response.Backups[0]
if backup.Status == backups.BackupStatusCompleted {
if backup.Status == backups_core.BackupStatusCompleted {
return backup
}
if backup.Status == backups.BackupStatusFailed {
if backup.Status == backups_core.BackupStatusFailed {
failMsg := "unknown error"
if backup.FailMessage != nil {
failMsg = *backup.FailMessage

View File

@@ -0,0 +1,22 @@
package tests
import (
"os"
"testing"
"databasus-backend/internal/features/backups/backups/backuping"
cache_utils "databasus-backend/internal/util/cache"
)
func TestMain(m *testing.M) {
cache_utils.ClearAllCache()
backuperNode := backuping.CreateTestBackuperNode()
cancel := backuping.StartBackuperNodeForTest(&testing.T{}, backuperNode)
exitCode := m.Run()
backuping.StopBackuperNodeForTest(&testing.T{}, cancel, backuperNode)
os.Exit(exitCode)
}

View File

@@ -2,13 +2,12 @@ package users_controllers
import (
users_services "databasus-backend/internal/features/users/services"
"golang.org/x/time/rate"
cache_utils "databasus-backend/internal/util/cache"
)
var userController = &UserController{
users_services.GetUserService(),
rate.NewLimiter(rate.Limit(3), 3), // 3 rps with 3 burst
cache_utils.NewRateLimiter(cache_utils.GetValkeyClient()),
}
var settingsController = &SettingsController{

View File

@@ -14,7 +14,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"golang.org/x/time/rate"
)
func Test_AdminLifecycleE2E_CompletesSuccessfully(t *testing.T) {
@@ -185,7 +184,6 @@ func createUserTestRouter() *gin.Engine {
// Register protected routes with auth middleware
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
GetUserController().RegisterProtectedRoutes(protected.(*gin.RouterGroup))
GetUserController().SetSignInLimiter(rate.NewLimiter(rate.Limit(100), 100))
// Setup audit log service
users_services.GetUserService().SetAuditLogWriter(&AuditLogWriterStub{})

View File

@@ -3,20 +3,21 @@ package users_controllers
import (
"errors"
"net/http"
"time"
"databasus-backend/internal/config"
user_dto "databasus-backend/internal/features/users/dto"
users_errors "databasus-backend/internal/features/users/errors"
user_middleware "databasus-backend/internal/features/users/middleware"
users_services "databasus-backend/internal/features/users/services"
cache_utils "databasus-backend/internal/util/cache"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
)
type UserController struct {
userService *users_services.UserService
signinLimiter *rate.Limiter
userService *users_services.UserService
rateLimiter *cache_utils.RateLimiter
}
func (c *UserController) RegisterRoutes(router *gin.RouterGroup) {
@@ -39,10 +40,6 @@ func (c *UserController) RegisterProtectedRoutes(router *gin.RouterGroup) {
router.POST("/users/invite", c.InviteUser)
}
func (c *UserController) SetSignInLimiter(limiter *rate.Limiter) {
c.signinLimiter = limiter
}
// SignUp
// @Summary Register a new user
// @Description Register a new user with email and password
@@ -81,8 +78,14 @@ func (c *UserController) SignUp(ctx *gin.Context) {
// @Failure 429 {object} map[string]string "Rate limit exceeded"
// @Router /users/signin [post]
func (c *UserController) SignIn(ctx *gin.Context) {
// We use rate limiter to prevent brute force attacks
if !c.signinLimiter.Allow() {
var request user_dto.SignInRequestDTO
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"})
return
}
allowed, _ := c.rateLimiter.CheckLimit(request.Email, "signin", 10, 1*time.Minute)
if !allowed {
ctx.JSON(
http.StatusTooManyRequests,
gin.H{"error": "Rate limit exceeded. Please try again later."},
@@ -90,12 +93,6 @@ func (c *UserController) SignIn(ctx *gin.Context) {
return
}
var request user_dto.SignInRequestDTO
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"})
return
}
response, err := c.userService.SignIn(&request)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})

View File

@@ -1146,3 +1146,48 @@ func Test_GoogleOAuth_WithInvitedUser_ActivatesUser(t *testing.T) {
assert.Equal(t, email, response.Email)
assert.False(t, response.IsNewUser)
}
func Test_SignIn_WithExcessiveAttempts_RateLimitEnforced(t *testing.T) {
router := createUserTestRouter()
email := "ratelimit" + uuid.New().String() + "@example.com"
password := "testpassword123"
// Create a user first
signupRequest := users_dto.SignUpRequestDTO{
Email: email,
Password: password,
Name: "Rate Limit Test User",
}
test_utils.MakePostRequest(t, router, "/api/v1/users/signup", "", signupRequest, http.StatusOK)
// Make 10 sign-in attempts (should succeed)
for range 10 {
signinRequest := users_dto.SignInRequestDTO{
Email: email,
Password: password,
}
test_utils.MakePostRequest(
t,
router,
"/api/v1/users/signin",
"",
signinRequest,
http.StatusOK,
)
}
// 11th attempt should be rate limited
signinRequest := users_dto.SignInRequestDTO{
Email: email,
Password: password,
}
resp := test_utils.MakePostRequest(
t,
router,
"/api/v1/users/signin",
"",
signinRequest,
http.StatusTooManyRequests,
)
assert.Contains(t, string(resp.Body), "Rate limit exceeded")
}

125
backend/internal/util/cache/cache.go vendored Normal file
View File

@@ -0,0 +1,125 @@
package cache_utils
import (
"context"
"crypto/tls"
"databasus-backend/internal/config"
"sync"
"github.com/valkey-io/valkey-go"
)
var (
once sync.Once
valkeyClient valkey.Client
)
func getCache() valkey.Client {
once.Do(func() {
env := config.GetEnv()
options := valkey.ClientOption{
InitAddress: []string{env.ValkeyHost + ":" + env.ValkeyPort},
Password: env.ValkeyPassword,
Username: env.ValkeyUsername,
}
if env.ValkeyIsSsl {
options.TLSConfig = &tls.Config{
ServerName: env.ValkeyHost,
}
}
client, err := valkey.NewClient(options)
if err != nil {
panic(err)
}
valkeyClient = client
})
return valkeyClient
}
func GetValkeyClient() valkey.Client {
return getCache()
}
func TestCacheConnection() {
// Get Valkey client from cache package
client := getCache()
// Create a simple test cache util for strings
cacheUtil := NewCacheUtil[string](client, "test:")
// Test data
testKey := "connection_test"
testValue := "valkey_is_working"
// Test Set operation
cacheUtil.Set(testKey, &testValue)
// Test Get operation
retrievedValue := cacheUtil.Get(testKey)
// Verify the value was retrieved correctly
if retrievedValue == nil {
panic("Cache test failed: could not retrieve cached value")
}
if *retrievedValue != testValue {
panic("Cache test failed: retrieved value does not match expected")
}
// Clean up - remove test key
cacheUtil.Invalidate(testKey)
// Verify cleanup worked
cleanupCheck := cacheUtil.Get(testKey)
if cleanupCheck != nil {
panic("Cache test failed: test key was not properly invalidated")
}
}
func ClearAllCache() error {
pattern := "*"
cursor := uint64(0)
batchSize := int64(100)
cacheUtil := NewCacheUtil[string](getCache(), "")
for {
ctx, cancel := context.WithTimeout(context.Background(), DefaultQueueTimeout)
result := cacheUtil.client.Do(
ctx,
cacheUtil.client.B().Scan().Cursor(cursor).Match(pattern).Count(batchSize).Build(),
)
cancel()
if result.Error() != nil {
return result.Error()
}
scanResult, err := result.AsScanEntry()
if err != nil {
return err
}
if len(scanResult.Elements) > 0 {
delCtx, delCancel := context.WithTimeout(context.Background(), cacheUtil.timeout)
cacheUtil.client.Do(
delCtx,
cacheUtil.client.B().Del().Key(scanResult.Elements...).Build(),
)
delCancel()
}
cursor = scanResult.Cursor
if cursor == 0 {
break
}
}
return nil
}

View File

@@ -0,0 +1,51 @@
package cache_utils
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_ClearAllCache_AfterClear_CacheIsEmpty(t *testing.T) {
client := getCache()
// Arrange: Set up multiple cache entries with different prefixes
testKeys := []struct {
prefix string
key string
value string
}{
{"test:user:", "user1", "John Doe"},
{"test:user:", "user2", "Jane Smith"},
{"test:session:", "session1", "abc123"},
{"test:session:", "session2", "def456"},
{"test:data:", "item1", "value1"},
}
// Set all test keys
for _, tk := range testKeys {
cacheUtil := NewCacheUtil[string](client, tk.prefix)
cacheUtil.Set(tk.key, &tk.value)
}
// Verify keys were set correctly before clearing
for _, tk := range testKeys {
cacheUtil := NewCacheUtil[string](client, tk.prefix)
retrieved := cacheUtil.Get(tk.key)
assert.NotNil(t, retrieved, "Key %s should exist before clearing", tk.prefix+tk.key)
assert.Equal(t, tk.value, *retrieved, "Retrieved value should match set value")
}
// Act: Clear all cache
err := ClearAllCache()
// Assert: No error returned
assert.NoError(t, err, "ClearAllCache should not return an error")
// Assert: All keys should be deleted
for _, tk := range testKeys {
cacheUtil := NewCacheUtil[string](client, tk.prefix)
retrieved := cacheUtil.Get(tk.key)
assert.Nil(t, retrieved, "Key %s should be deleted after clearing", tk.prefix+tk.key)
}
}

109
backend/internal/util/cache/pubsub.go vendored Normal file
View File

@@ -0,0 +1,109 @@
package cache_utils
import (
"context"
"fmt"
"log/slog"
"sync"
"databasus-backend/internal/util/logger"
"github.com/valkey-io/valkey-go"
)
type PubSubManager struct {
client valkey.Client
subscriptions map[string]context.CancelFunc
mu sync.RWMutex
logger *slog.Logger
}
func NewPubSubManager() *PubSubManager {
return &PubSubManager{
client: getCache(),
subscriptions: make(map[string]context.CancelFunc),
logger: logger.GetLogger(),
}
}
func (m *PubSubManager) Subscribe(
ctx context.Context,
channel string,
handler func(message string),
) error {
m.mu.Lock()
if _, exists := m.subscriptions[channel]; exists {
m.mu.Unlock()
return fmt.Errorf("already subscribed to channel: %s", channel)
}
subCtx, cancel := context.WithCancel(ctx)
m.subscriptions[channel] = cancel
m.mu.Unlock()
go m.subscriptionLoop(subCtx, channel, handler)
return nil
}
func (m *PubSubManager) Publish(ctx context.Context, channel string, message string) error {
cmd := m.client.B().Publish().Channel(channel).Message(message).Build()
result := m.client.Do(ctx, cmd)
if err := result.Error(); err != nil {
m.logger.Error("Failed to publish message to Redis", "channel", channel, "error", err)
return fmt.Errorf("failed to publish message: %w", err)
}
return nil
}
func (m *PubSubManager) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
for channel, cancel := range m.subscriptions {
cancel()
delete(m.subscriptions, channel)
}
return nil
}
func (m *PubSubManager) subscriptionLoop(
ctx context.Context,
channel string,
handler func(message string),
) {
defer func() {
if r := recover(); r != nil {
m.logger.Error("Panic in subscription handler", "channel", channel, "panic", r)
}
}()
m.logger.Info("Starting subscription", "channel", channel)
err := m.client.Receive(
ctx,
m.client.B().Subscribe().Channel(channel).Build(),
func(msg valkey.PubSubMessage) {
defer func() {
if r := recover(); r != nil {
m.logger.Error("Panic in message handler", "channel", channel, "panic", r)
}
}()
handler(msg.Message)
},
)
if err != nil && ctx.Err() == nil {
m.logger.Error("Subscription error", "channel", channel, "error", err)
} else if ctx.Err() != nil {
m.logger.Info("Subscription cancelled", "channel", channel)
}
m.mu.Lock()
delete(m.subscriptions, channel)
m.mu.Unlock()
}

View File

@@ -0,0 +1,85 @@
package cache_utils
import (
"context"
"fmt"
"time"
"github.com/google/uuid"
"github.com/valkey-io/valkey-go"
)
type RateLimiter struct {
client valkey.Client
}
func NewRateLimiter(client valkey.Client) *RateLimiter {
return &RateLimiter{
client: client,
}
}
func (r *RateLimiter) CheckLimit(
identifier string,
endpoint string,
maxRequests int,
windowDuration time.Duration,
) (bool, error) {
requestID := uuid.New().String()
keyPrefix := fmt.Sprintf("ratelimit:%s:%s", endpoint, identifier)
fullKey := fmt.Sprintf("%s:%s", keyPrefix, requestID)
ctx, cancel := context.WithTimeout(context.Background(), DefaultCacheTimeout)
defer cancel()
// Set the key with TTL
setCmd := r.client.B().
Set().
Key(fullKey).
Value("1").
ExSeconds(int64(windowDuration.Seconds())).
Build()
if err := r.client.Do(ctx, setCmd).Error(); err != nil {
return true, fmt.Errorf("failed to set rate limit key: %w", err)
}
// Count keys matching the pattern
count, err := r.countKeys(keyPrefix)
if err != nil {
return true, fmt.Errorf("failed to count rate limit keys: %w", err)
}
return count <= maxRequests, nil
}
func (r *RateLimiter) countKeys(keyPrefix string) (int, error) {
pattern := keyPrefix + ":*"
cursor := uint64(0)
totalCount := 0
for {
ctx, cancel := context.WithTimeout(context.Background(), DefaultCacheTimeout)
scanCmd := r.client.B().Scan().Cursor(cursor).Match(pattern).Count(100).Build()
result := r.client.Do(ctx, scanCmd)
cancel()
if result.Error() != nil {
return 0, result.Error()
}
scanResult, err := result.AsScanEntry()
if err != nil {
return 0, err
}
totalCount += len(scanResult.Elements)
cursor = scanResult.Cursor
if cursor == 0 {
break
}
}
return totalCount, nil
}

76
backend/internal/util/cache/utils.go vendored Normal file
View File

@@ -0,0 +1,76 @@
package cache_utils
import (
"context"
"encoding/json"
"time"
"github.com/valkey-io/valkey-go"
)
const (
DefaultCacheTimeout = 10 * time.Second
DefaultCacheExpiry = 10 * time.Minute
DefaultQueueTimeout = 30 * time.Second
)
type CacheUtil[T any] struct {
client valkey.Client
prefix string
timeout time.Duration
expiry time.Duration
}
func NewCacheUtil[T any](client valkey.Client, prefix string) *CacheUtil[T] {
return &CacheUtil[T]{
client: client,
prefix: prefix,
timeout: DefaultCacheTimeout,
expiry: DefaultCacheExpiry,
}
}
func (c *CacheUtil[T]) Get(key string) *T {
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
defer cancel()
fullKey := c.prefix + key
result := c.client.Do(ctx, c.client.B().Get().Key(fullKey).Build())
if result.Error() != nil {
return nil
}
data, err := result.AsBytes()
if err != nil {
return nil
}
var item T
if err := json.Unmarshal(data, &item); err != nil {
return nil
}
return &item
}
func (c *CacheUtil[T]) Set(key string, item *T) {
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
defer cancel()
data, err := json.Marshal(item)
if err != nil {
return
}
fullKey := c.prefix + key
c.client.Do(ctx, c.client.B().Set().Key(fullKey).Value(string(data)).Ex(c.expiry).Build())
}
func (c *CacheUtil[T]) Invalidate(key string) {
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
defer cancel()
fullKey := c.prefix + key
c.client.Do(ctx, c.client.B().Del().Key(fullKey).Build())
}

View File

@@ -87,17 +87,20 @@ func VerifyMariadbInstallation(
logger *slog.Logger,
envMode env_utils.EnvMode,
mariadbInstallDir string,
isShowLogs bool,
) {
clientVersions := []MariadbClientVersion{MariadbClientLegacy, MariadbClientModern}
for _, clientVersion := range clientVersions {
binDir := getMariadbBasePath(clientVersion, envMode, mariadbInstallDir)
logger.Info(
"Verifying MariaDB installation",
"clientVersion", clientVersion,
"path", binDir,
)
if isShowLogs {
logger.Info(
"Verifying MariaDB installation",
"clientVersion", clientVersion,
"path", binDir,
)
}
if _, err := os.Stat(binDir); os.IsNotExist(err) {
if envMode == env_utils.EnvModeDevelopment {
@@ -133,12 +136,14 @@ func VerifyMariadbInstallation(
}
cmdPath := GetMariadbExecutable(cmd, dummyServerVersion, envMode, mariadbInstallDir)
logger.Info(
"Checking for MariaDB command",
"clientVersion", clientVersion,
"command", cmd,
"path", cmdPath,
)
if isShowLogs {
logger.Info(
"Checking for MariaDB command",
"clientVersion", clientVersion,
"command", cmd,
"path", cmdPath,
)
}
if _, err := os.Stat(cmdPath); os.IsNotExist(err) {
if envMode == env_utils.EnvModeDevelopment {
@@ -162,11 +167,15 @@ func VerifyMariadbInstallation(
continue
}
logger.Info("MariaDB command found", "clientVersion", clientVersion, "command", cmd)
if isShowLogs {
logger.Info("MariaDB command found", "clientVersion", clientVersion, "command", cmd)
}
}
}
logger.Info("MariaDB client tools verification completed!")
if isShowLogs {
logger.Info("MariaDB client tools verification completed!")
}
}
// IsMariadbBackupVersionHigherThanRestoreVersion checks if backup was made with

Some files were not shown because too many files have changed in this diff Show More