mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 00:32:03 +02:00
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
27bec15a29 | ||
|
|
d98baa0656 | ||
|
|
4344f5ea5e | ||
|
|
7c6afa5b88 | ||
|
|
dbac799e1b | ||
|
|
7ee3817089 | ||
|
|
bae6f7f007 | ||
|
|
55dc087ddd | ||
|
|
c94d0db637 | ||
|
|
a1adef2261 | ||
|
|
4602dc3f88 | ||
|
|
cbbfc5ea8f | ||
|
|
dd1072e230 | ||
|
|
a495e5317a | ||
|
|
7eed647038 | ||
|
|
6973241e25 | ||
|
|
ab181f5b81 | ||
|
|
b60a0cc170 | ||
|
|
f319a497b3 |
123
.github/workflows/ci-release.yml
vendored
123
.github/workflows/ci-release.yml
vendored
@@ -19,16 +19,6 @@ jobs:
|
||||
with:
|
||||
go-version: "1.24.9"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
~/.cache/go-build
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('backend/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Install golangci-lint
|
||||
run: |
|
||||
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/HEAD/install.sh | sh -s -- -b $(go env GOPATH)/bin v2.7.2
|
||||
@@ -63,8 +53,6 @@ jobs:
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
cache: "npm"
|
||||
cache-dependency-path: frontend/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
@@ -93,8 +81,6 @@ jobs:
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
cache: "npm"
|
||||
cache-dependency-path: frontend/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
@@ -136,16 +122,6 @@ jobs:
|
||||
with:
|
||||
go-version: "1.24.9"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
~/.cache/go-build
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('backend/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Create .env file for testing
|
||||
run: |
|
||||
cd backend
|
||||
@@ -321,34 +297,6 @@ jobs:
|
||||
mkdir -p databasus-data/backups
|
||||
mkdir -p databasus-data/temp
|
||||
|
||||
- name: Cache PostgreSQL client tools
|
||||
id: cache-postgres
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: /usr/lib/postgresql
|
||||
key: postgres-clients-12-18-v1
|
||||
|
||||
- name: Cache MySQL client tools
|
||||
id: cache-mysql
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: backend/tools/mysql
|
||||
key: mysql-clients-57-80-84-9-v1
|
||||
|
||||
- name: Cache MariaDB client tools
|
||||
id: cache-mariadb
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: backend/tools/mariadb
|
||||
key: mariadb-clients-106-121-v1
|
||||
|
||||
- name: Cache MongoDB Database Tools
|
||||
id: cache-mongodb
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: backend/tools/mongodb
|
||||
key: mongodb-database-tools-100.10.0-v1
|
||||
|
||||
- name: Install MySQL dependencies
|
||||
run: |
|
||||
sudo apt-get update -qq
|
||||
@@ -356,31 +304,58 @@ jobs:
|
||||
sudo ln -sf /usr/lib/x86_64-linux-gnu/libncurses.so.6 /usr/lib/x86_64-linux-gnu/libncurses.so.5
|
||||
sudo ln -sf /usr/lib/x86_64-linux-gnu/libtinfo.so.6 /usr/lib/x86_64-linux-gnu/libtinfo.so.5
|
||||
|
||||
- name: Install PostgreSQL, MySQL, MariaDB and MongoDB client tools
|
||||
if: steps.cache-postgres.outputs.cache-hit != 'true' || steps.cache-mysql.outputs.cache-hit != 'true' || steps.cache-mariadb.outputs.cache-hit != 'true' || steps.cache-mongodb.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
chmod +x backend/tools/download_linux.sh
|
||||
cd backend/tools
|
||||
./download_linux.sh
|
||||
|
||||
- name: Setup PostgreSQL symlinks (when using cache)
|
||||
if: steps.cache-postgres.outputs.cache-hit == 'true'
|
||||
- name: Setup PostgreSQL, MySQL and MariaDB client tools from pre-built assets
|
||||
run: |
|
||||
cd backend/tools
|
||||
mkdir -p postgresql
|
||||
|
||||
# Create directory structure
|
||||
mkdir -p postgresql mysql mariadb mongodb/bin
|
||||
|
||||
# Copy PostgreSQL client tools (12-18) from pre-built assets
|
||||
for version in 12 13 14 15 16 17 18; do
|
||||
version_dir="postgresql/postgresql-$version"
|
||||
mkdir -p "$version_dir/bin"
|
||||
pg_bin_dir="/usr/lib/postgresql/$version/bin"
|
||||
if [ -d "$pg_bin_dir" ]; then
|
||||
ln -sf "$pg_bin_dir/pg_dump" "$version_dir/bin/pg_dump"
|
||||
ln -sf "$pg_bin_dir/pg_dumpall" "$version_dir/bin/pg_dumpall"
|
||||
ln -sf "$pg_bin_dir/psql" "$version_dir/bin/psql"
|
||||
ln -sf "$pg_bin_dir/pg_restore" "$version_dir/bin/pg_restore"
|
||||
ln -sf "$pg_bin_dir/createdb" "$version_dir/bin/createdb"
|
||||
ln -sf "$pg_bin_dir/dropdb" "$version_dir/bin/dropdb"
|
||||
fi
|
||||
mkdir -p postgresql/postgresql-$version
|
||||
cp -r ../../assets/tools/x64/postgresql/postgresql-$version/bin postgresql/postgresql-$version/
|
||||
done
|
||||
|
||||
# Copy MySQL client tools (5.7, 8.0, 8.4, 9) from pre-built assets
|
||||
for version in 5.7 8.0 8.4 9; do
|
||||
mkdir -p mysql/mysql-$version
|
||||
cp -r ../../assets/tools/x64/mysql/mysql-$version/bin mysql/mysql-$version/
|
||||
done
|
||||
|
||||
# Copy MariaDB client tools (10.6, 12.1) from pre-built assets
|
||||
for version in 10.6 12.1; do
|
||||
mkdir -p mariadb/mariadb-$version
|
||||
cp -r ../../assets/tools/x64/mariadb/mariadb-$version/bin mariadb/mariadb-$version/
|
||||
done
|
||||
|
||||
# Make all binaries executable
|
||||
chmod +x postgresql/*/bin/*
|
||||
chmod +x mysql/*/bin/*
|
||||
chmod +x mariadb/*/bin/*
|
||||
|
||||
echo "Pre-built client tools setup complete"
|
||||
|
||||
- name: Install MongoDB Database Tools
|
||||
run: |
|
||||
cd backend/tools
|
||||
|
||||
# MongoDB Database Tools must be downloaded (not in pre-built assets)
|
||||
# They are backward compatible - single version supports all servers (4.0-8.0)
|
||||
MONGODB_TOOLS_URL="https://fastdl.mongodb.org/tools/db/mongodb-database-tools-debian12-x86_64-100.10.0.deb"
|
||||
|
||||
echo "Downloading MongoDB Database Tools..."
|
||||
wget -q "$MONGODB_TOOLS_URL" -O /tmp/mongodb-database-tools.deb
|
||||
|
||||
echo "Installing MongoDB Database Tools..."
|
||||
sudo dpkg -i /tmp/mongodb-database-tools.deb || sudo apt-get install -f -y --no-install-recommends
|
||||
|
||||
# Create symlinks to tools directory
|
||||
ln -sf /usr/bin/mongodump mongodb/bin/mongodump
|
||||
ln -sf /usr/bin/mongorestore mongodb/bin/mongorestore
|
||||
|
||||
rm -f /tmp/mongodb-database-tools.deb
|
||||
echo "MongoDB Database Tools installed successfully"
|
||||
|
||||
- name: Verify MariaDB client tools exist
|
||||
run: |
|
||||
@@ -726,4 +701,4 @@ jobs:
|
||||
- name: Push Helm chart to GHCR
|
||||
run: |
|
||||
VERSION="${{ needs.determine-version.outputs.new_version }}"
|
||||
helm push databasus-${VERSION}.tgz oci://ghcr.io/databasus/charts
|
||||
helm push databasus-${VERSION}.tgz oci://ghcr.io/databasus/charts
|
||||
@@ -27,3 +27,10 @@ repos:
|
||||
language: system
|
||||
files: ^backend/.*\.go$
|
||||
pass_filenames: false
|
||||
|
||||
- id: backend-go-mod-tidy
|
||||
name: Backend Go Mod Tidy
|
||||
entry: bash -c "cd backend && go mod tidy"
|
||||
language: system
|
||||
files: ^backend/.*\.go$
|
||||
pass_filenames: false
|
||||
|
||||
@@ -1,152 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
Always place private methods to the bottom of file
|
||||
|
||||
**This rule applies to ALL Go files including tests, services, controllers, repositories, etc.**
|
||||
|
||||
In Go, exported (public) functions/methods start with uppercase letters, while unexported (private) ones start with lowercase letters.
|
||||
|
||||
## Structure Order:
|
||||
|
||||
1. Type definitions and constants
|
||||
2. Public methods/functions (uppercase)
|
||||
3. Private methods/functions (lowercase)
|
||||
|
||||
## Examples:
|
||||
|
||||
### Service with methods:
|
||||
|
||||
```go
|
||||
type UserService struct {
|
||||
repository *UserRepository
|
||||
}
|
||||
|
||||
// Public methods first
|
||||
func (s *UserService) CreateUser(user *User) error {
|
||||
if err := s.validateUser(user); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.repository.Save(user)
|
||||
}
|
||||
|
||||
func (s *UserService) GetUser(id uuid.UUID) (*User, error) {
|
||||
return s.repository.FindByID(id)
|
||||
}
|
||||
|
||||
// Private methods at the bottom
|
||||
func (s *UserService) validateUser(user *User) error {
|
||||
if user.Name == "" {
|
||||
return errors.New("name is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
### Package-level functions:
|
||||
|
||||
```go
|
||||
package utils
|
||||
|
||||
// Public functions first
|
||||
func ProcessData(data []byte) (Result, error) {
|
||||
cleaned := sanitizeInput(data)
|
||||
return parseData(cleaned)
|
||||
}
|
||||
|
||||
func ValidateInput(input string) bool {
|
||||
return isValidFormat(input) && checkLength(input)
|
||||
}
|
||||
|
||||
// Private functions at the bottom
|
||||
func sanitizeInput(data []byte) []byte {
|
||||
// implementation
|
||||
}
|
||||
|
||||
func parseData(data []byte) (Result, error) {
|
||||
// implementation
|
||||
}
|
||||
|
||||
func isValidFormat(input string) bool {
|
||||
// implementation
|
||||
}
|
||||
|
||||
func checkLength(input string) bool {
|
||||
// implementation
|
||||
}
|
||||
```
|
||||
|
||||
### Test files:
|
||||
|
||||
```go
|
||||
package user_test
|
||||
|
||||
// Public test functions first
|
||||
func Test_CreateUser_ValidInput_UserCreated(t *testing.T) {
|
||||
user := createTestUser()
|
||||
result, err := service.CreateUser(user)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
|
||||
func Test_GetUser_ExistingUser_ReturnsUser(t *testing.T) {
|
||||
user := createTestUser()
|
||||
// test implementation
|
||||
}
|
||||
|
||||
// Private helper functions at the bottom
|
||||
func createTestUser() *User {
|
||||
return &User{
|
||||
Name: "Test User",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
}
|
||||
|
||||
func setupTestDatabase() *Database {
|
||||
// setup implementation
|
||||
}
|
||||
```
|
||||
|
||||
### Controller example:
|
||||
|
||||
```go
|
||||
type ProjectController struct {
|
||||
service *ProjectService
|
||||
}
|
||||
|
||||
// Public HTTP handlers first
|
||||
func (c *ProjectController) CreateProject(ctx *gin.Context) {
|
||||
var request CreateProjectRequest
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
c.handleError(ctx, err)
|
||||
return
|
||||
}
|
||||
// handler logic
|
||||
}
|
||||
|
||||
func (c *ProjectController) GetProject(ctx *gin.Context) {
|
||||
projectID := c.extractProjectID(ctx)
|
||||
// handler logic
|
||||
}
|
||||
|
||||
// Private helper methods at the bottom
|
||||
func (c *ProjectController) handleError(ctx *gin.Context, err error) {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
}
|
||||
|
||||
func (c *ProjectController) extractProjectID(ctx *gin.Context) uuid.UUID {
|
||||
return uuid.MustParse(ctx.Param("projectId"))
|
||||
}
|
||||
```
|
||||
|
||||
## Key Points:
|
||||
|
||||
- **Exported/Public** = starts with uppercase letter (CreateUser, GetProject)
|
||||
- **Unexported/Private** = starts with lowercase letter (validateUser, handleError)
|
||||
- This improves code readability by showing the public API first
|
||||
- Private helpers are implementation details, so they go at the bottom
|
||||
- Apply this rule consistently across ALL Go files in the project
|
||||
@@ -1,45 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
## Comment Guidelines
|
||||
|
||||
1. **No obvious comments** - Don't state what the code already clearly shows
|
||||
2. **Functions and variables should have meaningful names** - Code should be self-documenting
|
||||
3. **Comments for unclear code only** - Only add comments when code logic isn't immediately clear
|
||||
|
||||
## Key Principles:
|
||||
|
||||
- **Code should tell a story** - Use descriptive variable and function names
|
||||
- **Comments explain WHY, not WHAT** - The code shows what happens, comments explain business logic or complex decisions
|
||||
- **Prefer refactoring over commenting** - If code needs explaining, consider making it clearer instead
|
||||
- **API documentation is required** - Swagger comments for all HTTP endpoints are mandatory
|
||||
- **Complex algorithms deserve comments** - Mathematical formulas, business rules, or non-obvious optimizations
|
||||
|
||||
Example of useless comment:
|
||||
|
||||
1.
|
||||
|
||||
```sql
|
||||
// Create projects table
|
||||
CREATE TABLE projects (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
name TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
```
|
||||
|
||||
2.
|
||||
|
||||
```go
|
||||
// Create test project
|
||||
project := CreateTestProject(projectName, user, router)
|
||||
```
|
||||
|
||||
3.
|
||||
|
||||
```go
|
||||
// CreateValidLogItems creates valid log items for testing
|
||||
func CreateValidLogItems(count int, uniqueID string) []logs_receiving.LogItemRequestDTO {
|
||||
```
|
||||
@@ -1,133 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
1. When we write controller:
|
||||
|
||||
- we combine all routes to single controller
|
||||
- names them as .WhatWeDo (not "handlers") concept
|
||||
|
||||
2. We use gin and \*gin.Context for all routes.
|
||||
Example:
|
||||
|
||||
func (c *TasksController) GetAvailableTasks(ctx *gin.Context) ...
|
||||
|
||||
3. We document all routes with Swagger in the following format:
|
||||
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
)
|
||||
|
||||
type AuditLogController struct {
|
||||
auditLogService \*AuditLogService
|
||||
}
|
||||
|
||||
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// All audit log endpoints require authentication (handled in main.go)
|
||||
auditRoutes := router.Group("/audit-logs")
|
||||
|
||||
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
|
||||
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
|
||||
|
||||
}
|
||||
|
||||
// GetGlobalAuditLogs
|
||||
// @Summary Get global audit logs (ADMIN only)
|
||||
// @Description Retrieve all audit logs across the system
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/global [get]
|
||||
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(\*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "only administrators can view global audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
|
||||
}
|
||||
|
||||
// GetUserAuditLogs
|
||||
// @Summary Get user audit logs
|
||||
// @Description Retrieve audit logs for a specific user
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param userId path string true "User ID"
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/users/{userId} [get]
|
||||
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(\*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
userIDStr := ctx.Param("userId")
|
||||
targetUserID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view user audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
|
||||
}
|
||||
@@ -1,671 +0,0 @@
|
||||
---
|
||||
alwaysApply: false
|
||||
---
|
||||
|
||||
This is example of CRUD:
|
||||
|
||||
------ backend/internal/features/audit_logs/controller.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogController struct {
|
||||
auditLogService *AuditLogService
|
||||
}
|
||||
|
||||
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// All audit log endpoints require authentication (handled in main.go)
|
||||
auditRoutes := router.Group("/audit-logs")
|
||||
|
||||
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
|
||||
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
|
||||
}
|
||||
|
||||
// GetGlobalAuditLogs
|
||||
// @Summary Get global audit logs (ADMIN only)
|
||||
// @Description Retrieve all audit logs across the system
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/global [get]
|
||||
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "only administrators can view global audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// GetUserAuditLogs
|
||||
// @Summary Get user audit logs
|
||||
// @Description Retrieve audit logs for a specific user
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param userId path string true "User ID"
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/users/{userId} [get]
|
||||
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
userIDStr := ctx.Param("userId")
|
||||
targetUserID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view user audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/controller_test.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
"databasus-backend/internal/storage"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_GetGlobalAuditLogs_AdminSucceedsAndMemberGetsForbidden(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
memberUser := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
projectID := uuid.New()
|
||||
|
||||
// Create test logs
|
||||
createAuditLog(service, "Test log with user", &adminUser.UserID, nil)
|
||||
createAuditLog(service, "Test log with project", nil, &projectID)
|
||||
createAuditLog(service, "Test log standalone", nil, nil)
|
||||
|
||||
// Test ADMIN can access global logs
|
||||
var response GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
"/api/v1/audit-logs/global?limit=10", "Bearer "+adminUser.Token, http.StatusOK, &response)
|
||||
|
||||
assert.GreaterOrEqual(t, len(response.AuditLogs), 3)
|
||||
assert.GreaterOrEqual(t, response.Total, int64(3))
|
||||
|
||||
messages := extractMessages(response.AuditLogs)
|
||||
assert.Contains(t, messages, "Test log with user")
|
||||
assert.Contains(t, messages, "Test log with project")
|
||||
assert.Contains(t, messages, "Test log standalone")
|
||||
|
||||
// Test MEMBER cannot access global logs
|
||||
resp := test_utils.MakeGetRequest(t, router, "/api/v1/audit-logs/global",
|
||||
"Bearer "+memberUser.Token, http.StatusForbidden)
|
||||
assert.Contains(t, string(resp.Body), "only administrators can view global audit logs")
|
||||
}
|
||||
|
||||
func Test_GetUserAuditLogs_PermissionsEnforcedCorrectly(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
projectID := uuid.New()
|
||||
|
||||
// Create test logs for different users
|
||||
createAuditLog(service, "Test log user1 first", &user1.UserID, nil)
|
||||
createAuditLog(service, "Test log user1 second", &user1.UserID, &projectID)
|
||||
createAuditLog(service, "Test log user2 first", &user2.UserID, nil)
|
||||
createAuditLog(service, "Test log user2 second", &user2.UserID, &projectID)
|
||||
createAuditLog(service, "Test project log", nil, &projectID)
|
||||
|
||||
// Test ADMIN can view any user's logs
|
||||
var user1Response GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s?limit=10", user1.UserID.String()),
|
||||
"Bearer "+adminUser.Token, http.StatusOK, &user1Response)
|
||||
|
||||
assert.Equal(t, 2, len(user1Response.AuditLogs))
|
||||
messages := extractMessages(user1Response.AuditLogs)
|
||||
assert.Contains(t, messages, "Test log user1 first")
|
||||
assert.Contains(t, messages, "Test log user1 second")
|
||||
|
||||
// Test user can view own logs
|
||||
var ownLogsResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s", user2.UserID.String()),
|
||||
"Bearer "+user2.Token, http.StatusOK, &ownLogsResponse)
|
||||
assert.Equal(t, 2, len(ownLogsResponse.AuditLogs))
|
||||
|
||||
// Test user cannot view other user's logs
|
||||
resp := test_utils.MakeGetRequest(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s", user1.UserID.String()),
|
||||
"Bearer "+user2.Token, http.StatusForbidden)
|
||||
|
||||
assert.Contains(t, string(resp.Body), "insufficient permissions")
|
||||
}
|
||||
|
||||
func Test_FilterAuditLogsByTime_ReturnsOnlyLogsBeforeDate(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
db := storage.GetDb()
|
||||
baseTime := time.Now().UTC()
|
||||
|
||||
// Create logs with different timestamps
|
||||
createTimedLog(db, &adminUser.UserID, "Test old log", baseTime.Add(-2*time.Hour))
|
||||
createTimedLog(db, &adminUser.UserID, "Test recent log", baseTime.Add(-30*time.Minute))
|
||||
createAuditLog(service, "Test current log", &adminUser.UserID, nil)
|
||||
|
||||
// Test filtering - get logs before 1 hour ago
|
||||
beforeTime := baseTime.Add(-1 * time.Hour)
|
||||
var filteredResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/global?beforeDate=%s", beforeTime.Format(time.RFC3339)),
|
||||
"Bearer "+adminUser.Token, http.StatusOK, &filteredResponse)
|
||||
|
||||
// Verify only old log is returned
|
||||
messages := extractMessages(filteredResponse.AuditLogs)
|
||||
assert.Contains(t, messages, "Test old log")
|
||||
assert.NotContains(t, messages, "Test recent log")
|
||||
assert.NotContains(t, messages, "Test current log")
|
||||
|
||||
// Test without filter - should get all logs
|
||||
var allResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router, "/api/v1/audit-logs/global",
|
||||
"Bearer "+adminUser.Token, http.StatusOK, &allResponse)
|
||||
assert.GreaterOrEqual(t, len(allResponse.AuditLogs), 3)
|
||||
}
|
||||
|
||||
func createRouter() *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
SetupDependencies()
|
||||
|
||||
v1 := router.Group("/api/v1")
|
||||
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
|
||||
GetAuditLogController().RegisterRoutes(protected.(*gin.RouterGroup))
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/di.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var auditLogRepository = &AuditLogRepository{}
|
||||
var auditLogService = &AuditLogService{
|
||||
auditLogRepository: auditLogRepository,
|
||||
logger: logger.GetLogger(),
|
||||
}
|
||||
var auditLogController = &AuditLogController{
|
||||
auditLogService: auditLogService,
|
||||
}
|
||||
|
||||
func GetAuditLogService() *AuditLogService {
|
||||
return auditLogService
|
||||
}
|
||||
|
||||
func GetAuditLogController() *AuditLogController {
|
||||
return auditLogController
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
users_services.GetUserService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/dto.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import "time"
|
||||
|
||||
type GetAuditLogsRequest struct {
|
||||
Limit int `form:"limit" json:"limit"`
|
||||
Offset int `form:"offset" json:"offset"`
|
||||
BeforeDate *time.Time `form:"beforeDate" json:"beforeDate"`
|
||||
}
|
||||
|
||||
type GetAuditLogsResponse struct {
|
||||
AuditLogs []*AuditLog `json:"auditLogs"`
|
||||
Total int64 `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/models.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLog struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id"`
|
||||
UserID *uuid.UUID `json:"userId" gorm:"column:user_id"`
|
||||
ProjectID *uuid.UUID `json:"projectId" gorm:"column:project_id"`
|
||||
Message string `json:"message" gorm:"column:message"`
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
|
||||
}
|
||||
|
||||
func (AuditLog) TableName() string {
|
||||
return "audit_logs"
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/repository.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/storage"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogRepository struct{}
|
||||
|
||||
func (r *AuditLogRepository) Create(auditLog *AuditLog) error {
|
||||
if auditLog.ID == uuid.Nil {
|
||||
auditLog.ID = uuid.New()
|
||||
}
|
||||
|
||||
return storage.GetDb().Create(auditLog).Error
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetGlobal(limit, offset int, beforeDate *time.Time) ([]*AuditLog, error) {
|
||||
var auditLogs []*AuditLog
|
||||
|
||||
query := storage.GetDb().Order("created_at DESC")
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetByUser(
|
||||
userID uuid.UUID,
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLog, error) {
|
||||
var auditLogs []*AuditLog
|
||||
|
||||
query := storage.GetDb().
|
||||
Where("user_id = ?", userID).
|
||||
Order("created_at DESC")
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetByProject(
|
||||
projectID uuid.UUID,
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLog, error) {
|
||||
var auditLogs []*AuditLog
|
||||
|
||||
query := storage.GetDb().
|
||||
Where("project_id = ?", projectID).
|
||||
Order("created_at DESC")
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) CountGlobal(beforeDate *time.Time) (int64, error) {
|
||||
var count int64
|
||||
query := storage.GetDb().Model(&AuditLog{})
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/service.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogService struct {
|
||||
auditLogRepository *AuditLogRepository
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *AuditLogService) WriteAuditLog(
|
||||
message string,
|
||||
userID *uuid.UUID,
|
||||
projectID *uuid.UUID,
|
||||
) {
|
||||
auditLog := &AuditLog{
|
||||
UserID: userID,
|
||||
ProjectID: projectID,
|
||||
Message: message,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
err := s.auditLogRepository.Create(auditLog)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to create audit log", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuditLogService) CreateAuditLog(auditLog *AuditLog) error {
|
||||
return s.auditLogRepository.Create(auditLog)
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetGlobalAuditLogs(
|
||||
user *user_models.User,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
if user.Role != user_enums.UserRoleAdmin {
|
||||
return nil, errors.New("only administrators can view global audit logs")
|
||||
}
|
||||
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetGlobal(limit, offset, request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
total, err := s.auditLogRepository.CountGlobal(request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: total,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetUserAuditLogs(
|
||||
targetUserID uuid.UUID,
|
||||
user *user_models.User,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
// Users can view their own logs, ADMIN can view any user's logs
|
||||
if user.Role != user_enums.UserRoleAdmin && user.ID != targetUserID {
|
||||
return nil, errors.New("insufficient permissions to view user audit logs")
|
||||
}
|
||||
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetByUser(targetUserID, limit, offset, request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: int64(len(auditLogs)),
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetProjectAuditLogs(
|
||||
projectID uuid.UUID,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetByProject(projectID, limit, offset, request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: int64(len(auditLogs)),
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/service_test.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func Test_AuditLogs_ProjectSpecificLogs(t *testing.T) {
|
||||
service := GetAuditLogService()
|
||||
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
project1ID, project2ID := uuid.New(), uuid.New()
|
||||
|
||||
// Create test logs for projects
|
||||
createAuditLog(service, "Test project1 log first", &user1.UserID, &project1ID)
|
||||
createAuditLog(service, "Test project1 log second", &user2.UserID, &project1ID)
|
||||
createAuditLog(service, "Test project2 log first", &user1.UserID, &project2ID)
|
||||
createAuditLog(service, "Test project2 log second", &user2.UserID, &project2ID)
|
||||
createAuditLog(service, "Test no project log", &user1.UserID, nil)
|
||||
|
||||
request := &GetAuditLogsRequest{Limit: 10, Offset: 0}
|
||||
|
||||
// Test project 1 logs
|
||||
project1Response, err := service.GetProjectAuditLogs(project1ID, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(project1Response.AuditLogs))
|
||||
|
||||
messages := extractMessages(project1Response.AuditLogs)
|
||||
assert.Contains(t, messages, "Test project1 log first")
|
||||
assert.Contains(t, messages, "Test project1 log second")
|
||||
for _, log := range project1Response.AuditLogs {
|
||||
assert.Equal(t, &project1ID, log.ProjectID)
|
||||
}
|
||||
|
||||
// Test project 2 logs
|
||||
project2Response, err := service.GetProjectAuditLogs(project2ID, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(project2Response.AuditLogs))
|
||||
|
||||
messages2 := extractMessages(project2Response.AuditLogs)
|
||||
assert.Contains(t, messages2, "Test project2 log first")
|
||||
assert.Contains(t, messages2, "Test project2 log second")
|
||||
|
||||
// Test pagination
|
||||
limitedResponse, err := service.GetProjectAuditLogs(project1ID,
|
||||
&GetAuditLogsRequest{Limit: 1, Offset: 0})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(limitedResponse.AuditLogs))
|
||||
assert.Equal(t, 1, limitedResponse.Limit)
|
||||
|
||||
// Test beforeDate filter
|
||||
beforeTime := time.Now().UTC().Add(-1 * time.Minute)
|
||||
filteredResponse, err := service.GetProjectAuditLogs(project1ID,
|
||||
&GetAuditLogsRequest{Limit: 10, BeforeDate: &beforeTime})
|
||||
assert.NoError(t, err)
|
||||
for _, log := range filteredResponse.AuditLogs {
|
||||
assert.True(t, log.CreatedAt.Before(beforeTime))
|
||||
}
|
||||
}
|
||||
|
||||
func createAuditLog(service *AuditLogService, message string, userID, projectID *uuid.UUID) {
|
||||
service.WriteAuditLog(message, userID, projectID)
|
||||
}
|
||||
|
||||
func extractMessages(logs []*AuditLog) []string {
|
||||
messages := make([]string, len(logs))
|
||||
for i, log := range logs {
|
||||
messages[i] = log.Message
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
func createTimedLog(db *gorm.DB, userID *uuid.UUID, message string, createdAt time.Time) {
|
||||
log := &AuditLog{
|
||||
ID: uuid.New(),
|
||||
UserID: userID,
|
||||
Message: message,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
db.Create(log)
|
||||
}
|
||||
|
||||
```
|
||||
@@ -1,74 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
For DI files use implicit fields declaration styles (espesially
|
||||
for controllers, services, repositories, use cases, etc., not simple
|
||||
data structures).
|
||||
|
||||
So, instead of:
|
||||
|
||||
var orderController = &OrderController{
|
||||
orderService: orderService,
|
||||
botUserService: bot_users.GetBotUserService(),
|
||||
botService: bots.GetBotService(),
|
||||
userService: users.GetUserService(),
|
||||
}
|
||||
|
||||
Use:
|
||||
|
||||
var orderController = &OrderController{
|
||||
orderService,
|
||||
bot_users.GetBotUserService(),
|
||||
bots.GetBotService(),
|
||||
users.GetUserService(),
|
||||
}
|
||||
|
||||
This is needed to avoid forgetting to update DI style
|
||||
when we add new dependency.
|
||||
|
||||
---
|
||||
|
||||
Please force such usage if file look like this (see some
|
||||
services\controllers\repos definitions and getters):
|
||||
|
||||
var orderBackgroundService = &OrderBackgroundService{
|
||||
orderService: orderService,
|
||||
orderPaymentRepository: orderPaymentRepository,
|
||||
botService: bots.GetBotService(),
|
||||
paymentSettingsService: payment_settings.GetPaymentSettingsService(),
|
||||
|
||||
orderSubscriptionListeners: []OrderSubscriptionListener{},
|
||||
}
|
||||
|
||||
var orderController = &OrderController{
|
||||
orderService: orderService,
|
||||
botUserService: bot_users.GetBotUserService(),
|
||||
botService: bots.GetBotService(),
|
||||
userService: users.GetUserService(),
|
||||
}
|
||||
|
||||
func GetUniquePaymentRepository() *repositories.UniquePaymentRepository {
|
||||
return uniquePaymentRepository
|
||||
}
|
||||
|
||||
func GetOrderPaymentRepository() *repositories.OrderPaymentRepository {
|
||||
return orderPaymentRepository
|
||||
}
|
||||
|
||||
func GetOrderService() *OrderService {
|
||||
return orderService
|
||||
}
|
||||
|
||||
func GetOrderController() *OrderController {
|
||||
return orderController
|
||||
}
|
||||
|
||||
func GetOrderBackgroundService() *OrderBackgroundService {
|
||||
return orderBackgroundService
|
||||
}
|
||||
|
||||
func GetOrderRepository() *repositories.OrderRepository {
|
||||
return orderRepository
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
When writting migrations:
|
||||
|
||||
- write them for PostgreSQL
|
||||
- for PRIMARY UUID keys use gen_random_uuid()
|
||||
- for time use TIMESTAMPTZ (timestamp with zone)
|
||||
- split table, constraint and indexes declaration (table first, them other one by one)
|
||||
- format SQL in pretty way (add spaces, align columns types), constraints split by lines. The example:
|
||||
|
||||
CREATE TABLE marketplace_info (
|
||||
bot_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
title TEXT NOT NULL,
|
||||
description TEXT NOT NULL,
|
||||
short_description TEXT NOT NULL,
|
||||
tutorial_url TEXT,
|
||||
info_order BIGINT NOT NULL DEFAULT 0,
|
||||
is_published BOOLEAN NOT NULL DEFAULT FALSE
|
||||
);
|
||||
|
||||
ALTER TABLE marketplace_info_images
|
||||
ADD CONSTRAINT fk_marketplace_info_images_bot_id
|
||||
FOREIGN KEY (bot_id)
|
||||
REFERENCES marketplace_info (bot_id);
|
||||
@@ -1,12 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
When applying changes, do not forget to refactor old code.
|
||||
|
||||
You can shortify, make more readable, improve code quality, etc.
|
||||
Common logic can be extracted to functions, constants, files, etc.
|
||||
|
||||
After each large change with more than ~50-100 lines of code - always run `make lint` (from backend root folder) and, if you change frontend, run `npm run format` (from frontend root folder).
|
||||
@@ -1,147 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
After writing tests, always launch them and verify that they pass.
|
||||
|
||||
## Test Naming Format
|
||||
|
||||
Use these naming patterns:
|
||||
|
||||
- `Test_WhatWeDo_WhatWeExpect`
|
||||
- `Test_WhatWeDo_WhichConditions_WhatWeExpect`
|
||||
|
||||
## Examples from Real Codebase:
|
||||
|
||||
- `Test_CreateApiKey_WhenUserIsProjectOwner_ApiKeyCreated`
|
||||
- `Test_UpdateProject_WhenUserIsProjectAdmin_ProjectUpdated`
|
||||
- `Test_DeleteApiKey_WhenUserIsProjectMember_ReturnsForbidden`
|
||||
- `Test_GetProjectAuditLogs_WithDifferentUserRoles_EnforcesPermissionsCorrectly`
|
||||
- `Test_ProjectLifecycleE2E_CompletesSuccessfully`
|
||||
|
||||
## Testing Philosophy
|
||||
|
||||
**Prefer Controllers Over Unit Tests:**
|
||||
|
||||
- Test through HTTP endpoints via controllers whenever possible
|
||||
- Avoid testing repositories, services in isolation - test via API instead
|
||||
- Only use unit tests for complex model logic when no API exists
|
||||
- Name test files `controller_test.go` or `service_test.go`, not `integration_test.go`
|
||||
|
||||
**Extract Common Logic to Testing Utilities:**
|
||||
|
||||
- Create `testing.go` or `testing/testing.go` files for shared test utilities
|
||||
- Extract router creation, user setup, models creation helpers (in API, not just structs creation)
|
||||
- Reuse common patterns across different test files
|
||||
|
||||
**Refactor Existing Tests:**
|
||||
|
||||
- When working with existing tests, always look for opportunities to refactor and improve
|
||||
- Extract repetitive setup code to common utilities
|
||||
- Simplify complex tests by breaking them into smaller, focused tests
|
||||
- Replace inline test data creation with reusable helper functions
|
||||
- Consolidate similar test patterns across different test files
|
||||
- Make tests more readable and maintainable for other developers
|
||||
|
||||
## Testing Utilities Structure
|
||||
|
||||
**Create `testing.go` or `testing/testing.go` files with common utilities:**
|
||||
|
||||
```go
|
||||
package projects_testing
|
||||
|
||||
// CreateTestRouter creates unified router for all controllers
|
||||
func CreateTestRouter(controllers ...ControllerInterface) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
v1 := router.Group("/api/v1")
|
||||
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
|
||||
|
||||
for _, controller := range controllers {
|
||||
if routerGroup, ok := protected.(*gin.RouterGroup); ok {
|
||||
controller.RegisterRoutes(routerGroup)
|
||||
}
|
||||
}
|
||||
return router
|
||||
}
|
||||
|
||||
// CreateTestProjectViaAPI creates project through HTTP API
|
||||
func CreateTestProjectViaAPI(name string, owner *users_dto.SignInResponseDTO, router *gin.Engine) (*projects_models.Project, string) {
|
||||
request := projects_dto.CreateProjectRequestDTO{Name: name}
|
||||
w := MakeAPIRequest(router, "POST", "/api/v1/projects", "Bearer "+owner.Token, request)
|
||||
// Handle response...
|
||||
return project, owner.Token
|
||||
}
|
||||
|
||||
// AddMemberToProject adds member via API call
|
||||
func AddMemberToProject(project *projects_models.Project, member *users_dto.SignInResponseDTO, role users_enums.ProjectRole, ownerToken string, router *gin.Engine) {
|
||||
// Implementation...
|
||||
}
|
||||
```
|
||||
|
||||
## Controller Test Examples
|
||||
|
||||
**Permission-based testing:**
|
||||
|
||||
```go
|
||||
func Test_CreateApiKey_WhenUserIsProjectOwner_ApiKeyCreated(t *testing.T) {
|
||||
router := CreateApiKeyTestRouter(GetProjectController(), GetMembershipController())
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
project, _ := projects_testing.CreateTestProjectViaAPI("Test Project", owner, router)
|
||||
|
||||
request := CreateApiKeyRequestDTO{Name: "Test API Key"}
|
||||
var response ApiKey
|
||||
test_utils.MakePostRequestAndUnmarshal(t, router, "/api/v1/projects/api-keys/"+project.ID.String(), "Bearer "+owner.Token, request, http.StatusOK, &response)
|
||||
|
||||
assert.Equal(t, "Test API Key", response.Name)
|
||||
assert.NotEmpty(t, response.Token)
|
||||
}
|
||||
```
|
||||
|
||||
**Cross-project security testing:**
|
||||
|
||||
```go
|
||||
func Test_UpdateApiKey_WithApiKeyFromDifferentProject_ReturnsBadRequest(t *testing.T) {
|
||||
router := CreateApiKeyTestRouter(GetProjectController(), GetMembershipController())
|
||||
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
project1, _ := projects_testing.CreateTestProjectViaAPI("Project 1", owner1, router)
|
||||
project2, _ := projects_testing.CreateTestProjectViaAPI("Project 2", owner2, router)
|
||||
|
||||
apiKey := CreateTestApiKey("Cross Project Key", project1.ID, owner1.Token, router)
|
||||
|
||||
// Try to update via different project endpoint
|
||||
request := UpdateApiKeyRequestDTO{Name: &"Hacked Key"}
|
||||
resp := test_utils.MakePutRequest(t, router, "/api/v1/projects/api-keys/"+project2.ID.String()+"/"+apiKey.ID.String(), "Bearer "+owner2.Token, request, http.StatusBadRequest)
|
||||
|
||||
assert.Contains(t, string(resp.Body), "API key does not belong to this project")
|
||||
}
|
||||
```
|
||||
|
||||
**E2E lifecycle testing:**
|
||||
|
||||
```go
|
||||
func Test_ProjectLifecycleE2E_CompletesSuccessfully(t *testing.T) {
|
||||
router := projects_testing.CreateTestRouter(GetProjectController(), GetMembershipController())
|
||||
|
||||
// 1. Create project
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
project := projects_testing.CreateTestProject("E2E Project", owner, router)
|
||||
|
||||
// 2. Add member
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
projects_testing.AddMemberToProject(project, member, users_enums.ProjectRoleMember, owner.Token, router)
|
||||
|
||||
// 3. Promote to admin
|
||||
projects_testing.ChangeMemberRole(project, member.UserID, users_enums.ProjectRoleAdmin, owner.Token, router)
|
||||
|
||||
// 4. Transfer ownership
|
||||
projects_testing.TransferProjectOwnership(project, member.UserID, owner.Token, router)
|
||||
|
||||
// 5. Verify new owner can manage project
|
||||
finalProject := projects_testing.GetProject(project.ID, member.Token, router)
|
||||
assert.Equal(t, project.ID, finalProject.ID)
|
||||
}
|
||||
```
|
||||
@@ -1,6 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
Always use time.Now().UTC() instead of time.Now()
|
||||
@@ -2,8 +2,10 @@
|
||||
DEV_DB_NAME=databasus
|
||||
DEV_DB_USERNAME=postgres
|
||||
DEV_DB_PASSWORD=Q1234567
|
||||
#app
|
||||
# app
|
||||
ENV_MODE=development
|
||||
# logging
|
||||
SHOW_DB_INSTALLATION_VERIFICATION_LOGS=true
|
||||
# db
|
||||
DATABASE_DSN=host=dev-db user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
|
||||
DATABASE_URL=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -28,6 +27,8 @@ import (
|
||||
"databasus-backend/internal/features/restores"
|
||||
"databasus-backend/internal/features/storages"
|
||||
system_healthcheck "databasus-backend/internal/features/system/healthcheck"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
task_registry "databasus-backend/internal/features/tasks/registry"
|
||||
users_controllers "databasus-backend/internal/features/users/controllers"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
@@ -59,6 +60,8 @@ func main() {
|
||||
cache_utils.TestCacheConnection()
|
||||
|
||||
if config.GetEnv().IsPrimaryNode {
|
||||
log.Info("Clearing cache...")
|
||||
|
||||
err := cache_utils.ClearAllCache()
|
||||
if err != nil {
|
||||
log.Error("Failed to clear cache", "error", err)
|
||||
@@ -239,7 +242,7 @@ func setUpDependencies() {
|
||||
notifiers.SetupDependencies()
|
||||
storages.SetupDependencies()
|
||||
backups_config.SetupDependencies()
|
||||
backups_cancellation.SetupDependencies()
|
||||
task_cancellation.SetupDependencies()
|
||||
}
|
||||
|
||||
func runBackgroundTasks(log *slog.Logger) {
|
||||
@@ -257,14 +260,14 @@ func runBackgroundTasks(log *slog.Logger) {
|
||||
cancel()
|
||||
}()
|
||||
|
||||
err := files_utils.CleanFolder(config.GetEnv().TempFolder)
|
||||
if err != nil {
|
||||
log.Error("Failed to clean temp folder", "error", err)
|
||||
}
|
||||
|
||||
if config.GetEnv().IsPrimaryNode {
|
||||
log.Info("Starting primary node background tasks...")
|
||||
|
||||
err := files_utils.CleanFolder(config.GetEnv().TempFolder)
|
||||
if err != nil {
|
||||
log.Error("Failed to clean temp folder", "error", err)
|
||||
}
|
||||
|
||||
go runWithPanicLogging(log, "backup background service", func() {
|
||||
backuping.GetBackupsScheduler().Run(ctx)
|
||||
})
|
||||
@@ -284,6 +287,10 @@ func runBackgroundTasks(log *slog.Logger) {
|
||||
go runWithPanicLogging(log, "download token cleanup background service", func() {
|
||||
backups_download.GetDownloadTokenBackgroundService().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "task nodes registry background service", func() {
|
||||
task_registry.GetTaskNodesRegistry().Run(ctx)
|
||||
})
|
||||
} else {
|
||||
log.Info("Skipping primary node tasks as not primary node")
|
||||
}
|
||||
|
||||
@@ -28,7 +28,6 @@ require (
|
||||
github.com/valkey-io/valkey-go v1.0.70
|
||||
go.mongodb.org/mongo-driver v1.17.6
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/time v0.14.0
|
||||
gorm.io/driver/postgres v1.5.11
|
||||
gorm.io/gorm v1.26.1
|
||||
)
|
||||
@@ -186,6 +185,7 @@ require (
|
||||
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
golang.org/x/term v0.38.0 // indirect
|
||||
golang.org/x/time v0.14.0 // indirect
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
|
||||
gopkg.in/validator.v2 v2.0.1 // indirect
|
||||
moul.io/http2curl/v2 v2.3.0 // indirect
|
||||
|
||||
@@ -30,6 +30,8 @@ type EnvVariables struct {
|
||||
MariadbInstallDir string `env:"MARIADB_INSTALL_DIR"`
|
||||
MongodbInstallDir string `env:"MONGODB_INSTALL_DIR"`
|
||||
|
||||
ShowDbInstallationVerificationLogs bool `env:"SHOW_DB_INSTALLATION_VERIFICATION_LOGS"`
|
||||
|
||||
NodeID string
|
||||
IsManyNodesMode bool `env:"IS_MANY_NODES_MODE"`
|
||||
IsPrimaryNode bool `env:"IS_PRIMARY_NODE"`
|
||||
@@ -169,6 +171,11 @@ func loadEnvVariables() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Set default value for ShowDbInstallationVerificationLogs if not defined
|
||||
if os.Getenv("SHOW_DB_INSTALLATION_VERIFICATION_LOGS") == "" {
|
||||
env.ShowDbInstallationVerificationLogs = true
|
||||
}
|
||||
|
||||
for _, arg := range os.Args {
|
||||
if strings.Contains(arg, "test") {
|
||||
env.IsTesting = true
|
||||
@@ -192,16 +199,36 @@ func loadEnvVariables() {
|
||||
log.Info("ENV_MODE loaded", "mode", env.EnvMode)
|
||||
|
||||
env.PostgresesInstallDir = filepath.Join(backendRoot, "tools", "postgresql")
|
||||
tools.VerifyPostgresesInstallation(log, env.EnvMode, env.PostgresesInstallDir)
|
||||
tools.VerifyPostgresesInstallation(
|
||||
log,
|
||||
env.EnvMode,
|
||||
env.PostgresesInstallDir,
|
||||
env.ShowDbInstallationVerificationLogs,
|
||||
)
|
||||
|
||||
env.MysqlInstallDir = filepath.Join(backendRoot, "tools", "mysql")
|
||||
tools.VerifyMysqlInstallation(log, env.EnvMode, env.MysqlInstallDir)
|
||||
tools.VerifyMysqlInstallation(
|
||||
log,
|
||||
env.EnvMode,
|
||||
env.MysqlInstallDir,
|
||||
env.ShowDbInstallationVerificationLogs,
|
||||
)
|
||||
|
||||
env.MariadbInstallDir = filepath.Join(backendRoot, "tools", "mariadb")
|
||||
tools.VerifyMariadbInstallation(log, env.EnvMode, env.MariadbInstallDir)
|
||||
tools.VerifyMariadbInstallation(
|
||||
log,
|
||||
env.EnvMode,
|
||||
env.MariadbInstallDir,
|
||||
env.ShowDbInstallationVerificationLogs,
|
||||
)
|
||||
|
||||
env.MongodbInstallDir = filepath.Join(backendRoot, "tools", "mongodb")
|
||||
tools.VerifyMongodbInstallation(log, env.EnvMode, env.MongodbInstallDir)
|
||||
tools.VerifyMongodbInstallation(
|
||||
log,
|
||||
env.EnvMode,
|
||||
env.MongodbInstallDir,
|
||||
env.ShowDbInstallationVerificationLogs,
|
||||
)
|
||||
|
||||
env.NodeID = uuid.New().String()
|
||||
if env.NodeNetworkThroughputMBs == 0 {
|
||||
|
||||
@@ -3,11 +3,12 @@ package backuping
|
||||
import (
|
||||
"context"
|
||||
"databasus-backend/internal/config"
|
||||
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
task_registry "databasus-backend/internal/features/tasks/registry"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
"errors"
|
||||
@@ -20,6 +21,11 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
heartbeatTickerInterval = 15 * time.Second
|
||||
backuperHeathcheckThreshold = 5 * time.Minute
|
||||
)
|
||||
|
||||
type BackuperNode struct {
|
||||
databaseService *databases.DatabaseService
|
||||
fieldEncryptor util_encryption.FieldEncryptor
|
||||
@@ -28,8 +34,8 @@ type BackuperNode struct {
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
storageService *storages.StorageService
|
||||
notificationSender backups_core.NotificationSender
|
||||
backupCancelManager *backups_cancellation.BackupCancelManager
|
||||
nodesRegistry *BackupNodesRegistry
|
||||
backupCancelManager *tasks_cancellation.TaskCancelManager
|
||||
tasksRegistry *task_registry.TaskNodesRegistry
|
||||
logger *slog.Logger
|
||||
createBackupUseCase backups_core.CreateBackupUsecase
|
||||
nodeID uuid.UUID
|
||||
@@ -42,20 +48,19 @@ func (n *BackuperNode) Run(ctx context.Context) {
|
||||
|
||||
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
|
||||
|
||||
backupNode := BackupNode{
|
||||
backupNode := task_registry.TaskNode{
|
||||
ID: n.nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: time.Now().UTC(),
|
||||
}
|
||||
|
||||
if err := n.nodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
|
||||
if err := n.tasksRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
|
||||
n.logger.Error("Failed to register node in registry", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
|
||||
n.MakeBackup(backupID, isCallNotifier)
|
||||
if err := n.nodesRegistry.PublishBackupCompletion(n.nodeID.String(), backupID); err != nil {
|
||||
if err := n.tasksRegistry.PublishTaskCompletion(n.nodeID.String(), backupID); err != nil {
|
||||
n.logger.Error(
|
||||
"Failed to publish backup completion",
|
||||
"error",
|
||||
@@ -66,17 +71,17 @@ func (n *BackuperNode) Run(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
if err := n.nodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID.String(), backupHandler); err != nil {
|
||||
if err := n.tasksRegistry.SubscribeNodeForTasksAssignment(n.nodeID.String(), backupHandler); err != nil {
|
||||
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := n.nodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
|
||||
if err := n.tasksRegistry.UnsubscribeNodeForTasksAssignments(); err != nil {
|
||||
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(15 * time.Second)
|
||||
ticker := time.NewTicker(heartbeatTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
|
||||
@@ -86,7 +91,7 @@ func (n *BackuperNode) Run(ctx context.Context) {
|
||||
case <-ctx.Done():
|
||||
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
|
||||
|
||||
if err := n.nodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
|
||||
if err := n.tasksRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
|
||||
n.logger.Error("Failed to unregister node from registry", "error", err)
|
||||
}
|
||||
|
||||
@@ -98,7 +103,7 @@ func (n *BackuperNode) Run(ctx context.Context) {
|
||||
}
|
||||
|
||||
func (n *BackuperNode) IsBackuperRunning() bool {
|
||||
return n.lastHeartbeat.After(time.Now().UTC().Add(-5 * time.Minute))
|
||||
return n.lastHeartbeat.After(time.Now().UTC().Add(-backuperHeathcheckThreshold))
|
||||
}
|
||||
|
||||
func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
@@ -147,8 +152,8 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
n.backupCancelManager.RegisterBackup(backup.ID, cancel)
|
||||
defer n.backupCancelManager.UnregisterBackup(backup.ID)
|
||||
n.backupCancelManager.RegisterTask(backup.ID, cancel)
|
||||
defer n.backupCancelManager.UnregisterTask(backup.ID)
|
||||
|
||||
backupMetadata, err := n.createBackupUseCase.Execute(
|
||||
ctx,
|
||||
@@ -161,6 +166,17 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
|
||||
// Log detailed error information for debugging
|
||||
n.logger.Error("Backup execution failed",
|
||||
"backupId", backup.ID,
|
||||
"databaseId", databaseID,
|
||||
"databaseType", database.Type,
|
||||
"storageId", storage.ID,
|
||||
"storageType", storage.Type,
|
||||
"error", err,
|
||||
"errorMessage", errMsg,
|
||||
)
|
||||
|
||||
// Check if backup was cancelled (not due to shutdown)
|
||||
isCancelled := strings.Contains(errMsg, "backup cancelled") ||
|
||||
strings.Contains(errMsg, "context canceled") ||
|
||||
@@ -168,6 +184,12 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
isShutdown := strings.Contains(errMsg, "shutdown")
|
||||
|
||||
if isCancelled && !isShutdown {
|
||||
n.logger.Warn("Backup was cancelled by user or system",
|
||||
"backupId", backup.ID,
|
||||
"isCancelled", isCancelled,
|
||||
"isShutdown", isShutdown,
|
||||
)
|
||||
|
||||
backup.Status = backups_core.BackupStatusCanceled
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
backup.BackupSizeMb = 0
|
||||
@@ -335,10 +357,9 @@ func (n *BackuperNode) SendBackupNotification(
|
||||
}
|
||||
}
|
||||
|
||||
func (n *BackuperNode) sendHeartbeat(backupNode *BackupNode) {
|
||||
func (n *BackuperNode) sendHeartbeat(backupNode *task_registry.TaskNode) {
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
backupNode.LastHeartbeat = time.Now().UTC()
|
||||
if err := n.nodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil {
|
||||
if err := n.tasksRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil {
|
||||
n.logger.Error("Failed to send heartbeat", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,15 +2,15 @@ package backuping
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/usecases"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
task_registry "databasus-backend/internal/features/tasks/registry"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
"time"
|
||||
@@ -20,15 +20,9 @@ import (
|
||||
|
||||
var backupRepository = &backups_core.BackupRepository{}
|
||||
|
||||
var backupCancelManager = backups_cancellation.GetBackupCancelManager()
|
||||
var taskCancelManager = tasks_cancellation.GetTaskCancelManager()
|
||||
|
||||
var nodesRegistry = &BackupNodesRegistry{
|
||||
cache_utils.GetValkeyClient(),
|
||||
logger.GetLogger(),
|
||||
cache_utils.DefaultCacheTimeout,
|
||||
cache_utils.NewPubSubManager(),
|
||||
cache_utils.NewPubSubManager(),
|
||||
}
|
||||
var nodesRegistry = task_registry.GetTaskNodesRegistry()
|
||||
|
||||
func getNodeID() uuid.UUID {
|
||||
nodeIDStr := config.GetEnv().NodeID
|
||||
@@ -48,7 +42,7 @@ var backuperNode = &BackuperNode{
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
backupCancelManager,
|
||||
taskCancelManager,
|
||||
nodesRegistry,
|
||||
logger.GetLogger(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
@@ -60,7 +54,7 @@ var backupsScheduler = &BackupsScheduler{
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
backupCancelManager,
|
||||
taskCancelManager,
|
||||
nodesRegistry,
|
||||
time.Now().UTC(),
|
||||
logger.GetLogger(),
|
||||
|
||||
@@ -1,32 +1,6 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BackupNode struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ThroughputMBs int `json:"throughputMBs"`
|
||||
LastHeartbeat time.Time `json:"lastHeartbeat"`
|
||||
}
|
||||
|
||||
type BackupNodeStats struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ActiveBackups int `json:"activeBackups"`
|
||||
}
|
||||
|
||||
type BackupSubmitMessage struct {
|
||||
NodeID string `json:"nodeId"`
|
||||
BackupID string `json:"backupId"`
|
||||
IsCallNotifier bool `json:"isCallNotifier"`
|
||||
}
|
||||
|
||||
type BackupCompletionMessage struct {
|
||||
NodeID string `json:"nodeId"`
|
||||
BackupID string `json:"backupId"`
|
||||
}
|
||||
import "github.com/google/uuid"
|
||||
|
||||
type BackupToNodeRelation struct {
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
|
||||
@@ -1,448 +0,0 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/valkey-io/valkey-go"
|
||||
)
|
||||
|
||||
const (
|
||||
nodeInfoKeyPrefix = "node:"
|
||||
nodeInfoKeySuffix = ":info"
|
||||
nodeActiveBackupsPrefix = "node:"
|
||||
nodeActiveBackupsSuffix = ":active_backups"
|
||||
backupSubmitChannel = "backup:submit"
|
||||
backupCompletionChannel = "backup:completion"
|
||||
)
|
||||
|
||||
// BackupNodesRegistry helps to sync backups scheduler and backup nodes.
|
||||
// Features:
|
||||
// - Track node availability and load level
|
||||
// - Assign from scheduler to node backups needed to be processed
|
||||
// - Notify scheduler from node about backup completion
|
||||
type BackupNodesRegistry struct {
|
||||
client valkey.Client
|
||||
logger *slog.Logger
|
||||
timeout time.Duration
|
||||
pubsubBackups *cache_utils.PubSubManager
|
||||
pubsubCompletions *cache_utils.PubSubManager
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) GetAvailableNodes() ([]BackupNode, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
var allKeys []string
|
||||
cursor := uint64(0)
|
||||
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
|
||||
|
||||
for {
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return nil, fmt.Errorf("failed to scan node keys: %w", result.Error())
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse scan result: %w", err)
|
||||
}
|
||||
|
||||
allKeys = append(allKeys, scanResult.Elements...)
|
||||
|
||||
cursor = scanResult.Cursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return []BackupNode{}, nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pipeline get node keys: %w", err)
|
||||
}
|
||||
|
||||
var nodes []BackupNode
|
||||
for key, data := range keyDataMap {
|
||||
var node BackupNode
|
||||
if err := json.Unmarshal(data, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data", "key", key, "error", err)
|
||||
continue
|
||||
}
|
||||
nodes = append(nodes, node)
|
||||
}
|
||||
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) GetBackupNodesStats() ([]BackupNodeStats, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
var allKeys []string
|
||||
cursor := uint64(0)
|
||||
pattern := nodeActiveBackupsPrefix + "*" + nodeActiveBackupsSuffix
|
||||
|
||||
for {
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(100).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return nil, fmt.Errorf("failed to scan active backups keys: %w", result.Error())
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse scan result: %w", err)
|
||||
}
|
||||
|
||||
allKeys = append(allKeys, scanResult.Elements...)
|
||||
|
||||
cursor = scanResult.Cursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return []BackupNodeStats{}, nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pipeline get active backups keys: %w", err)
|
||||
}
|
||||
|
||||
var stats []BackupNodeStats
|
||||
for key, data := range keyDataMap {
|
||||
nodeID := r.extractNodeIDFromKey(key, nodeActiveBackupsPrefix, nodeActiveBackupsSuffix)
|
||||
|
||||
count, err := r.parseIntFromBytes(data)
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse active backups count", "key", key, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
stat := BackupNodeStats{
|
||||
ID: nodeID,
|
||||
ActiveBackups: int(count),
|
||||
}
|
||||
stats = append(stats, stat)
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) IncrementBackupsInProgress(nodeID string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID, nodeActiveBackupsSuffix)
|
||||
result := r.client.Do(ctx, r.client.B().Incr().Key(key).Build())
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to increment backups in progress for node %s: %w",
|
||||
nodeID,
|
||||
result.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) DecrementBackupsInProgress(nodeID string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID, nodeActiveBackupsSuffix)
|
||||
result := r.client.Do(ctx, r.client.B().Decr().Key(key).Build())
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to decrement backups in progress for node %s: %w",
|
||||
nodeID,
|
||||
result.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
newValue, err := result.AsInt64()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse decremented value for node %s: %w", nodeID, err)
|
||||
}
|
||||
|
||||
if newValue < 0 {
|
||||
setCtx, setCancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
r.client.Do(setCtx, r.client.B().Set().Key(key).Value("0").Build())
|
||||
setCancel()
|
||||
r.logger.Warn("Active backups counter went below 0, reset to 0", "nodeID", nodeID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) HearthbeatNodeInRegistry(now time.Time, backupNode BackupNode) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
backupNode.LastHeartbeat = now
|
||||
|
||||
data, err := json.Marshal(backupNode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal backup node: %w", err)
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, backupNode.ID.String(), nodeInfoKeySuffix)
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Set().Key(key).Value(string(data)).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to register node %s: %w", backupNode.ID, result.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) UnregisterNodeFromRegistry(backupNode BackupNode) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, backupNode.ID.String(), nodeInfoKeySuffix)
|
||||
counterKey := fmt.Sprintf(
|
||||
"%s%s%s",
|
||||
nodeActiveBackupsPrefix,
|
||||
backupNode.ID.String(),
|
||||
nodeActiveBackupsSuffix,
|
||||
)
|
||||
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Del().Key(infoKey, counterKey).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to unregister node %s: %w", backupNode.ID, result.Error())
|
||||
}
|
||||
|
||||
r.logger.Info("Unregistered node from registry", "nodeID", backupNode.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) AssignBackupToNode(
|
||||
targetNodeID string,
|
||||
backupID uuid.UUID,
|
||||
isCallNotifier bool,
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
message := BackupSubmitMessage{
|
||||
NodeID: targetNodeID,
|
||||
BackupID: backupID.String(),
|
||||
IsCallNotifier: isCallNotifier,
|
||||
}
|
||||
|
||||
messageJSON, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal backup submit message: %w", err)
|
||||
}
|
||||
|
||||
err = r.pubsubBackups.Publish(ctx, backupSubmitChannel, string(messageJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to publish backup submit message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) SubscribeNodeForBackupsAssignment(
|
||||
nodeID string,
|
||||
handler func(backupID uuid.UUID, isCallNotifier bool),
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
wrappedHandler := func(message string) {
|
||||
var msg BackupSubmitMessage
|
||||
if err := json.Unmarshal([]byte(message), &msg); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal backup submit message", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if msg.NodeID != nodeID {
|
||||
return
|
||||
}
|
||||
|
||||
backupID, err := uuid.Parse(msg.BackupID)
|
||||
if err != nil {
|
||||
r.logger.Warn(
|
||||
"Failed to parse backup ID from message",
|
||||
"backupId",
|
||||
msg.BackupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
handler(backupID, msg.IsCallNotifier)
|
||||
}
|
||||
|
||||
err := r.pubsubBackups.Subscribe(ctx, backupSubmitChannel, wrappedHandler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to backup submit channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Subscribed to backup submit channel", "nodeID", nodeID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) UnsubscribeNodeForBackupsAssignments() error {
|
||||
err := r.pubsubBackups.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unsubscribe from backup submit channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Unsubscribed from backup submit channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) PublishBackupCompletion(nodeID string, backupID uuid.UUID) error {
|
||||
ctx := context.Background()
|
||||
|
||||
message := BackupCompletionMessage{
|
||||
NodeID: nodeID,
|
||||
BackupID: backupID.String(),
|
||||
}
|
||||
|
||||
messageJSON, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal backup completion message: %w", err)
|
||||
}
|
||||
|
||||
err = r.pubsubCompletions.Publish(ctx, backupCompletionChannel, string(messageJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to publish backup completion message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) SubscribeForBackupsCompletions(
|
||||
handler func(nodeID string, backupID uuid.UUID),
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
wrappedHandler := func(message string) {
|
||||
var msg BackupCompletionMessage
|
||||
if err := json.Unmarshal([]byte(message), &msg); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal backup completion message", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
backupID, err := uuid.Parse(msg.BackupID)
|
||||
if err != nil {
|
||||
r.logger.Warn(
|
||||
"Failed to parse backup ID from completion message",
|
||||
"backupId",
|
||||
msg.BackupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
handler(msg.NodeID, backupID)
|
||||
}
|
||||
|
||||
err := r.pubsubCompletions.Subscribe(ctx, backupCompletionChannel, wrappedHandler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to backup completion channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Subscribed to backup completion channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) UnsubscribeForBackupsCompletions() error {
|
||||
err := r.pubsubCompletions.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unsubscribe from backup completion channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Unsubscribed from backup completion channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uuid.UUID {
|
||||
nodeIDStr := strings.TrimPrefix(key, prefix)
|
||||
nodeIDStr = strings.TrimSuffix(nodeIDStr, suffix)
|
||||
|
||||
nodeID, err := uuid.Parse(nodeIDStr)
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse node ID from key", "key", key, "error", err)
|
||||
return uuid.Nil
|
||||
}
|
||||
|
||||
return nodeID
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, error) {
|
||||
if len(keys) == 0 {
|
||||
return make(map[string][]byte), nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
commands := make([]valkey.Completed, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
commands = append(commands, r.client.B().Get().Key(key).Build())
|
||||
}
|
||||
|
||||
results := r.client.DoMulti(ctx, commands...)
|
||||
|
||||
keyDataMap := make(map[string][]byte, len(keys))
|
||||
for i, result := range results {
|
||||
if result.Error() != nil {
|
||||
r.logger.Warn("Failed to get key in pipeline", "key", keys[i], "error", result.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
data, err := result.AsBytes()
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse key data in pipeline", "key", keys[i], "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
keyDataMap[keys[i]] = data
|
||||
}
|
||||
|
||||
return keyDataMap, nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
|
||||
str := string(data)
|
||||
var count int64
|
||||
_, err := fmt.Sscanf(str, "%d", &count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to parse integer from bytes: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
@@ -1,904 +0,0 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/logger"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_HearthbeatNodeInRegistry_RegistersNodeWithTTL(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
defer cleanupTestNode(registry, node)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
nodes, err := registry.GetAvailableNodes()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, node.ID, nodes[0].ID)
|
||||
assert.Equal(t, node.ThroughputMBs, nodes[0].ThroughputMBs)
|
||||
}
|
||||
|
||||
func Test_UnregisterNodeFromRegistry_RemovesNodeAndCounter(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.UnregisterNodeFromRegistry(node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
nodes, err := registry.GetAvailableNodes()
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, nodes)
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, stats)
|
||||
}
|
||||
|
||||
func Test_GetAvailableNodes_ReturnsAllRegisteredNodes(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node1 := createTestBackupNode()
|
||||
node2 := createTestBackupNode()
|
||||
node3 := createTestBackupNode()
|
||||
defer cleanupTestNode(registry, node1)
|
||||
defer cleanupTestNode(registry, node2)
|
||||
defer cleanupTestNode(registry, node3)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
|
||||
assert.NoError(t, err)
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
|
||||
assert.NoError(t, err)
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
nodes, err := registry.GetAvailableNodes()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, nodes, 3)
|
||||
|
||||
nodeIDs := make(map[uuid.UUID]bool)
|
||||
for _, node := range nodes {
|
||||
nodeIDs[node.ID] = true
|
||||
}
|
||||
assert.True(t, nodeIDs[node1.ID])
|
||||
assert.True(t, nodeIDs[node2.ID])
|
||||
assert.True(t, nodeIDs[node3.ID])
|
||||
}
|
||||
|
||||
func Test_GetAvailableNodes_WhenNoNodesExist_ReturnsEmptySlice(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
|
||||
nodes, err := registry.GetAvailableNodes()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, nodes)
|
||||
assert.Empty(t, nodes)
|
||||
}
|
||||
|
||||
func Test_IncrementBackupsInProgress_IncrementsCounter(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
defer cleanupTestNode(registry, node)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, stats, 1)
|
||||
assert.Equal(t, node.ID, stats[0].ID)
|
||||
assert.Equal(t, 1, stats[0].ActiveBackups)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err = registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, stats, 1)
|
||||
assert.Equal(t, 2, stats[0].ActiveBackups)
|
||||
}
|
||||
|
||||
func Test_DecrementBackupsInProgress_DecrementsCounter(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
defer cleanupTestNode(registry, node)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node.ID.String())
|
||||
assert.NoError(t, err)
|
||||
err = registry.IncrementBackupsInProgress(node.ID.String())
|
||||
assert.NoError(t, err)
|
||||
err = registry.IncrementBackupsInProgress(node.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, stats[0].ActiveBackups)
|
||||
|
||||
err = registry.DecrementBackupsInProgress(node.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err = registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, stats[0].ActiveBackups)
|
||||
|
||||
err = registry.DecrementBackupsInProgress(node.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err = registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, stats[0].ActiveBackups)
|
||||
}
|
||||
|
||||
func Test_DecrementBackupsInProgress_WhenNegative_ResetsToZero(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
defer cleanupTestNode(registry, node)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.DecrementBackupsInProgress(node.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, stats, 1)
|
||||
assert.Equal(t, 0, stats[0].ActiveBackups)
|
||||
}
|
||||
|
||||
func Test_GetBackupNodesStats_ReturnsStatsForAllNodes(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node1 := createTestBackupNode()
|
||||
node2 := createTestBackupNode()
|
||||
node3 := createTestBackupNode()
|
||||
defer cleanupTestNode(registry, node1)
|
||||
defer cleanupTestNode(registry, node2)
|
||||
defer cleanupTestNode(registry, node3)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
|
||||
assert.NoError(t, err)
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
|
||||
assert.NoError(t, err)
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node1.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node2.ID.String())
|
||||
assert.NoError(t, err)
|
||||
err = registry.IncrementBackupsInProgress(node2.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node3.ID.String())
|
||||
assert.NoError(t, err)
|
||||
err = registry.IncrementBackupsInProgress(node3.ID.String())
|
||||
assert.NoError(t, err)
|
||||
err = registry.IncrementBackupsInProgress(node3.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, stats, 3)
|
||||
|
||||
statsMap := make(map[uuid.UUID]int)
|
||||
for _, stat := range stats {
|
||||
statsMap[stat.ID] = stat.ActiveBackups
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, statsMap[node1.ID])
|
||||
assert.Equal(t, 2, statsMap[node2.ID])
|
||||
assert.Equal(t, 3, statsMap[node3.ID])
|
||||
}
|
||||
|
||||
func Test_GetBackupNodesStats_WhenNoStats_ReturnsEmptySlice(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, stats)
|
||||
assert.Empty(t, stats)
|
||||
}
|
||||
|
||||
func Test_MultipleNodes_RegisteredAndQueriedCorrectly(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node1 := createTestBackupNode()
|
||||
node1.ThroughputMBs = 50
|
||||
node2 := createTestBackupNode()
|
||||
node2.ThroughputMBs = 100
|
||||
node3 := createTestBackupNode()
|
||||
node3.ThroughputMBs = 150
|
||||
defer cleanupTestNode(registry, node1)
|
||||
defer cleanupTestNode(registry, node2)
|
||||
defer cleanupTestNode(registry, node3)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
|
||||
assert.NoError(t, err)
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
|
||||
assert.NoError(t, err)
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
nodes, err := registry.GetAvailableNodes()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, nodes, 3)
|
||||
|
||||
nodeMap := make(map[uuid.UUID]BackupNode)
|
||||
for _, node := range nodes {
|
||||
nodeMap[node.ID] = node
|
||||
}
|
||||
|
||||
assert.Equal(t, 50, nodeMap[node1.ID].ThroughputMBs)
|
||||
assert.Equal(t, 100, nodeMap[node2.ID].ThroughputMBs)
|
||||
assert.Equal(t, 150, nodeMap[node3.ID].ThroughputMBs)
|
||||
}
|
||||
|
||||
func Test_BackupCounters_TrackedSeparatelyPerNode(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node1 := createTestBackupNode()
|
||||
node2 := createTestBackupNode()
|
||||
defer cleanupTestNode(registry, node1)
|
||||
defer cleanupTestNode(registry, node2)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
|
||||
assert.NoError(t, err)
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node1.ID.String())
|
||||
assert.NoError(t, err)
|
||||
err = registry.IncrementBackupsInProgress(node1.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node2.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, stats, 2)
|
||||
|
||||
statsMap := make(map[uuid.UUID]int)
|
||||
for _, stat := range stats {
|
||||
statsMap[stat.ID] = stat.ActiveBackups
|
||||
}
|
||||
|
||||
assert.Equal(t, 2, statsMap[node1.ID])
|
||||
assert.Equal(t, 1, statsMap[node2.ID])
|
||||
|
||||
err = registry.DecrementBackupsInProgress(node1.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err = registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
|
||||
statsMap = make(map[uuid.UUID]int)
|
||||
for _, stat := range stats {
|
||||
statsMap[stat.ID] = stat.ActiveBackups
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, statsMap[node1.ID])
|
||||
assert.Equal(t, 1, statsMap[node2.ID])
|
||||
}
|
||||
|
||||
func Test_GetAvailableNodes_SkipsInvalidJsonData(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
defer cleanupTestNode(registry, node)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
defer cancel()
|
||||
|
||||
invalidKey := nodeInfoKeyPrefix + uuid.New().String() + nodeInfoKeySuffix
|
||||
registry.client.Do(
|
||||
ctx,
|
||||
registry.client.B().Set().Key(invalidKey).Value("invalid json data").Build(),
|
||||
)
|
||||
defer func() {
|
||||
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
defer cleanupCancel()
|
||||
registry.client.Do(cleanupCtx, registry.client.B().Del().Key(invalidKey).Build())
|
||||
}()
|
||||
|
||||
nodes, err := registry.GetAvailableNodes()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, node.ID, nodes[0].ID)
|
||||
}
|
||||
|
||||
func Test_PipelineGetKeys_HandlesEmptyKeysList(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
|
||||
keyDataMap, err := registry.pipelineGetKeys([]string{})
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, keyDataMap)
|
||||
assert.Empty(t, keyDataMap)
|
||||
}
|
||||
|
||||
func Test_HearthbeatNodeInRegistry_UpdatesLastHeartbeat(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
originalHeartbeat := node.LastHeartbeat
|
||||
defer cleanupTestNode(registry, node)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
nodes, err := registry.GetAvailableNodes()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.True(t, nodes[0].LastHeartbeat.After(originalHeartbeat))
|
||||
}
|
||||
|
||||
func createTestRegistry() *BackupNodesRegistry {
|
||||
return &BackupNodesRegistry{
|
||||
cache_utils.GetValkeyClient(),
|
||||
logger.GetLogger(),
|
||||
cache_utils.DefaultCacheTimeout,
|
||||
cache_utils.NewPubSubManager(),
|
||||
cache_utils.NewPubSubManager(),
|
||||
}
|
||||
}
|
||||
|
||||
func createTestBackupNode() BackupNode {
|
||||
return BackupNode{
|
||||
ID: uuid.New(),
|
||||
ThroughputMBs: 100,
|
||||
LastHeartbeat: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
func cleanupTestNode(registry *BackupNodesRegistry, node BackupNode) {
|
||||
registry.UnregisterNodeFromRegistry(node)
|
||||
}
|
||||
|
||||
func Test_AssignBackupTonode_PublishesJsonMessageToChannel(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
backupID := uuid.New()
|
||||
|
||||
err := registry.AssignBackupToNode(node.ID.String(), backupID, true)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_SubscribeNodeForBackupsAssignment_ReceivesSubmittedBackupsForMatchingNode(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
backupID := uuid.New()
|
||||
defer registry.UnsubscribeNodeForBackupsAssignments()
|
||||
|
||||
receivedBackupID := make(chan uuid.UUID, 1)
|
||||
handler := func(id uuid.UUID, isCallNotifier bool) {
|
||||
receivedBackupID <- id
|
||||
}
|
||||
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.AssignBackupToNode(node.ID.String(), backupID, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case received := <-receivedBackupID:
|
||||
assert.Equal(t, backupID, received)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Timeout waiting for backup message")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SubscribeNodeForBackupsAssignment_FiltersOutBackupsForDifferentNode(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node1 := createTestBackupNode()
|
||||
node2 := createTestBackupNode()
|
||||
backupID := uuid.New()
|
||||
defer registry.UnsubscribeNodeForBackupsAssignments()
|
||||
|
||||
receivedBackupID := make(chan uuid.UUID, 1)
|
||||
handler := func(id uuid.UUID, isCallNotifier bool) {
|
||||
receivedBackupID <- id
|
||||
}
|
||||
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node1.ID.String(), handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.AssignBackupToNode(node2.ID.String(), backupID, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-receivedBackupID:
|
||||
t.Fatal("Should not receive backup for different node")
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SubscribeNodeForBackupsAssignment_ParsesJsonAndBackupIdCorrectly(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
backupID1 := uuid.New()
|
||||
backupID2 := uuid.New()
|
||||
defer registry.UnsubscribeNodeForBackupsAssignments()
|
||||
|
||||
receivedBackups := make(chan uuid.UUID, 2)
|
||||
handler := func(id uuid.UUID, isCallNotifier bool) {
|
||||
receivedBackups <- id
|
||||
}
|
||||
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.AssignBackupToNode(node.ID.String(), backupID1, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.AssignBackupToNode(node.ID.String(), backupID2, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
received1 := <-receivedBackups
|
||||
received2 := <-receivedBackups
|
||||
|
||||
receivedIDs := []uuid.UUID{received1, received2}
|
||||
assert.Contains(t, receivedIDs, backupID1)
|
||||
assert.Contains(t, receivedIDs, backupID2)
|
||||
}
|
||||
|
||||
func Test_SubscribeNodeForBackupsAssignment_HandlesInvalidJson(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
defer registry.UnsubscribeNodeForBackupsAssignments()
|
||||
|
||||
receivedBackupID := make(chan uuid.UUID, 1)
|
||||
handler := func(id uuid.UUID, isCallNotifier bool) {
|
||||
receivedBackupID <- id
|
||||
}
|
||||
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
ctx := context.Background()
|
||||
err = registry.pubsubBackups.Publish(ctx, "backup:submit", "invalid json")
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-receivedBackupID:
|
||||
t.Fatal("Should not receive backup for invalid JSON")
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func Test_UnsubscribeNodeForBackupsAssignments_StopsReceivingMessages(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
backupID1 := uuid.New()
|
||||
backupID2 := uuid.New()
|
||||
|
||||
receivedBackupID := make(chan uuid.UUID, 2)
|
||||
handler := func(id uuid.UUID, isCallNotifier bool) {
|
||||
receivedBackupID <- id
|
||||
}
|
||||
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.AssignBackupToNode(node.ID.String(), backupID1, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
received := <-receivedBackupID
|
||||
assert.Equal(t, backupID1, received)
|
||||
|
||||
err = registry.UnsubscribeNodeForBackupsAssignments()
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.AssignBackupToNode(node.ID.String(), backupID2, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-receivedBackupID:
|
||||
t.Fatal("Should not receive backup after unsubscribe")
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SubscribeNodeForBackupsAssignment_WhenAlreadySubscribed_ReturnsError(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
defer registry.UnsubscribeNodeForBackupsAssignments()
|
||||
|
||||
handler := func(id uuid.UUID, isCallNotifier bool) {}
|
||||
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "already subscribed")
|
||||
}
|
||||
|
||||
func Test_MultipleNodes_EachReceivesOnlyTheirBackups(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry1 := createTestRegistry()
|
||||
registry2 := createTestRegistry()
|
||||
registry3 := createTestRegistry()
|
||||
|
||||
node1 := createTestBackupNode()
|
||||
node2 := createTestBackupNode()
|
||||
node3 := createTestBackupNode()
|
||||
|
||||
backupID1 := uuid.New()
|
||||
backupID2 := uuid.New()
|
||||
backupID3 := uuid.New()
|
||||
|
||||
defer registry1.UnsubscribeNodeForBackupsAssignments()
|
||||
defer registry2.UnsubscribeNodeForBackupsAssignments()
|
||||
defer registry3.UnsubscribeNodeForBackupsAssignments()
|
||||
|
||||
receivedBackups1 := make(chan uuid.UUID, 3)
|
||||
receivedBackups2 := make(chan uuid.UUID, 3)
|
||||
receivedBackups3 := make(chan uuid.UUID, 3)
|
||||
|
||||
handler1 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups1 <- id }
|
||||
handler2 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups2 <- id }
|
||||
handler3 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups3 <- id }
|
||||
|
||||
err := registry1.SubscribeNodeForBackupsAssignment(node1.ID.String(), handler1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry2.SubscribeNodeForBackupsAssignment(node2.ID.String(), handler2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry3.SubscribeNodeForBackupsAssignment(node3.ID.String(), handler3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
submitRegistry := createTestRegistry()
|
||||
err = submitRegistry.AssignBackupToNode(node1.ID.String(), backupID1, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = submitRegistry.AssignBackupToNode(node2.ID.String(), backupID2, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = submitRegistry.AssignBackupToNode(node3.ID.String(), backupID3, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case received := <-receivedBackups1:
|
||||
assert.Equal(t, backupID1, received)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Node 1 timeout waiting for backup message")
|
||||
}
|
||||
|
||||
select {
|
||||
case received := <-receivedBackups2:
|
||||
assert.Equal(t, backupID2, received)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Node 2 timeout waiting for backup message")
|
||||
}
|
||||
|
||||
select {
|
||||
case received := <-receivedBackups3:
|
||||
assert.Equal(t, backupID3, received)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Node 3 timeout waiting for backup message")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-receivedBackups1:
|
||||
t.Fatal("Node 1 should not receive additional backups")
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
}
|
||||
|
||||
select {
|
||||
case <-receivedBackups2:
|
||||
t.Fatal("Node 2 should not receive additional backups")
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
}
|
||||
|
||||
select {
|
||||
case <-receivedBackups3:
|
||||
t.Fatal("Node 3 should not receive additional backups")
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func Test_PublishBackupCompletion_PublishesMessageToChannel(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
backupID := uuid.New()
|
||||
|
||||
err := registry.PublishBackupCompletion(node.ID.String(), backupID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_SubscribeForBackupsCompletions_ReceivesCompletedBackups(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
backupID := uuid.New()
|
||||
defer registry.UnsubscribeForBackupsCompletions()
|
||||
|
||||
receivedBackupID := make(chan uuid.UUID, 1)
|
||||
receivedNodeID := make(chan string, 1)
|
||||
handler := func(nodeID string, backupID uuid.UUID) {
|
||||
receivedNodeID <- nodeID
|
||||
receivedBackupID <- backupID
|
||||
}
|
||||
|
||||
err := registry.SubscribeForBackupsCompletions(handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.PublishBackupCompletion(node.ID.String(), backupID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case receivedNode := <-receivedNodeID:
|
||||
assert.Equal(t, node.ID.String(), receivedNode)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Timeout waiting for node ID")
|
||||
}
|
||||
|
||||
select {
|
||||
case received := <-receivedBackupID:
|
||||
assert.Equal(t, backupID, received)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Timeout waiting for backup completion message")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SubscribeForBackupsCompletions_ParsesJsonCorrectly(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
backupID1 := uuid.New()
|
||||
backupID2 := uuid.New()
|
||||
defer registry.UnsubscribeForBackupsCompletions()
|
||||
|
||||
receivedBackups := make(chan uuid.UUID, 2)
|
||||
handler := func(nodeID string, backupID uuid.UUID) {
|
||||
receivedBackups <- backupID
|
||||
}
|
||||
|
||||
err := registry.SubscribeForBackupsCompletions(handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.PublishBackupCompletion(node.ID.String(), backupID1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.PublishBackupCompletion(node.ID.String(), backupID2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
received1 := <-receivedBackups
|
||||
received2 := <-receivedBackups
|
||||
|
||||
receivedIDs := []uuid.UUID{received1, received2}
|
||||
assert.Contains(t, receivedIDs, backupID1)
|
||||
assert.Contains(t, receivedIDs, backupID2)
|
||||
}
|
||||
|
||||
func Test_SubscribeForBackupsCompletions_HandlesInvalidJson(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
defer registry.UnsubscribeForBackupsCompletions()
|
||||
|
||||
receivedBackupID := make(chan uuid.UUID, 1)
|
||||
handler := func(nodeID string, backupID uuid.UUID) {
|
||||
receivedBackupID <- backupID
|
||||
}
|
||||
|
||||
err := registry.SubscribeForBackupsCompletions(handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
ctx := context.Background()
|
||||
err = registry.pubsubCompletions.Publish(ctx, "backup:completion", "invalid json")
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-receivedBackupID:
|
||||
t.Fatal("Should not receive backup for invalid JSON")
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func Test_UnsubscribeForBackupsCompletions_StopsReceivingMessages(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
backupID1 := uuid.New()
|
||||
backupID2 := uuid.New()
|
||||
|
||||
receivedBackupID := make(chan uuid.UUID, 2)
|
||||
handler := func(nodeID string, backupID uuid.UUID) {
|
||||
receivedBackupID <- backupID
|
||||
}
|
||||
|
||||
err := registry.SubscribeForBackupsCompletions(handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.PublishBackupCompletion(node.ID.String(), backupID1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
received := <-receivedBackupID
|
||||
assert.Equal(t, backupID1, received)
|
||||
|
||||
err = registry.UnsubscribeForBackupsCompletions()
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.PublishBackupCompletion(node.ID.String(), backupID2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-receivedBackupID:
|
||||
t.Fatal("Should not receive backup after unsubscribe")
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SubscribeForBackupsCompletions_WhenAlreadySubscribed_ReturnsError(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
defer registry.UnsubscribeForBackupsCompletions()
|
||||
|
||||
handler := func(nodeID string, backupID uuid.UUID) {}
|
||||
|
||||
err := registry.SubscribeForBackupsCompletions(handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.SubscribeForBackupsCompletions(handler)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "already subscribed")
|
||||
}
|
||||
|
||||
func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry1 := createTestRegistry()
|
||||
registry2 := createTestRegistry()
|
||||
registry3 := createTestRegistry()
|
||||
|
||||
node1 := createTestBackupNode()
|
||||
node2 := createTestBackupNode()
|
||||
node3 := createTestBackupNode()
|
||||
|
||||
backupID1 := uuid.New()
|
||||
backupID2 := uuid.New()
|
||||
backupID3 := uuid.New()
|
||||
|
||||
defer registry1.UnsubscribeForBackupsCompletions()
|
||||
defer registry2.UnsubscribeForBackupsCompletions()
|
||||
defer registry3.UnsubscribeForBackupsCompletions()
|
||||
|
||||
receivedBackups1 := make(chan uuid.UUID, 3)
|
||||
receivedBackups2 := make(chan uuid.UUID, 3)
|
||||
receivedBackups3 := make(chan uuid.UUID, 3)
|
||||
|
||||
handler1 := func(nodeID string, backupID uuid.UUID) { receivedBackups1 <- backupID }
|
||||
handler2 := func(nodeID string, backupID uuid.UUID) { receivedBackups2 <- backupID }
|
||||
handler3 := func(nodeID string, backupID uuid.UUID) { receivedBackups3 <- backupID }
|
||||
|
||||
err := registry1.SubscribeForBackupsCompletions(handler1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry2.SubscribeForBackupsCompletions(handler2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry3.SubscribeForBackupsCompletions(handler3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
publishRegistry := createTestRegistry()
|
||||
err = publishRegistry.PublishBackupCompletion(node1.ID.String(), backupID1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = publishRegistry.PublishBackupCompletion(node2.ID.String(), backupID2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = publishRegistry.PublishBackupCompletion(node3.ID.String(), backupID3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
receivedAll1 := []uuid.UUID{}
|
||||
receivedAll2 := []uuid.UUID{}
|
||||
receivedAll3 := []uuid.UUID{}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
select {
|
||||
case received := <-receivedBackups1:
|
||||
receivedAll1 = append(receivedAll1, received)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Subscriber 1 timeout waiting for completion message")
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
select {
|
||||
case received := <-receivedBackups2:
|
||||
receivedAll2 = append(receivedAll2, received)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Subscriber 2 timeout waiting for completion message")
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
select {
|
||||
case received := <-receivedBackups3:
|
||||
receivedAll3 = append(receivedAll3, received)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Subscriber 3 timeout waiting for completion message")
|
||||
}
|
||||
}
|
||||
|
||||
assert.Contains(t, receivedAll1, backupID1)
|
||||
assert.Contains(t, receivedAll1, backupID2)
|
||||
assert.Contains(t, receivedAll1, backupID3)
|
||||
|
||||
assert.Contains(t, receivedAll2, backupID1)
|
||||
assert.Contains(t, receivedAll2, backupID2)
|
||||
assert.Contains(t, receivedAll2, backupID3)
|
||||
|
||||
assert.Contains(t, receivedAll3, backupID1)
|
||||
assert.Contains(t, receivedAll3, backupID2)
|
||||
assert.Contains(t, receivedAll3, backupID3)
|
||||
}
|
||||
@@ -3,10 +3,11 @@ package backuping
|
||||
import (
|
||||
"context"
|
||||
"databasus-backend/internal/config"
|
||||
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/storages"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
task_registry "databasus-backend/internal/features/tasks/registry"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/period"
|
||||
"fmt"
|
||||
@@ -16,12 +17,18 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
schedulerStartupDelay = 1 * time.Minute
|
||||
schedulerTickerInterval = 1 * time.Minute
|
||||
schedulerHealthcheckThreshold = 5 * time.Minute
|
||||
)
|
||||
|
||||
type BackupsScheduler struct {
|
||||
backupRepository *backups_core.BackupRepository
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
storageService *storages.StorageService
|
||||
backupCancelManager *backups_cancellation.BackupCancelManager
|
||||
nodesRegistry *BackupNodesRegistry
|
||||
taskCancelManager *task_cancellation.TaskCancelManager
|
||||
tasksRegistry *task_registry.TaskNodesRegistry
|
||||
|
||||
lastBackupTime time.Time
|
||||
logger *slog.Logger
|
||||
@@ -35,7 +42,7 @@ func (s *BackupsScheduler) Run(ctx context.Context) {
|
||||
|
||||
if config.GetEnv().IsManyNodesMode {
|
||||
// wait other nodes to start
|
||||
time.Sleep(1 * time.Minute)
|
||||
time.Sleep(schedulerStartupDelay)
|
||||
}
|
||||
|
||||
if err := s.failBackupsInProgress(); err != nil {
|
||||
@@ -43,12 +50,12 @@ func (s *BackupsScheduler) Run(ctx context.Context) {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err := s.nodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted); err != nil {
|
||||
if err := s.tasksRegistry.SubscribeForTasksCompletions(s.onBackupCompleted); err != nil {
|
||||
s.logger.Error("Failed to subscribe to backup completions", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := s.nodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
|
||||
if err := s.tasksRegistry.UnsubscribeForTasksCompletions(); err != nil {
|
||||
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
|
||||
}
|
||||
}()
|
||||
@@ -57,7 +64,7 @@ func (s *BackupsScheduler) Run(ctx context.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
ticker := time.NewTicker(schedulerTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
@@ -84,7 +91,7 @@ func (s *BackupsScheduler) Run(ctx context.Context) {
|
||||
|
||||
func (s *BackupsScheduler) IsSchedulerRunning() bool {
|
||||
// if last backup time is more than 5 minutes ago, return false
|
||||
return s.lastBackupTime.After(time.Now().UTC().Add(-5 * time.Minute))
|
||||
return s.lastBackupTime.After(time.Now().UTC().Add(-schedulerHealthcheckThreshold))
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) failBackupsInProgress() error {
|
||||
@@ -93,12 +100,10 @@ func (s *BackupsScheduler) failBackupsInProgress() error {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println("Backups in progress", len(backupsInProgress))
|
||||
|
||||
for _, backup := range backupsInProgress {
|
||||
if err := s.backupCancelManager.CancelBackup(backup.ID); err != nil {
|
||||
if err := s.taskCancelManager.CancelTask(backup.ID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to cancel backup via context manager",
|
||||
"Failed to cancel backup via task cancel manager",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
@@ -175,7 +180,7 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.nodesRegistry.IncrementBackupsInProgress(leastBusyNodeID.String()); err != nil {
|
||||
if err := s.tasksRegistry.IncrementTasksInProgress(leastBusyNodeID.String()); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to increment backups in progress",
|
||||
"nodeId",
|
||||
@@ -188,7 +193,7 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.nodesRegistry.AssignBackupToNode(leastBusyNodeID.String(), backup.ID, isCallNotifier); err != nil {
|
||||
if err := s.tasksRegistry.AssignTaskToNode(leastBusyNodeID.String(), backup.ID, isCallNotifier); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to submit backup",
|
||||
"nodeId",
|
||||
@@ -198,7 +203,7 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
if decrementErr := s.nodesRegistry.DecrementBackupsInProgress(leastBusyNodeID.String()); decrementErr != nil {
|
||||
if decrementErr := s.tasksRegistry.DecrementTasksInProgress(leastBusyNodeID.String()); decrementErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress after submit failure",
|
||||
"nodeId",
|
||||
@@ -393,7 +398,7 @@ func (s *BackupsScheduler) runPendingBackups() error {
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
|
||||
nodes, err := s.nodesRegistry.GetAvailableNodes()
|
||||
nodes, err := s.tasksRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get available nodes: %w", err)
|
||||
}
|
||||
@@ -402,27 +407,22 @@ func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
|
||||
return nil, fmt.Errorf("no nodes available")
|
||||
}
|
||||
|
||||
stats, err := s.nodesRegistry.GetBackupNodesStats()
|
||||
stats, err := s.tasksRegistry.GetNodesStats()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get backup nodes stats: %w", err)
|
||||
}
|
||||
|
||||
statsMap := make(map[uuid.UUID]int)
|
||||
for _, stat := range stats {
|
||||
statsMap[stat.ID] = stat.ActiveBackups
|
||||
statsMap[stat.ID] = stat.ActiveTasks
|
||||
}
|
||||
|
||||
var bestNode *BackupNode
|
||||
var bestNode *task_registry.TaskNode
|
||||
var bestScore float64 = -1
|
||||
|
||||
now := time.Now().UTC()
|
||||
for i := range nodes {
|
||||
node := &nodes[i]
|
||||
|
||||
if now.Sub(node.LastHeartbeat) > 2*time.Minute {
|
||||
continue
|
||||
}
|
||||
|
||||
activeBackups := statsMap[node.ID]
|
||||
|
||||
var score float64
|
||||
@@ -458,6 +458,13 @@ func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUI
|
||||
return
|
||||
}
|
||||
|
||||
// Verify this task is actually a backup (registry contains multiple task types)
|
||||
_, err = s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
// Not a backup task, ignore it
|
||||
return
|
||||
}
|
||||
|
||||
relation, exists := s.backupToNodeRelations[nodeID]
|
||||
if !exists {
|
||||
s.logger.Warn(
|
||||
@@ -498,7 +505,7 @@ func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUI
|
||||
s.backupToNodeRelations[nodeID] = relation
|
||||
}
|
||||
|
||||
if err := s.nodesRegistry.DecrementBackupsInProgress(nodeIDStr); err != nil {
|
||||
if err := s.tasksRegistry.DecrementTasksInProgress(nodeIDStr); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress",
|
||||
"nodeId",
|
||||
@@ -512,18 +519,14 @@ func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUI
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) checkDeadNodesAndFailBackups() error {
|
||||
nodes, err := s.nodesRegistry.GetAvailableNodes()
|
||||
nodes, err := s.tasksRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get available nodes: %w", err)
|
||||
}
|
||||
|
||||
aliveNodeIDs := make(map[uuid.UUID]bool)
|
||||
now := time.Now().UTC()
|
||||
|
||||
for _, node := range nodes {
|
||||
if now.Sub(node.LastHeartbeat) <= 2*time.Minute {
|
||||
aliveNodeIDs[node.ID] = true
|
||||
}
|
||||
aliveNodeIDs[node.ID] = true
|
||||
}
|
||||
|
||||
for nodeID, relation := range s.backupToNodeRelations {
|
||||
@@ -572,7 +575,7 @@ func (s *BackupsScheduler) checkDeadNodesAndFailBackups() error {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := s.nodesRegistry.DecrementBackupsInProgress(nodeID.String()); err != nil {
|
||||
if err := s.tasksRegistry.DecrementTasksInProgress(nodeID.String()); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress for dead node",
|
||||
"nodeId",
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
task_registry "databasus-backend/internal/features/tasks/registry"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
@@ -42,8 +43,8 @@ func Test_RunPendingBackups_WhenLastBackupWasYesterday_CreatesNewBackup(t *testi
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
@@ -111,8 +112,8 @@ func Test_RunPendingBackups_WhenLastBackupWasRecentlyCompleted_SkipsBackup(t *te
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
@@ -179,8 +180,8 @@ func Test_RunPendingBackups_WhenLastBackupFailedAndRetriesDisabled_SkipsBackup(t
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
@@ -251,8 +252,8 @@ func Test_RunPendingBackups_WhenLastBackupFailedAndRetriesEnabled_CreatesNewBack
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
@@ -324,8 +325,8 @@ func Test_RunPendingBackups_WhenFailedBackupsExceedMaxRetries_SkipsBackup(t *tes
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
@@ -396,8 +397,8 @@ func Test_RunPendingBackups_WhenBackupsDisabled_SkipsBackup(t *testing.T) {
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
@@ -449,6 +450,8 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
var mockNodeID uuid.UUID
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
@@ -457,9 +460,15 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
// Clean up mock node
|
||||
if mockNodeID != uuid.Nil {
|
||||
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: mockNodeID})
|
||||
}
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
@@ -479,7 +488,7 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Register mock node without subscribing to backups (simulates node crash after registration)
|
||||
mockNodeID := uuid.New()
|
||||
mockNodeID = uuid.New()
|
||||
err = CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -493,12 +502,12 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
|
||||
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status)
|
||||
|
||||
// Verify Valkey counter was incremented when backup was assigned
|
||||
stats, err := nodesRegistry.GetBackupNodesStats()
|
||||
stats, err := nodesRegistry.GetNodesStats()
|
||||
assert.NoError(t, err)
|
||||
foundStat := false
|
||||
for _, stat := range stats {
|
||||
if stat.ID == mockNodeID {
|
||||
assert.Equal(t, 1, stat.ActiveBackups)
|
||||
assert.Equal(t, 1, stat.ActiveTasks)
|
||||
foundStat = true
|
||||
break
|
||||
}
|
||||
@@ -523,19 +532,117 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
|
||||
assert.Contains(t, *backups[0].FailMessage, "node unavailability")
|
||||
|
||||
// Verify Valkey counter was decremented after backup failed
|
||||
stats, err = nodesRegistry.GetBackupNodesStats()
|
||||
stats, err = nodesRegistry.GetNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range stats {
|
||||
if stat.ID == mockNodeID {
|
||||
assert.Equal(t, 0, stat.ActiveBackups)
|
||||
assert.Equal(t, 0, stat.ActiveTasks)
|
||||
}
|
||||
}
|
||||
|
||||
// Node info should still exist in registry (not removed by checkDeadNodesAndFailBackups)
|
||||
node, err := GetNodeFromRegistry(mockNodeID)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_OnBackupCompleted_WhenTaskIsNotBackup_SkipsProcessing(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
var mockNodeID uuid.UUID
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
// Clean up mock node
|
||||
if mockNodeID != uuid.Nil {
|
||||
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: mockNodeID})
|
||||
}
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, node)
|
||||
assert.Equal(t, mockNodeID, node.ID)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Register mock node
|
||||
mockNodeID = uuid.New()
|
||||
err = CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Start a backup and assign it to the node
|
||||
GetBackupsScheduler().StartBackup(database.ID, false)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status)
|
||||
|
||||
// Get initial state of the registry
|
||||
initialStats, err := nodesRegistry.GetNodesStats()
|
||||
assert.NoError(t, err)
|
||||
var initialActiveTasks int
|
||||
for _, stat := range initialStats {
|
||||
if stat.ID == mockNodeID {
|
||||
initialActiveTasks = stat.ActiveTasks
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, initialActiveTasks, "Should have 1 active task")
|
||||
|
||||
// Call onBackupCompleted with a random UUID (not a backup ID)
|
||||
nonBackupTaskID := uuid.New()
|
||||
GetBackupsScheduler().onBackupCompleted(mockNodeID.String(), nonBackupTaskID)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify: Active tasks counter should remain the same (not decremented)
|
||||
stats, err := nodesRegistry.GetNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range stats {
|
||||
if stat.ID == mockNodeID {
|
||||
assert.Equal(t, initialActiveTasks, stat.ActiveTasks,
|
||||
"Active tasks should not change for non-backup task")
|
||||
}
|
||||
}
|
||||
|
||||
// Verify: backup should still be in progress (not modified)
|
||||
backups, err = backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status,
|
||||
"Backup status should not change for non-backup task completion")
|
||||
|
||||
// Verify: backupToNodeRelations should still contain the node
|
||||
scheduler := GetBackupsScheduler()
|
||||
_, exists := scheduler.backupToNodeRelations[mockNodeID]
|
||||
assert.True(t, exists, "Node should still be in backupToNodeRelations")
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
@@ -549,6 +656,14 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
|
||||
node3ID := uuid.New()
|
||||
now := time.Now().UTC()
|
||||
|
||||
defer func() {
|
||||
// Clean up all mock nodes
|
||||
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node1ID})
|
||||
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node2ID})
|
||||
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node3ID})
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
err := CreateMockNodeInRegistry(node1ID, 100, now)
|
||||
assert.NoError(t, err)
|
||||
err = CreateMockNodeInRegistry(node2ID, 100, now)
|
||||
@@ -557,17 +672,17 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
for range 5 {
|
||||
err = nodesRegistry.IncrementBackupsInProgress(node1ID.String())
|
||||
err = nodesRegistry.IncrementTasksInProgress(node1ID.String())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
for range 2 {
|
||||
err = nodesRegistry.IncrementBackupsInProgress(node2ID.String())
|
||||
err = nodesRegistry.IncrementTasksInProgress(node2ID.String())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
for range 8 {
|
||||
err = nodesRegistry.IncrementBackupsInProgress(node3ID.String())
|
||||
err = nodesRegistry.IncrementTasksInProgress(node3ID.String())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -584,17 +699,24 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
|
||||
node50MBsID := uuid.New()
|
||||
now := time.Now().UTC()
|
||||
|
||||
defer func() {
|
||||
// Clean up all mock nodes
|
||||
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node100MBsID})
|
||||
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node50MBsID})
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
err := CreateMockNodeInRegistry(node100MBsID, 100, now)
|
||||
assert.NoError(t, err)
|
||||
err = CreateMockNodeInRegistry(node50MBsID, 50, now)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for range 10 {
|
||||
err = nodesRegistry.IncrementBackupsInProgress(node100MBsID.String())
|
||||
err = nodesRegistry.IncrementTasksInProgress(node100MBsID.String())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
err = nodesRegistry.IncrementBackupsInProgress(node50MBsID.String())
|
||||
err = nodesRegistry.IncrementTasksInProgress(node50MBsID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
leastBusyNodeID, err := GetBackupsScheduler().calculateLeastBusyNode()
|
||||
@@ -622,9 +744,11 @@ func Test_FailBackupsInProgress_WhenSchedulerStarts_CancelsBackupsAndUpdatesStat
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
@@ -707,3 +831,204 @@ func Test_FailBackupsInProgress_WhenSchedulerStarts_CancelsBackupsAndUpdatesStat
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
// Start scheduler so it can handle task completions
|
||||
schedulerCancel := StartSchedulerForTest(t)
|
||||
defer schedulerCancel()
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Get initial active task count
|
||||
stats, err := nodesRegistry.GetNodesStats()
|
||||
assert.NoError(t, err)
|
||||
var initialActiveTasks int
|
||||
for _, stat := range stats {
|
||||
if stat.ID == backuperNode.nodeID {
|
||||
initialActiveTasks = stat.ActiveTasks
|
||||
break
|
||||
}
|
||||
}
|
||||
t.Logf("Initial active tasks: %d", initialActiveTasks)
|
||||
|
||||
// Start backup
|
||||
GetBackupsScheduler().StartBackup(database.ID, false)
|
||||
|
||||
// Wait for backup to complete
|
||||
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
|
||||
|
||||
// Verify backup was created and completed
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backups[0].Status)
|
||||
|
||||
// Wait for active task count to decrease
|
||||
decreased := WaitForActiveTasksDecrease(
|
||||
t,
|
||||
backuperNode.nodeID,
|
||||
initialActiveTasks+1,
|
||||
10*time.Second,
|
||||
)
|
||||
assert.True(t, decreased, "Active task count should have decreased after backup completion")
|
||||
|
||||
// Verify final active task count equals initial count
|
||||
finalStats, err := nodesRegistry.GetNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range finalStats {
|
||||
if stat.ID == backuperNode.nodeID {
|
||||
t.Logf("Final active tasks: %d", stat.ActiveTasks)
|
||||
assert.Equal(t, initialActiveTasks, stat.ActiveTasks,
|
||||
"Active task count should return to initial value after backup completion")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
// Start scheduler so it can handle task completions
|
||||
schedulerCancel := StartSchedulerForTest(t)
|
||||
defer schedulerCancel()
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Set wrong password to cause backup failure
|
||||
// We need to bypass service layer validation which would fail on connection test
|
||||
database.Postgresql.Password = "intentionally_wrong_password"
|
||||
dbRepo := &databases.DatabaseRepository{}
|
||||
_, err := dbRepo.Save(database)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Get initial active task count
|
||||
stats, err := nodesRegistry.GetNodesStats()
|
||||
assert.NoError(t, err)
|
||||
var initialActiveTasks int
|
||||
for _, stat := range stats {
|
||||
if stat.ID == backuperNode.nodeID {
|
||||
initialActiveTasks = stat.ActiveTasks
|
||||
break
|
||||
}
|
||||
}
|
||||
t.Logf("Initial active tasks: %d", initialActiveTasks)
|
||||
|
||||
// Start backup
|
||||
GetBackupsScheduler().StartBackup(database.ID, false)
|
||||
|
||||
// Wait for backup to fail
|
||||
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
|
||||
|
||||
// Verify backup was created and failed
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
assert.Equal(t, backups_core.BackupStatusFailed, backups[0].Status)
|
||||
assert.NotNil(t, backups[0].FailMessage)
|
||||
if backups[0].FailMessage != nil {
|
||||
t.Logf("Backup failed with message: %s", *backups[0].FailMessage)
|
||||
}
|
||||
|
||||
// Wait for active task count to decrease
|
||||
decreased := WaitForActiveTasksDecrease(
|
||||
t,
|
||||
backuperNode.nodeID,
|
||||
initialActiveTasks+1,
|
||||
10*time.Second,
|
||||
)
|
||||
assert.True(t, decreased, "Active task count should have decreased after backup failure")
|
||||
|
||||
// Verify final active task count equals initial count
|
||||
finalStats, err := nodesRegistry.GetNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range finalStats {
|
||||
if stat.ID == backuperNode.nodeID {
|
||||
t.Logf("Final active tasks: %d", stat.ActiveTasks)
|
||||
assert.Equal(t, initialActiveTasks, stat.ActiveTasks,
|
||||
"Active task count should return to initial value after backup failure")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
task_registry "databasus-backend/internal/features/tasks/registry"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
@@ -42,7 +43,7 @@ func CreateTestBackuperNode() *BackuperNode {
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
backupCancelManager,
|
||||
taskCancelManager,
|
||||
nodesRegistry,
|
||||
logger.GetLogger(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
@@ -138,6 +139,34 @@ func StartBackuperNodeForTest(t *testing.T, backuperNode *BackuperNode) context.
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartSchedulerForTest starts the BackupsScheduler in a goroutine for testing.
|
||||
// The scheduler subscribes to task completions and manages backup lifecycle.
|
||||
// Returns a context cancel function that should be deferred to stop the scheduler.
|
||||
func StartSchedulerForTest(t *testing.T) context.CancelFunc {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
GetBackupsScheduler().Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Give scheduler time to subscribe to completions
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
t.Log("BackupsScheduler started")
|
||||
|
||||
return func() {
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("BackupsScheduler stopped gracefully")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Log("BackupsScheduler stop timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StopBackuperNodeForTest stops the BackuperNode by canceling its context.
|
||||
// It waits for the node to unregister from the registry.
|
||||
func StopBackuperNodeForTest(t *testing.T, cancel context.CancelFunc, backuperNode *BackuperNode) {
|
||||
@@ -167,7 +196,7 @@ func StopBackuperNodeForTest(t *testing.T, cancel context.CancelFunc, backuperNo
|
||||
}
|
||||
|
||||
func CreateMockNodeInRegistry(nodeID uuid.UUID, throughputMBs int, lastHeartbeat time.Time) error {
|
||||
backupNode := BackupNode{
|
||||
backupNode := task_registry.TaskNode{
|
||||
ID: nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: lastHeartbeat,
|
||||
@@ -181,7 +210,7 @@ func UpdateNodeHeartbeatDirectly(
|
||||
throughputMBs int,
|
||||
lastHeartbeat time.Time,
|
||||
) error {
|
||||
backupNode := BackupNode{
|
||||
backupNode := task_registry.TaskNode{
|
||||
ID: nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: lastHeartbeat,
|
||||
@@ -190,7 +219,7 @@ func UpdateNodeHeartbeatDirectly(
|
||||
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
|
||||
}
|
||||
|
||||
func GetNodeFromRegistry(nodeID uuid.UUID) (*BackupNode, error) {
|
||||
func GetNodeFromRegistry(nodeID uuid.UUID) (*task_registry.TaskNode, error) {
|
||||
nodes, err := nodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -204,3 +233,48 @@ func GetNodeFromRegistry(nodeID uuid.UUID) (*BackupNode, error) {
|
||||
|
||||
return nil, fmt.Errorf("node not found")
|
||||
}
|
||||
|
||||
// WaitForActiveTasksDecrease waits for the active task count to decrease below the initial count.
|
||||
// It polls the registry every 500ms until the count decreases or the timeout is reached.
|
||||
// Returns true if the count decreased, false if timeout was reached.
|
||||
func WaitForActiveTasksDecrease(
|
||||
t *testing.T,
|
||||
nodeID uuid.UUID,
|
||||
initialCount int,
|
||||
timeout time.Duration,
|
||||
) bool {
|
||||
deadline := time.Now().UTC().Add(timeout)
|
||||
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
stats, err := nodesRegistry.GetNodesStats()
|
||||
if err != nil {
|
||||
t.Logf("WaitForActiveTasksDecrease: error getting node stats: %v", err)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, stat := range stats {
|
||||
if stat.ID == nodeID {
|
||||
t.Logf(
|
||||
"WaitForActiveTasksDecrease: current active tasks = %d (initial = %d)",
|
||||
stat.ActiveTasks,
|
||||
initialCount,
|
||||
)
|
||||
if stat.ActiveTasks < initialCount {
|
||||
t.Logf(
|
||||
"WaitForActiveTasksDecrease: active tasks decreased from %d to %d",
|
||||
initialCount,
|
||||
stat.ActiveTasks,
|
||||
)
|
||||
return true
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Logf("WaitForActiveTasksDecrease: timeout waiting for active tasks to decrease")
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
package backups_cancellation
|
||||
|
||||
import (
|
||||
"context"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const backupCancelChannel = "backup:cancel"
|
||||
|
||||
type BackupCancelManager struct {
|
||||
mu sync.RWMutex
|
||||
cancelFuncs map[uuid.UUID]context.CancelFunc
|
||||
pubsub *cache_utils.PubSubManager
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (m *BackupCancelManager) StartSubscription() {
|
||||
ctx := context.Background()
|
||||
|
||||
handler := func(message string) {
|
||||
backupID, err := uuid.Parse(message)
|
||||
if err != nil {
|
||||
m.logger.Error("Invalid backup ID in cancel message", "message", message, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
cancelFunc, exists := m.cancelFuncs[backupID]
|
||||
if exists {
|
||||
cancelFunc()
|
||||
delete(m.cancelFuncs, backupID)
|
||||
m.logger.Info("Cancelled backup via Pub/Sub", "backupID", backupID)
|
||||
}
|
||||
}
|
||||
|
||||
err := m.pubsub.Subscribe(ctx, backupCancelChannel, handler)
|
||||
if err != nil {
|
||||
m.logger.Error("Failed to subscribe to backup cancel channel", "error", err)
|
||||
} else {
|
||||
m.logger.Info("Successfully subscribed to backup cancel channel")
|
||||
}
|
||||
}
|
||||
|
||||
func (m *BackupCancelManager) RegisterBackup(backupID uuid.UUID, cancelFunc context.CancelFunc) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.cancelFuncs[backupID] = cancelFunc
|
||||
m.logger.Debug("Registered backup", "backupID", backupID)
|
||||
}
|
||||
|
||||
func (m *BackupCancelManager) CancelBackup(backupID uuid.UUID) error {
|
||||
ctx := context.Background()
|
||||
|
||||
err := m.pubsub.Publish(ctx, backupCancelChannel, backupID.String())
|
||||
if err != nil {
|
||||
m.logger.Error("Failed to publish cancel message", "backupID", backupID, "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
m.logger.Info("Published backup cancel message", "backupID", backupID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *BackupCancelManager) UnregisterBackup(backupID uuid.UUID) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.cancelFuncs, backupID)
|
||||
m.logger.Debug("Unregistered backup", "backupID", backupID)
|
||||
}
|
||||
@@ -1,12 +1,15 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"context"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
"databasus-backend/internal/features/databases"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -174,6 +177,7 @@ func (c *BackupController) CancelBackup(ctx *gin.Context) {
|
||||
// @Success 200 {object} backups_download.GenerateDownloadTokenResponse
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 409 {object} map[string]string "Download already in progress"
|
||||
// @Router /backups/{id}/download-token [post]
|
||||
func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
@@ -190,6 +194,15 @@ func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
|
||||
|
||||
response, err := c.backupService.GenerateDownloadToken(user, id)
|
||||
if err != nil {
|
||||
if err == backups_download.ErrDownloadAlreadyInProgress {
|
||||
ctx.JSON(
|
||||
http.StatusConflict,
|
||||
gin.H{
|
||||
"error": "Download already in progress for some of backups. Please wait until previous download completed or cancel it",
|
||||
},
|
||||
)
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -199,14 +212,22 @@ func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
|
||||
|
||||
// GetFile
|
||||
// @Summary Download a backup file
|
||||
// @Description Download the backup file for the specified backup using a download token
|
||||
// @Description Download the backup file for the specified backup using a download token.
|
||||
// @Description
|
||||
// @Description **Download Concurrency Control:**
|
||||
// @Description - Only one download per user is allowed at a time
|
||||
// @Description - If a download is already in progress, returns 409 Conflict
|
||||
// @Description - Downloads are tracked using cache with 5-second TTL and 3-second heartbeat
|
||||
// @Description - Browser cancellations automatically release the download lock
|
||||
// @Description - Server crashes are handled via automatic cache expiry (5 seconds)
|
||||
// @Tags backups
|
||||
// @Param id path string true "Backup ID"
|
||||
// @Param token query string true "Download token"
|
||||
// @Success 200 {file} file
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 409 {object} map[string]string "Download already in progress"
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /backups/{id}/file [get]
|
||||
func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
token := ctx.Query("token")
|
||||
@@ -215,7 +236,6 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Get backup ID from URL
|
||||
backupIDParam := ctx.Param("id")
|
||||
backupID, err := uuid.Parse(backupIDParam)
|
||||
if err != nil {
|
||||
@@ -223,13 +243,22 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
downloadToken, err := c.backupService.ValidateDownloadToken(token)
|
||||
downloadToken, rateLimiter, err := c.backupService.ValidateDownloadToken(token)
|
||||
if err != nil {
|
||||
if err == backups_download.ErrDownloadAlreadyInProgress {
|
||||
ctx.JSON(
|
||||
http.StatusConflict,
|
||||
gin.H{
|
||||
"error": "download already in progress for this user. Please wait until previous download completed or cancel it",
|
||||
},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"})
|
||||
return
|
||||
}
|
||||
|
||||
// Verify token is for the requested backup
|
||||
if downloadToken.BackupID != backupID {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"})
|
||||
return
|
||||
@@ -239,18 +268,28 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
downloadToken.BackupID,
|
||||
)
|
||||
if err != nil {
|
||||
c.backupService.UnregisterDownload(downloadToken.UserID)
|
||||
c.backupService.ReleaseDownloadLock(downloadToken.UserID)
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
rateLimitedReader := backups_download.NewRateLimitedReader(fileReader, rateLimiter)
|
||||
|
||||
heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background())
|
||||
defer func() {
|
||||
if err := fileReader.Close(); err != nil {
|
||||
cancelHeartbeat()
|
||||
c.backupService.UnregisterDownload(downloadToken.UserID)
|
||||
c.backupService.ReleaseDownloadLock(downloadToken.UserID)
|
||||
if err := rateLimitedReader.Close(); err != nil {
|
||||
fmt.Printf("Error closing file reader: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go c.startDownloadHeartbeat(heartbeatCtx, downloadToken.UserID)
|
||||
|
||||
filename := c.generateBackupFilename(backup, database)
|
||||
|
||||
// Set Content-Length for progress tracking
|
||||
if backup.BackupSizeMb > 0 {
|
||||
sizeBytes := int64(backup.BackupSizeMb * 1024 * 1024)
|
||||
ctx.Header("Content-Length", fmt.Sprintf("%d", sizeBytes))
|
||||
@@ -262,13 +301,12 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
fmt.Sprintf("attachment; filename=\"%s\"", filename),
|
||||
)
|
||||
|
||||
_, err = io.Copy(ctx.Writer, fileReader)
|
||||
_, err = io.Copy(ctx.Writer, rateLimitedReader)
|
||||
if err != nil {
|
||||
fmt.Printf("Error streaming file: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Write audit log after successful download
|
||||
c.backupService.WriteAuditLogForDownload(downloadToken.UserID, backup, database)
|
||||
}
|
||||
|
||||
@@ -334,3 +372,17 @@ func sanitizeFilename(name string) string {
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func (c *BackupController) startDownloadHeartbeat(ctx context.Context, userID uuid.UUID) {
|
||||
ticker := time.NewTicker(backups_download.GetDownloadHeartbeatInterval())
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.backupService.RefreshDownloadLock(userID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -913,7 +913,7 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Register a cancellable context for the backup
|
||||
GetBackupService().backupCancelManager.RegisterBackup(backup.ID, func() {})
|
||||
GetBackupService().taskCancelManager.RegisterTask(backup.ID, func() {})
|
||||
|
||||
resp := test_utils.MakePostRequest(
|
||||
t,
|
||||
@@ -950,6 +950,189 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
|
||||
assert.True(t, foundCancelLog, "Cancel audit log should be created")
|
||||
}
|
||||
|
||||
func Test_ConcurrentDownloadPrevention(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var token1Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token1Response,
|
||||
)
|
||||
|
||||
var token2Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token2Response,
|
||||
)
|
||||
|
||||
downloadInProgress := make(chan bool, 1)
|
||||
downloadComplete := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/backups/%s/file?token=%s",
|
||||
backup.ID.String(),
|
||||
token1Response.Token,
|
||||
),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
downloadComplete <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
service := GetBackupService()
|
||||
if !service.IsDownloadInProgress(owner.UserID) {
|
||||
t.Log("Warning: First download completed before we could test concurrency")
|
||||
<-downloadComplete
|
||||
return
|
||||
}
|
||||
|
||||
downloadInProgress <- true
|
||||
|
||||
resp := test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), token2Response.Token),
|
||||
"",
|
||||
http.StatusConflict,
|
||||
)
|
||||
|
||||
var errorResponse map[string]string
|
||||
err := json.Unmarshal(resp.Body, &errorResponse)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, errorResponse["error"], "download already in progress")
|
||||
|
||||
<-downloadComplete
|
||||
<-downloadInProgress
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
var token3Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token3Response,
|
||||
)
|
||||
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), token3Response.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
t.Log("Database:", database.Name)
|
||||
t.Log(
|
||||
"Successfully prevented concurrent downloads and allowed subsequent downloads after completion",
|
||||
)
|
||||
}
|
||||
|
||||
func Test_GenerateDownloadToken_BlockedWhenDownloadInProgress(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var token1Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token1Response,
|
||||
)
|
||||
|
||||
downloadComplete := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/backups/%s/file?token=%s",
|
||||
backup.ID.String(),
|
||||
token1Response.Token,
|
||||
),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
downloadComplete <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
service := GetBackupService()
|
||||
if !service.IsDownloadInProgress(owner.UserID) {
|
||||
t.Log("Warning: First download completed before we could test token generation blocking")
|
||||
<-downloadComplete
|
||||
return
|
||||
}
|
||||
|
||||
resp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusConflict,
|
||||
)
|
||||
|
||||
var errorResponse map[string]string
|
||||
err := json.Unmarshal(resp.Body, &errorResponse)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, errorResponse["error"], "download already in progress")
|
||||
|
||||
<-downloadComplete
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
var token2Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token2Response,
|
||||
)
|
||||
|
||||
assert.NotEmpty(t, token2Response.Token)
|
||||
assert.NotEqual(t, token1Response.Token, token2Response.Token)
|
||||
|
||||
t.Log("Database:", database.Name)
|
||||
t.Log(
|
||||
"Successfully blocked token generation during download and allowed generation after completion",
|
||||
)
|
||||
}
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
return CreateTestRouter()
|
||||
}
|
||||
@@ -1131,3 +1314,267 @@ func createExpiredDownloadToken(backupID, userID uuid.UUID) string {
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
func Test_BandwidthThrottling_SingleDownload_Uses75Percent(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
bandwidthManager := backups_download.GetBandwidthManager()
|
||||
initialCount := bandwidthManager.GetActiveDownloadCount()
|
||||
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&tokenResponse,
|
||||
)
|
||||
|
||||
downloadStarted := make(chan bool, 1)
|
||||
downloadComplete := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/backups/%s/file?token=%s",
|
||||
backup.ID.String(),
|
||||
tokenResponse.Token,
|
||||
),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
downloadComplete <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
activeCount := bandwidthManager.GetActiveDownloadCount()
|
||||
if activeCount > initialCount {
|
||||
downloadStarted <- true
|
||||
assert.Equal(t, initialCount+1, activeCount, "Should have one active download")
|
||||
}
|
||||
|
||||
<-downloadComplete
|
||||
if len(downloadStarted) > 0 {
|
||||
<-downloadStarted
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
finalCount := bandwidthManager.GetActiveDownloadCount()
|
||||
assert.Equal(t, initialCount, finalCount, "Download should be unregistered after completion")
|
||||
}
|
||||
|
||||
func Test_BandwidthThrottling_MultipleDownloads_ShareBandwidth(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
owner3 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner1, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
owner2,
|
||||
users_enums.WorkspaceRoleMember,
|
||||
owner1.Token,
|
||||
router,
|
||||
)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
owner3,
|
||||
users_enums.WorkspaceRoleMember,
|
||||
owner1.Token,
|
||||
router,
|
||||
)
|
||||
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner1.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
configService := backups_config.GetBackupConfigService()
|
||||
config, err := configService.GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
config.IsBackupsEnabled = true
|
||||
config.StorageID = &storage.ID
|
||||
config.Storage = storage
|
||||
_, err = configService.SaveBackupConfig(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backup1 := createTestBackup(database, owner1)
|
||||
backup2 := createTestBackup(database, owner2)
|
||||
backup3 := createTestBackup(database, owner3)
|
||||
|
||||
var token1, token2, token3 backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup1.ID.String()),
|
||||
"Bearer "+owner1.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token1,
|
||||
)
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup2.ID.String()),
|
||||
"Bearer "+owner2.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token2,
|
||||
)
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup3.ID.String()),
|
||||
"Bearer "+owner3.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token3,
|
||||
)
|
||||
|
||||
bandwidthManager := backups_download.GetBandwidthManager()
|
||||
initialCount := bandwidthManager.GetActiveDownloadCount()
|
||||
|
||||
complete1 := make(chan bool, 1)
|
||||
complete2 := make(chan bool, 1)
|
||||
complete3 := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup1.ID.String(), token1.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
complete1 <- true
|
||||
}()
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup2.ID.String(), token2.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
complete2 <- true
|
||||
}()
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup3.ID.String(), token3.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
complete3 <- true
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
<-complete1
|
||||
<-complete2
|
||||
<-complete3
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
finalCount := bandwidthManager.GetActiveDownloadCount()
|
||||
assert.Equal(t, initialCount, finalCount, "All downloads should be unregistered")
|
||||
}
|
||||
|
||||
func Test_BandwidthThrottling_DynamicAdjustment(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner1, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
owner2,
|
||||
users_enums.WorkspaceRoleMember,
|
||||
owner1.Token,
|
||||
router,
|
||||
)
|
||||
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner1.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
configService := backups_config.GetBackupConfigService()
|
||||
config, err := configService.GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
config.IsBackupsEnabled = true
|
||||
config.StorageID = &storage.ID
|
||||
config.Storage = storage
|
||||
_, err = configService.SaveBackupConfig(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backup1 := createTestBackup(database, owner1)
|
||||
backup2 := createTestBackup(database, owner2)
|
||||
|
||||
var token1, token2 backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup1.ID.String()),
|
||||
"Bearer "+owner1.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token1,
|
||||
)
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup2.ID.String()),
|
||||
"Bearer "+owner2.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token2,
|
||||
)
|
||||
|
||||
bandwidthManager := backups_download.GetBandwidthManager()
|
||||
initialCount := bandwidthManager.GetActiveDownloadCount()
|
||||
|
||||
complete1 := make(chan bool, 1)
|
||||
complete2 := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup1.ID.String(), token1.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
complete1 <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup2.ID.String(), token2.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
complete2 <- true
|
||||
}()
|
||||
|
||||
<-complete1
|
||||
<-complete2
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
finalCount := bandwidthManager.GetActiveDownloadCount()
|
||||
assert.Equal(t, initialCount, finalCount, "All downloads completed and unregistered")
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package backups
|
||||
import (
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
"databasus-backend/internal/features/backups/backups/usecases"
|
||||
@@ -12,6 +11,7 @@ import (
|
||||
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
|
||||
var backupRepository = &backups_core.BackupRepository{}
|
||||
|
||||
var backupCancelManager = backups_cancellation.GetBackupCancelManager()
|
||||
var taskCancelManager = task_cancellation.GetTaskCancelManager()
|
||||
|
||||
var backupService = &BackupService{
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
@@ -35,7 +35,7 @@ var backupService = &BackupService{
|
||||
backupRemoveListeners: []backups_core.BackupRemoveListener{},
|
||||
workspaceService: workspaces_services.GetWorkspaceService(),
|
||||
auditLogService: audit_logs.GetAuditLogService(),
|
||||
backupCancelManager: backupCancelManager,
|
||||
taskCancelManager: taskCancelManager,
|
||||
downloadTokenService: backups_download.GetDownloadTokenService(),
|
||||
backupSchedulerService: backuping.GetBackupsScheduler(),
|
||||
}
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BandwidthManager struct {
|
||||
mu sync.RWMutex
|
||||
activeDownloads map[uuid.UUID]*activeDownload
|
||||
maxTotalBytesPerSecond int64
|
||||
bytesPerSecondPerDownload int64
|
||||
}
|
||||
|
||||
type activeDownload struct {
|
||||
userID uuid.UUID
|
||||
rateLimiter *RateLimiter
|
||||
}
|
||||
|
||||
func NewBandwidthManager(throughputMBs int) *BandwidthManager {
|
||||
// Use 75% of total throughput
|
||||
maxBytes := int64(throughputMBs) * 1024 * 1024 * 75 / 100
|
||||
|
||||
return &BandwidthManager{
|
||||
activeDownloads: make(map[uuid.UUID]*activeDownload),
|
||||
maxTotalBytesPerSecond: maxBytes,
|
||||
bytesPerSecondPerDownload: maxBytes,
|
||||
}
|
||||
}
|
||||
|
||||
func (bm *BandwidthManager) RegisterDownload(userID uuid.UUID) (*RateLimiter, error) {
|
||||
bm.mu.Lock()
|
||||
defer bm.mu.Unlock()
|
||||
|
||||
if _, exists := bm.activeDownloads[userID]; exists {
|
||||
return nil, fmt.Errorf("download already registered for user %s", userID)
|
||||
}
|
||||
|
||||
rateLimiter := NewRateLimiter(bm.bytesPerSecondPerDownload)
|
||||
|
||||
bm.activeDownloads[userID] = &activeDownload{
|
||||
userID: userID,
|
||||
rateLimiter: rateLimiter,
|
||||
}
|
||||
|
||||
bm.recalculateRates()
|
||||
|
||||
return rateLimiter, nil
|
||||
}
|
||||
|
||||
func (bm *BandwidthManager) UnregisterDownload(userID uuid.UUID) {
|
||||
bm.mu.Lock()
|
||||
defer bm.mu.Unlock()
|
||||
|
||||
delete(bm.activeDownloads, userID)
|
||||
bm.recalculateRates()
|
||||
}
|
||||
|
||||
func (bm *BandwidthManager) GetActiveDownloadCount() int {
|
||||
bm.mu.RLock()
|
||||
defer bm.mu.RUnlock()
|
||||
return len(bm.activeDownloads)
|
||||
}
|
||||
|
||||
func (bm *BandwidthManager) recalculateRates() {
|
||||
activeCount := len(bm.activeDownloads)
|
||||
|
||||
if activeCount == 0 {
|
||||
bm.bytesPerSecondPerDownload = bm.maxTotalBytesPerSecond
|
||||
return
|
||||
}
|
||||
|
||||
newRate := bm.maxTotalBytesPerSecond / int64(activeCount)
|
||||
bm.bytesPerSecondPerDownload = newRate
|
||||
|
||||
for _, download := range bm.activeDownloads {
|
||||
download.rateLimiter.UpdateRate(newRate)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,150 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_BandwidthManager_RegisterSingleDownload(t *testing.T) {
|
||||
throughputMBs := 100
|
||||
manager := NewBandwidthManager(throughputMBs)
|
||||
|
||||
expectedBytesPerSec := int64(100 * 1024 * 1024 * 75 / 100)
|
||||
assert.Equal(t, expectedBytesPerSec, manager.maxTotalBytesPerSecond)
|
||||
assert.Equal(t, expectedBytesPerSec, manager.bytesPerSecondPerDownload)
|
||||
|
||||
userID := uuid.New()
|
||||
rateLimiter, err := manager.RegisterDownload(userID)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, rateLimiter)
|
||||
|
||||
assert.Equal(t, 1, manager.GetActiveDownloadCount())
|
||||
assert.Equal(t, expectedBytesPerSec, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, expectedBytesPerSec, rateLimiter.bytesPerSecond)
|
||||
}
|
||||
|
||||
func Test_BandwidthManager_RegisterMultipleDownloads_BandwidthShared(t *testing.T) {
|
||||
throughputMBs := 100
|
||||
manager := NewBandwidthManager(throughputMBs)
|
||||
|
||||
maxBytes := int64(100 * 1024 * 1024 * 75 / 100)
|
||||
|
||||
user1 := uuid.New()
|
||||
rateLimiter1, err := manager.RegisterDownload(user1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, maxBytes, rateLimiter1.bytesPerSecond)
|
||||
|
||||
user2 := uuid.New()
|
||||
rateLimiter2, err := manager.RegisterDownload(user2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expectedPerDownload := maxBytes / 2
|
||||
assert.Equal(t, expectedPerDownload, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter2.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter2.bytesPerSecond)
|
||||
|
||||
user3 := uuid.New()
|
||||
rateLimiter3, err := manager.RegisterDownload(user3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expectedPerDownload = maxBytes / 3
|
||||
assert.Equal(t, expectedPerDownload, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter2.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter3.bytesPerSecond)
|
||||
assert.Equal(t, 3, manager.GetActiveDownloadCount())
|
||||
}
|
||||
|
||||
func Test_BandwidthManager_UnregisterDownload_BandwidthRebalanced(t *testing.T) {
|
||||
throughputMBs := 100
|
||||
manager := NewBandwidthManager(throughputMBs)
|
||||
|
||||
maxBytes := int64(100 * 1024 * 1024 * 75 / 100)
|
||||
|
||||
user1 := uuid.New()
|
||||
rateLimiter1, _ := manager.RegisterDownload(user1)
|
||||
|
||||
user2 := uuid.New()
|
||||
_, _ = manager.RegisterDownload(user2)
|
||||
|
||||
user3 := uuid.New()
|
||||
rateLimiter3, _ := manager.RegisterDownload(user3)
|
||||
|
||||
assert.Equal(t, 3, manager.GetActiveDownloadCount())
|
||||
expectedPerDownload := maxBytes / 3
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
|
||||
|
||||
manager.UnregisterDownload(user2)
|
||||
|
||||
assert.Equal(t, 2, manager.GetActiveDownloadCount())
|
||||
expectedPerDownload = maxBytes / 2
|
||||
assert.Equal(t, expectedPerDownload, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter3.bytesPerSecond)
|
||||
|
||||
manager.UnregisterDownload(user1)
|
||||
|
||||
assert.Equal(t, 1, manager.GetActiveDownloadCount())
|
||||
assert.Equal(t, maxBytes, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, maxBytes, rateLimiter3.bytesPerSecond)
|
||||
|
||||
manager.UnregisterDownload(user3)
|
||||
assert.Equal(t, 0, manager.GetActiveDownloadCount())
|
||||
assert.Equal(t, maxBytes, manager.bytesPerSecondPerDownload)
|
||||
}
|
||||
|
||||
func Test_BandwidthManager_RegisterDuplicateUser_ReturnsError(t *testing.T) {
|
||||
manager := NewBandwidthManager(100)
|
||||
|
||||
userID := uuid.New()
|
||||
_, err := manager.RegisterDownload(userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = manager.RegisterDownload(userID)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "download already registered")
|
||||
}
|
||||
|
||||
func Test_RateLimiter_TokenBucketBasic(t *testing.T) {
|
||||
bytesPerSec := int64(1024 * 1024)
|
||||
limiter := NewRateLimiter(bytesPerSec)
|
||||
|
||||
assert.Equal(t, bytesPerSec, limiter.bytesPerSecond)
|
||||
assert.Equal(t, bytesPerSec*2, limiter.bucketSize)
|
||||
|
||||
start := time.Now()
|
||||
limiter.Wait(512 * 1024)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.Less(t, elapsed, 100*time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_RateLimiter_UpdateRate(t *testing.T) {
|
||||
limiter := NewRateLimiter(1024 * 1024)
|
||||
|
||||
assert.Equal(t, int64(1024*1024), limiter.bytesPerSecond)
|
||||
|
||||
newRate := int64(2 * 1024 * 1024)
|
||||
limiter.UpdateRate(newRate)
|
||||
|
||||
assert.Equal(t, newRate, limiter.bytesPerSecond)
|
||||
assert.Equal(t, newRate*2, limiter.bucketSize)
|
||||
}
|
||||
|
||||
func Test_RateLimiter_ThrottlesCorrectly(t *testing.T) {
|
||||
bytesPerSec := int64(1024 * 1024)
|
||||
limiter := NewRateLimiter(bytesPerSec)
|
||||
|
||||
limiter.availableTokens = 0
|
||||
|
||||
start := time.Now()
|
||||
limiter.Wait(bytesPerSec / 2)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.GreaterOrEqual(t, elapsed, 400*time.Millisecond)
|
||||
assert.LessOrEqual(t, elapsed, 700*time.Millisecond)
|
||||
}
|
||||
@@ -1,19 +1,38 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var downloadTokenRepository = &DownloadTokenRepository{}
|
||||
|
||||
var downloadTokenService = &DownloadTokenService{
|
||||
downloadTokenRepository,
|
||||
logger.GetLogger(),
|
||||
}
|
||||
var downloadTracker = NewDownloadTracker(cache_utils.GetValkeyClient())
|
||||
|
||||
var downloadTokenBackgroundService = &DownloadTokenBackgroundService{
|
||||
downloadTokenService,
|
||||
logger.GetLogger(),
|
||||
var bandwidthManager *BandwidthManager
|
||||
var downloadTokenService *DownloadTokenService
|
||||
var downloadTokenBackgroundService *DownloadTokenBackgroundService
|
||||
|
||||
func init() {
|
||||
env := config.GetEnv()
|
||||
throughputMBs := env.NodeNetworkThroughputMBs
|
||||
if throughputMBs == 0 {
|
||||
throughputMBs = 125
|
||||
}
|
||||
bandwidthManager = NewBandwidthManager(throughputMBs)
|
||||
|
||||
downloadTokenService = &DownloadTokenService{
|
||||
downloadTokenRepository,
|
||||
logger.GetLogger(),
|
||||
downloadTracker,
|
||||
bandwidthManager,
|
||||
}
|
||||
|
||||
downloadTokenBackgroundService = &DownloadTokenBackgroundService{
|
||||
downloadTokenService,
|
||||
logger.GetLogger(),
|
||||
}
|
||||
}
|
||||
|
||||
func GetDownloadTokenService() *DownloadTokenService {
|
||||
@@ -23,3 +42,7 @@ func GetDownloadTokenService() *DownloadTokenService {
|
||||
func GetDownloadTokenBackgroundService() *DownloadTokenBackgroundService {
|
||||
return downloadTokenBackgroundService
|
||||
}
|
||||
|
||||
func GetBandwidthManager() *BandwidthManager {
|
||||
return bandwidthManager
|
||||
}
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RateLimiter struct {
|
||||
mu sync.Mutex
|
||||
bytesPerSecond int64
|
||||
bucketSize int64
|
||||
availableTokens float64
|
||||
lastRefill time.Time
|
||||
}
|
||||
|
||||
func NewRateLimiter(bytesPerSecond int64) *RateLimiter {
|
||||
if bytesPerSecond <= 0 {
|
||||
bytesPerSecond = 1024 * 1024 * 100
|
||||
}
|
||||
|
||||
return &RateLimiter{
|
||||
bytesPerSecond: bytesPerSecond,
|
||||
bucketSize: bytesPerSecond * 2,
|
||||
availableTokens: float64(bytesPerSecond * 2),
|
||||
lastRefill: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) UpdateRate(bytesPerSecond int64) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
if bytesPerSecond <= 0 {
|
||||
bytesPerSecond = 1024 * 1024 * 100
|
||||
}
|
||||
|
||||
rl.bytesPerSecond = bytesPerSecond
|
||||
rl.bucketSize = bytesPerSecond * 2
|
||||
|
||||
if rl.availableTokens > float64(rl.bucketSize) {
|
||||
rl.availableTokens = float64(rl.bucketSize)
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) Wait(bytes int64) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
for {
|
||||
now := time.Now().UTC()
|
||||
elapsed := now.Sub(rl.lastRefill).Seconds()
|
||||
|
||||
tokensToAdd := elapsed * float64(rl.bytesPerSecond)
|
||||
rl.availableTokens += tokensToAdd
|
||||
if rl.availableTokens > float64(rl.bucketSize) {
|
||||
rl.availableTokens = float64(rl.bucketSize)
|
||||
}
|
||||
rl.lastRefill = now
|
||||
|
||||
if rl.availableTokens >= float64(bytes) {
|
||||
rl.availableTokens -= float64(bytes)
|
||||
return
|
||||
}
|
||||
|
||||
tokensNeeded := float64(bytes) - rl.availableTokens
|
||||
waitTime := time.Duration(tokensNeeded/float64(rl.bytesPerSecond)*1000) * time.Millisecond
|
||||
|
||||
if waitTime < time.Millisecond {
|
||||
waitTime = time.Millisecond
|
||||
}
|
||||
|
||||
rl.mu.Unlock()
|
||||
time.Sleep(waitTime)
|
||||
rl.mu.Lock()
|
||||
}
|
||||
}
|
||||
|
||||
type RateLimitedReader struct {
|
||||
reader io.ReadCloser
|
||||
rateLimiter *RateLimiter
|
||||
}
|
||||
|
||||
func NewRateLimitedReader(reader io.ReadCloser, limiter *RateLimiter) *RateLimitedReader {
|
||||
return &RateLimitedReader{
|
||||
reader: reader,
|
||||
rateLimiter: limiter,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RateLimitedReader) Read(p []byte) (n int, err error) {
|
||||
n, err = r.reader.Read(p)
|
||||
if n > 0 {
|
||||
r.rateLimiter.Wait(int64(n))
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *RateLimitedReader) Close() error {
|
||||
return r.reader.Close()
|
||||
}
|
||||
@@ -9,11 +9,17 @@ import (
|
||||
)
|
||||
|
||||
type DownloadTokenService struct {
|
||||
repository *DownloadTokenRepository
|
||||
logger *slog.Logger
|
||||
repository *DownloadTokenRepository
|
||||
logger *slog.Logger
|
||||
downloadTracker *DownloadTracker
|
||||
bandwidthManager *BandwidthManager
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) Generate(backupID, userID uuid.UUID) (string, error) {
|
||||
if s.downloadTracker.IsDownloadInProgress(userID) {
|
||||
return "", ErrDownloadAlreadyInProgress
|
||||
}
|
||||
|
||||
token := GenerateSecureToken()
|
||||
|
||||
downloadToken := &DownloadToken{
|
||||
@@ -32,22 +38,34 @@ func (s *DownloadTokenService) Generate(backupID, userID uuid.UUID) (string, err
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) ValidateAndConsume(token string) (*DownloadToken, error) {
|
||||
func (s *DownloadTokenService) ValidateAndConsume(
|
||||
token string,
|
||||
) (*DownloadToken, *RateLimiter, error) {
|
||||
dt, err := s.repository.FindByToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if dt == nil {
|
||||
return nil, errors.New("invalid token")
|
||||
return nil, nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
if dt.Used {
|
||||
return nil, errors.New("token already used")
|
||||
return nil, nil, errors.New("token already used")
|
||||
}
|
||||
|
||||
if time.Now().UTC().After(dt.ExpiresAt) {
|
||||
return nil, errors.New("token expired")
|
||||
return nil, nil, errors.New("token expired")
|
||||
}
|
||||
|
||||
if err := s.downloadTracker.AcquireDownloadLock(dt.UserID); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
rateLimiter, err := s.bandwidthManager.RegisterDownload(dt.UserID)
|
||||
if err != nil {
|
||||
s.downloadTracker.ReleaseDownloadLock(dt.UserID)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
dt.Used = true
|
||||
@@ -55,8 +73,26 @@ func (s *DownloadTokenService) ValidateAndConsume(token string) (*DownloadToken,
|
||||
s.logger.Error("Failed to mark token as used", "error", err)
|
||||
}
|
||||
|
||||
s.logger.Info("Token validated and consumed", "backupId", dt.BackupID)
|
||||
return dt, nil
|
||||
s.logger.Info("Token validated and consumed", "backupId", dt.BackupID, "userId", dt.UserID)
|
||||
return dt, rateLimiter, nil
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) RefreshDownloadLock(userID uuid.UUID) {
|
||||
s.downloadTracker.RefreshDownloadLock(userID)
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) ReleaseDownloadLock(userID uuid.UUID) {
|
||||
s.downloadTracker.ReleaseDownloadLock(userID)
|
||||
s.logger.Info("Released download lock", "userId", userID)
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) IsDownloadInProgress(userID uuid.UUID) bool {
|
||||
return s.downloadTracker.IsDownloadInProgress(userID)
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) UnregisterDownload(userID uuid.UUID) {
|
||||
s.bandwidthManager.UnregisterDownload(userID)
|
||||
s.logger.Info("Unregistered from bandwidth manager", "userId", userID)
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) CleanExpiredTokens() error {
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/valkey-io/valkey-go"
|
||||
)
|
||||
|
||||
const (
|
||||
downloadLockPrefix = "backup_download_lock:"
|
||||
downloadLockTTL = 5 * time.Second
|
||||
downloadLockValue = "1"
|
||||
downloadHeartbeatDelay = 3 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDownloadAlreadyInProgress = errors.New("download already in progress for this user")
|
||||
)
|
||||
|
||||
type DownloadTracker struct {
|
||||
cache *cache_utils.CacheUtil[string]
|
||||
}
|
||||
|
||||
func NewDownloadTracker(client valkey.Client) *DownloadTracker {
|
||||
return &DownloadTracker{
|
||||
cache: cache_utils.NewCacheUtil[string](client, downloadLockPrefix),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *DownloadTracker) AcquireDownloadLock(userID uuid.UUID) error {
|
||||
key := userID.String()
|
||||
|
||||
existingLock := t.cache.Get(key)
|
||||
if existingLock != nil {
|
||||
return ErrDownloadAlreadyInProgress
|
||||
}
|
||||
|
||||
value := downloadLockValue
|
||||
t.cache.Set(key, &value)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *DownloadTracker) RefreshDownloadLock(userID uuid.UUID) {
|
||||
key := userID.String()
|
||||
value := downloadLockValue
|
||||
t.cache.Set(key, &value)
|
||||
}
|
||||
|
||||
func (t *DownloadTracker) ReleaseDownloadLock(userID uuid.UUID) {
|
||||
key := userID.String()
|
||||
t.cache.Invalidate(key)
|
||||
}
|
||||
|
||||
func (t *DownloadTracker) IsDownloadInProgress(userID uuid.UUID) bool {
|
||||
key := userID.String()
|
||||
existingLock := t.cache.Get(key)
|
||||
return existingLock != nil
|
||||
}
|
||||
|
||||
func GetDownloadHeartbeatInterval() time.Duration {
|
||||
return downloadHeartbeatDelay
|
||||
}
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
"databasus-backend/internal/features/backups/backups/encryption"
|
||||
@@ -18,6 +17,7 @@ import (
|
||||
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
users_models "databasus-backend/internal/features/users/models"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
@@ -43,7 +43,7 @@ type BackupService struct {
|
||||
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
auditLogService *audit_logs.AuditLogService
|
||||
backupCancelManager *backups_cancellation.BackupCancelManager
|
||||
taskCancelManager *task_cancellation.TaskCancelManager
|
||||
downloadTokenService *backups_download.DownloadTokenService
|
||||
backupSchedulerService *backuping.BackupsScheduler
|
||||
}
|
||||
@@ -226,7 +226,7 @@ func (s *BackupService) CancelBackup(
|
||||
return errors.New("backup is not in progress")
|
||||
}
|
||||
|
||||
if err := s.backupCancelManager.CancelBackup(backupID); err != nil {
|
||||
if err := s.taskCancelManager.CancelTask(backupID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -481,7 +481,7 @@ func (s *BackupService) GenerateDownloadToken(
|
||||
|
||||
func (s *BackupService) ValidateDownloadToken(
|
||||
token string,
|
||||
) (*backups_download.DownloadToken, error) {
|
||||
) (*backups_download.DownloadToken, *backups_download.RateLimiter, error) {
|
||||
return s.downloadTokenService.ValidateAndConsume(token)
|
||||
}
|
||||
|
||||
@@ -522,6 +522,22 @@ func (s *BackupService) WriteAuditLogForDownload(
|
||||
)
|
||||
}
|
||||
|
||||
func (s *BackupService) RefreshDownloadLock(userID uuid.UUID) {
|
||||
s.downloadTokenService.RefreshDownloadLock(userID)
|
||||
}
|
||||
|
||||
func (s *BackupService) ReleaseDownloadLock(userID uuid.UUID) {
|
||||
s.downloadTokenService.ReleaseDownloadLock(userID)
|
||||
}
|
||||
|
||||
func (s *BackupService) IsDownloadInProgress(userID uuid.UUID) bool {
|
||||
return s.downloadTokenService.IsDownloadInProgress(userID)
|
||||
}
|
||||
|
||||
func (s *BackupService) UnregisterDownload(userID uuid.UUID) {
|
||||
s.downloadTokenService.UnregisterDownload(userID)
|
||||
}
|
||||
|
||||
func (s *BackupService) generateBackupFilename(
|
||||
backup *backups_core.Backup,
|
||||
database *databases.Database,
|
||||
|
||||
@@ -400,6 +400,7 @@ func HasPrivilege(privileges, priv string) bool {
|
||||
|
||||
func (m *MysqlDatabase) buildDSN(password string, database string) string {
|
||||
tlsConfig := "false"
|
||||
allowCleartext := ""
|
||||
|
||||
if m.IsHttps {
|
||||
err := mysql.RegisterTLSConfig("mysql-skip-verify", &tls.Config{
|
||||
@@ -411,16 +412,18 @@ func (m *MysqlDatabase) buildDSN(password string, database string) string {
|
||||
}
|
||||
|
||||
tlsConfig = "mysql-skip-verify"
|
||||
allowCleartext = "&allowCleartextPasswords=1"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%s:%s@tcp(%s:%d)/%s?parseTime=true&timeout=15s&tls=%s&charset=utf8mb4",
|
||||
"%s:%s@tcp(%s:%d)/%s?parseTime=true&timeout=15s&tls=%s&charset=utf8mb4%s",
|
||||
m.Username,
|
||||
password,
|
||||
m.Host,
|
||||
m.Port,
|
||||
database,
|
||||
tlsConfig,
|
||||
allowCleartext,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -85,6 +85,42 @@ func (p *PostgresqlDatabase) Validate() error {
|
||||
return errors.New("cpu count must be greater than 0")
|
||||
}
|
||||
|
||||
// Prevent Databasus from backing up itself
|
||||
// Databasus runs an internal PostgreSQL instance that should not be backed up through the UI
|
||||
// because it would expose internal metadata to non-system administrators.
|
||||
// To properly backup Databasus, see: https://databasus.com/faq#backup-databasus
|
||||
if p.Database != nil && *p.Database != "" {
|
||||
localhostHosts := []string{
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
"172.17.0.1",
|
||||
"host.docker.internal",
|
||||
"::1", // IPv6 loopback (equivalent to 127.0.0.1)
|
||||
"::", // IPv6 all interfaces (equivalent to 0.0.0.0)
|
||||
"0.0.0.0", // IPv4 all interfaces
|
||||
}
|
||||
|
||||
isLocalhost := false
|
||||
|
||||
for _, host := range localhostHosts {
|
||||
if strings.EqualFold(p.Host, host) {
|
||||
isLocalhost = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Also check if the host is in the entire 127.0.0.0/8 loopback range
|
||||
if strings.HasPrefix(p.Host, "127.") {
|
||||
isLocalhost = true
|
||||
}
|
||||
|
||||
if isLocalhost && strings.EqualFold(*p.Database, "databasus") {
|
||||
return errors.New(
|
||||
"backing up Databasus internal database is not allowed. To backup Databasus itself, see https://databasus.com/faq#backup-databasus",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -705,6 +705,269 @@ func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenLocalhostAndDatabasus_ReturnsError(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
username string
|
||||
database string
|
||||
}{
|
||||
{
|
||||
name: "localhost with databasus db",
|
||||
host: "localhost",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1 with databasus db",
|
||||
host: "127.0.0.1",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "172.17.0.1 with databasus db",
|
||||
host: "172.17.0.1",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "host.docker.internal with databasus db",
|
||||
host: "host.docker.internal",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "LOCALHOST (uppercase) with DATABASUS db",
|
||||
host: "LOCALHOST",
|
||||
username: "POSTGRES",
|
||||
database: "DATABASUS",
|
||||
},
|
||||
{
|
||||
name: "LocalHost (mixed case) with DataBasus db",
|
||||
host: "LocalHost",
|
||||
username: "anyuser",
|
||||
database: "DataBasus",
|
||||
},
|
||||
{
|
||||
name: "localhost with databasus and any username",
|
||||
host: "localhost",
|
||||
username: "myuser",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "::1 (IPv6 loopback) with databasus db",
|
||||
host: "::1",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: ":: (IPv6 all interfaces) with databasus db",
|
||||
host: "::",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "::1 (uppercase) with DATABASUS db",
|
||||
host: "::1",
|
||||
username: "POSTGRES",
|
||||
database: "DATABASUS",
|
||||
},
|
||||
{
|
||||
name: "0.0.0.0 (all IPv4 interfaces) with databasus db",
|
||||
host: "0.0.0.0",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "127.0.0.2 (loopback range) with databasus db",
|
||||
host: "127.0.0.2",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "127.255.255.255 (end of loopback range) with databasus db",
|
||||
host: "127.255.255.255",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Host: tc.host,
|
||||
Port: 5437,
|
||||
Username: tc.username,
|
||||
Password: "somepassword",
|
||||
Database: &tc.database,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
err := pgModel.Validate()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "backing up Databasus internal database is not allowed")
|
||||
assert.Contains(t, err.Error(), "https://databasus.com/faq#backup-databasus")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Validate_WhenNotLocalhostOrNotDatabasus_ValidatesSuccessfully(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
username string
|
||||
database string
|
||||
}{
|
||||
{
|
||||
name: "different host (remote server) with databasus db",
|
||||
host: "192.168.1.100",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "different database name on localhost",
|
||||
host: "localhost",
|
||||
username: "postgres",
|
||||
database: "myapp",
|
||||
},
|
||||
{
|
||||
name: "all different",
|
||||
host: "db.example.com",
|
||||
username: "appuser",
|
||||
database: "production",
|
||||
},
|
||||
{
|
||||
name: "localhost with postgres database",
|
||||
host: "localhost",
|
||||
username: "postgres",
|
||||
database: "postgres",
|
||||
},
|
||||
{
|
||||
name: "remote host with databasus db name (allowed)",
|
||||
host: "db.example.com",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Host: tc.host,
|
||||
Port: 5432,
|
||||
Username: tc.username,
|
||||
Password: "somepassword",
|
||||
Database: &tc.database,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
err := pgModel.Validate()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Validate_WhenDatabaseIsNil_ValidatesSuccessfully(t *testing.T) {
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Host: "localhost",
|
||||
Port: 5437,
|
||||
Username: "postgres",
|
||||
Password: "somepassword",
|
||||
Database: nil,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
err := pgModel.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenDatabaseIsEmpty_ValidatesSuccessfully(t *testing.T) {
|
||||
emptyDb := ""
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Host: "localhost",
|
||||
Port: 5437,
|
||||
Username: "postgres",
|
||||
Password: "somepassword",
|
||||
Database: &emptyDb,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
err := pgModel.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenRequiredFieldsMissing_ReturnsError(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
model *PostgresqlDatabase
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "missing host",
|
||||
model: &PostgresqlDatabase{
|
||||
Host: "",
|
||||
Port: 5432,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
CpuCount: 1,
|
||||
},
|
||||
expectedError: "host is required",
|
||||
},
|
||||
{
|
||||
name: "missing port",
|
||||
model: &PostgresqlDatabase{
|
||||
Host: "localhost",
|
||||
Port: 0,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
CpuCount: 1,
|
||||
},
|
||||
expectedError: "port is required",
|
||||
},
|
||||
{
|
||||
name: "missing username",
|
||||
model: &PostgresqlDatabase{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "",
|
||||
Password: "pass",
|
||||
CpuCount: 1,
|
||||
},
|
||||
expectedError: "username is required",
|
||||
},
|
||||
{
|
||||
name: "missing password",
|
||||
model: &PostgresqlDatabase{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "user",
|
||||
Password: "",
|
||||
CpuCount: 1,
|
||||
},
|
||||
expectedError: "password is required",
|
||||
},
|
||||
{
|
||||
name: "invalid cpu count",
|
||||
model: &PostgresqlDatabase{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
CpuCount: 0,
|
||||
},
|
||||
expectedError: "cpu count must be greater than 0",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.model.Validate()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tc.expectedError)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type PostgresContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
package task_cancellation
|
||||
|
||||
import (
|
||||
"context"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const taskCancelChannel = "task:cancel"
|
||||
|
||||
type TaskCancelManager struct {
|
||||
mu sync.RWMutex
|
||||
cancelFuncs map[uuid.UUID]context.CancelFunc
|
||||
pubsub *cache_utils.PubSubManager
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (m *TaskCancelManager) StartSubscription() {
|
||||
ctx := context.Background()
|
||||
|
||||
handler := func(message string) {
|
||||
taskID, err := uuid.Parse(message)
|
||||
if err != nil {
|
||||
m.logger.Error("Invalid task ID in cancel message", "message", message, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
cancelFunc, exists := m.cancelFuncs[taskID]
|
||||
if exists {
|
||||
cancelFunc()
|
||||
delete(m.cancelFuncs, taskID)
|
||||
m.logger.Info("Cancelled task via Pub/Sub", "taskID", taskID)
|
||||
}
|
||||
}
|
||||
|
||||
err := m.pubsub.Subscribe(ctx, taskCancelChannel, handler)
|
||||
if err != nil {
|
||||
m.logger.Error("Failed to subscribe to task cancel channel", "error", err)
|
||||
} else {
|
||||
m.logger.Info("Successfully subscribed to task cancel channel")
|
||||
}
|
||||
}
|
||||
|
||||
func (m *TaskCancelManager) RegisterTask(task uuid.UUID, cancelFunc context.CancelFunc) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.cancelFuncs[task] = cancelFunc
|
||||
m.logger.Debug("Registered task", "taskID", task)
|
||||
}
|
||||
|
||||
func (m *TaskCancelManager) CancelTask(taskID uuid.UUID) error {
|
||||
ctx := context.Background()
|
||||
|
||||
err := m.pubsub.Publish(ctx, taskCancelChannel, taskID.String())
|
||||
if err != nil {
|
||||
m.logger.Error("Failed to publish cancel message", "taskID", taskID, "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
m.logger.Info("Published task cancel message", "taskID", taskID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *TaskCancelManager) UnregisterTask(taskID uuid.UUID) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.cancelFuncs, taskID)
|
||||
m.logger.Debug("Unregistered task", "taskID", taskID)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package backups_cancellation
|
||||
package task_cancellation
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -10,41 +10,41 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_RegisterBackup_BackupRegisteredSuccessfully(t *testing.T) {
|
||||
manager := backupCancelManager
|
||||
func Test_RegisterTask_TaskRegisteredSuccessfully(t *testing.T) {
|
||||
manager := taskCancelManager
|
||||
|
||||
backupID := uuid.New()
|
||||
taskID := uuid.New()
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
manager.RegisterBackup(backupID, cancel)
|
||||
manager.RegisterTask(taskID, cancel)
|
||||
|
||||
manager.mu.RLock()
|
||||
_, exists := manager.cancelFuncs[backupID]
|
||||
_, exists := manager.cancelFuncs[taskID]
|
||||
manager.mu.RUnlock()
|
||||
assert.True(t, exists, "Backup should be registered")
|
||||
assert.True(t, exists, "Task should be registered")
|
||||
}
|
||||
|
||||
func Test_UnregisterBackup_BackupUnregisteredSuccessfully(t *testing.T) {
|
||||
manager := backupCancelManager
|
||||
func Test_UnregisterTask_TaskUnregisteredSuccessfully(t *testing.T) {
|
||||
manager := taskCancelManager
|
||||
|
||||
backupID := uuid.New()
|
||||
taskID := uuid.New()
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
manager.RegisterBackup(backupID, cancel)
|
||||
manager.UnregisterBackup(backupID)
|
||||
manager.RegisterTask(taskID, cancel)
|
||||
manager.UnregisterTask(taskID)
|
||||
|
||||
manager.mu.RLock()
|
||||
_, exists := manager.cancelFuncs[backupID]
|
||||
_, exists := manager.cancelFuncs[taskID]
|
||||
manager.mu.RUnlock()
|
||||
assert.False(t, exists, "Backup should be unregistered")
|
||||
assert.False(t, exists, "Task should be unregistered")
|
||||
}
|
||||
|
||||
func Test_CancelBackup_OnSameInstance_BackupCancelledViaPubSub(t *testing.T) {
|
||||
manager := backupCancelManager
|
||||
func Test_CancelTask_OnSameInstance_TaskCancelledViaPubSub(t *testing.T) {
|
||||
manager := taskCancelManager
|
||||
|
||||
backupID := uuid.New()
|
||||
taskID := uuid.New()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
cancelled := false
|
||||
@@ -57,11 +57,11 @@ func Test_CancelBackup_OnSameInstance_BackupCancelledViaPubSub(t *testing.T) {
|
||||
cancel()
|
||||
}
|
||||
|
||||
manager.RegisterBackup(backupID, wrappedCancel)
|
||||
manager.RegisterTask(taskID, wrappedCancel)
|
||||
manager.StartSubscription()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err := manager.CancelBackup(backupID)
|
||||
err := manager.CancelTask(taskID)
|
||||
assert.NoError(t, err, "Cancel should not return error")
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
@@ -74,11 +74,11 @@ func Test_CancelBackup_OnSameInstance_BackupCancelledViaPubSub(t *testing.T) {
|
||||
assert.Error(t, ctx.Err(), "Context should be cancelled")
|
||||
}
|
||||
|
||||
func Test_CancelBackup_FromDifferentInstance_BackupCancelledOnRunningInstance(t *testing.T) {
|
||||
manager1 := backupCancelManager
|
||||
manager2 := backupCancelManager
|
||||
func Test_CancelTask_FromDifferentInstance_TaskCancelledOnRunningInstance(t *testing.T) {
|
||||
manager1 := taskCancelManager
|
||||
manager2 := taskCancelManager
|
||||
|
||||
backupID := uuid.New()
|
||||
taskID := uuid.New()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
cancelled := false
|
||||
@@ -91,13 +91,13 @@ func Test_CancelBackup_FromDifferentInstance_BackupCancelledOnRunningInstance(t
|
||||
cancel()
|
||||
}
|
||||
|
||||
manager1.RegisterBackup(backupID, wrappedCancel)
|
||||
manager1.RegisterTask(taskID, wrappedCancel)
|
||||
|
||||
manager1.StartSubscription()
|
||||
manager2.StartSubscription()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err := manager2.CancelBackup(backupID)
|
||||
err := manager2.CancelTask(taskID)
|
||||
assert.NoError(t, err, "Cancel should not return error")
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
@@ -110,29 +110,29 @@ func Test_CancelBackup_FromDifferentInstance_BackupCancelledOnRunningInstance(t
|
||||
assert.Error(t, ctx.Err(), "Context should be cancelled")
|
||||
}
|
||||
|
||||
func Test_CancelBackup_WhenBackupDoesNotExist_NoErrorReturned(t *testing.T) {
|
||||
manager := backupCancelManager
|
||||
func Test_CancelTask_WhenTaskDoesNotExist_NoErrorReturned(t *testing.T) {
|
||||
manager := taskCancelManager
|
||||
|
||||
manager.StartSubscription()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
nonExistentID := uuid.New()
|
||||
err := manager.CancelBackup(nonExistentID)
|
||||
assert.NoError(t, err, "Cancelling non-existent backup should not error")
|
||||
err := manager.CancelTask(nonExistentID)
|
||||
assert.NoError(t, err, "Cancelling non-existent task should not error")
|
||||
}
|
||||
|
||||
func Test_CancelBackup_WithMultipleBackups_AllBackupsCancelled(t *testing.T) {
|
||||
manager := backupCancelManager
|
||||
func Test_CancelTask_WithMultipleTasks_AllTasksCancelled(t *testing.T) {
|
||||
manager := taskCancelManager
|
||||
|
||||
numBackups := 5
|
||||
backupIDs := make([]uuid.UUID, numBackups)
|
||||
contexts := make([]context.Context, numBackups)
|
||||
cancels := make([]context.CancelFunc, numBackups)
|
||||
cancelledFlags := make([]bool, numBackups)
|
||||
numTasks := 5
|
||||
taskIDs := make([]uuid.UUID, numTasks)
|
||||
contexts := make([]context.Context, numTasks)
|
||||
cancels := make([]context.CancelFunc, numTasks)
|
||||
cancelledFlags := make([]bool, numTasks)
|
||||
var mu sync.Mutex
|
||||
|
||||
for i := 0; i < numBackups; i++ {
|
||||
backupIDs[i] = uuid.New()
|
||||
for i := 0; i < numTasks; i++ {
|
||||
taskIDs[i] = uuid.New()
|
||||
contexts[i], cancels[i] = context.WithCancel(context.Background())
|
||||
|
||||
idx := i
|
||||
@@ -143,31 +143,31 @@ func Test_CancelBackup_WithMultipleBackups_AllBackupsCancelled(t *testing.T) {
|
||||
cancels[idx]()
|
||||
}
|
||||
|
||||
manager.RegisterBackup(backupIDs[i], wrappedCancel)
|
||||
manager.RegisterTask(taskIDs[i], wrappedCancel)
|
||||
}
|
||||
|
||||
manager.StartSubscription()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
for i := 0; i < numBackups; i++ {
|
||||
err := manager.CancelBackup(backupIDs[i])
|
||||
for i := 0; i < numTasks; i++ {
|
||||
err := manager.CancelTask(taskIDs[i])
|
||||
assert.NoError(t, err, "Cancel should not return error")
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
mu.Lock()
|
||||
for i := 0; i < numBackups; i++ {
|
||||
assert.True(t, cancelledFlags[i], "Backup %d should be cancelled", i)
|
||||
for i := 0; i < numTasks; i++ {
|
||||
assert.True(t, cancelledFlags[i], "Task %d should be cancelled", i)
|
||||
assert.Error(t, contexts[i].Err(), "Context %d should be cancelled", i)
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
func Test_CancelBackup_AfterUnregister_BackupNotCancelled(t *testing.T) {
|
||||
manager := backupCancelManager
|
||||
func Test_CancelTask_AfterUnregister_TaskNotCancelled(t *testing.T) {
|
||||
manager := taskCancelManager
|
||||
|
||||
backupID := uuid.New()
|
||||
taskID := uuid.New()
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
@@ -181,13 +181,13 @@ func Test_CancelBackup_AfterUnregister_BackupNotCancelled(t *testing.T) {
|
||||
cancel()
|
||||
}
|
||||
|
||||
manager.RegisterBackup(backupID, wrappedCancel)
|
||||
manager.RegisterTask(taskID, wrappedCancel)
|
||||
manager.StartSubscription()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
manager.UnregisterBackup(backupID)
|
||||
manager.UnregisterTask(taskID)
|
||||
|
||||
err := manager.CancelBackup(backupID)
|
||||
err := manager.CancelTask(taskID)
|
||||
assert.NoError(t, err, "Cancel should not return error")
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
@@ -1,4 +1,4 @@
|
||||
package backups_cancellation
|
||||
package task_cancellation
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -9,17 +9,17 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var backupCancelManager = &BackupCancelManager{
|
||||
var taskCancelManager = &TaskCancelManager{
|
||||
sync.RWMutex{},
|
||||
make(map[uuid.UUID]context.CancelFunc),
|
||||
cache_utils.NewPubSubManager(),
|
||||
logger.GetLogger(),
|
||||
}
|
||||
|
||||
func GetBackupCancelManager() *BackupCancelManager {
|
||||
return backupCancelManager
|
||||
func GetTaskCancelManager() *TaskCancelManager {
|
||||
return taskCancelManager
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
backupCancelManager.StartSubscription()
|
||||
taskCancelManager.StartSubscription()
|
||||
}
|
||||
18
backend/internal/features/tasks/registry/di.go
Normal file
18
backend/internal/features/tasks/registry/di.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package task_registry
|
||||
|
||||
import (
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var taskNodesRegistry = &TaskNodesRegistry{
|
||||
cache_utils.GetValkeyClient(),
|
||||
logger.GetLogger(),
|
||||
cache_utils.DefaultCacheTimeout,
|
||||
cache_utils.NewPubSubManager(),
|
||||
cache_utils.NewPubSubManager(),
|
||||
}
|
||||
|
||||
func GetTaskNodesRegistry() *TaskNodesRegistry {
|
||||
return taskNodesRegistry
|
||||
}
|
||||
29
backend/internal/features/tasks/registry/dto.go
Normal file
29
backend/internal/features/tasks/registry/dto.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package task_registry
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type TaskNode struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ThroughputMBs int `json:"throughputMBs"`
|
||||
LastHeartbeat time.Time `json:"lastHeartbeat"`
|
||||
}
|
||||
|
||||
type TaskNodeStats struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ActiveTasks int `json:"activeTasks"`
|
||||
}
|
||||
|
||||
type TaskSubmitMessage struct {
|
||||
NodeID string `json:"nodeId"`
|
||||
TaskID string `json:"taskId"`
|
||||
IsCallNotifier bool `json:"isCallNotifier"`
|
||||
}
|
||||
|
||||
type TaskCompletionMessage struct {
|
||||
NodeID string `json:"nodeId"`
|
||||
TaskID string `json:"taskId"`
|
||||
}
|
||||
641
backend/internal/features/tasks/registry/registry.go
Normal file
641
backend/internal/features/tasks/registry/registry.go
Normal file
@@ -0,0 +1,641 @@
|
||||
package task_registry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/valkey-io/valkey-go"
|
||||
)
|
||||
|
||||
const (
|
||||
nodeInfoKeyPrefix = "node:"
|
||||
nodeInfoKeySuffix = ":info"
|
||||
nodeActiveTasksPrefix = "node:"
|
||||
nodeActiveTasksSuffix = ":active_tasks"
|
||||
taskSubmitChannel = "task:submit"
|
||||
taskCompletionChannel = "task:completion"
|
||||
|
||||
deadNodeThreshold = 2 * time.Minute
|
||||
cleanupTickerInterval = 1 * time.Second
|
||||
)
|
||||
|
||||
// TaskNodesRegistry helps to sync tasks scheduler (backuping or restoring)
|
||||
// and task nodes which are used for network-intensive tasks processing
|
||||
//
|
||||
// Features:
|
||||
// - Track node availability and load level
|
||||
// - Assign from scheduler to node tasks needed to be processed
|
||||
// - Notify scheduler from node about task completion
|
||||
//
|
||||
// Important things to remember:
|
||||
// - Node can contain different tasks types so when task is assigned
|
||||
// or node's tasks cleaned - should be performed DB check in DB
|
||||
// that task with this ID exists for this task type at all
|
||||
// - Nodes without heathbeat for more than 2 minutes are not included
|
||||
// in available nodes list and stats
|
||||
//
|
||||
// Cleanup dead nodes performed on 2 levels:
|
||||
// - List and stats functions do not return dead nodes
|
||||
// - Periodically dead nodes are cleaned up in cache (to not
|
||||
// accumulate too many dead nodes in cache)
|
||||
type TaskNodesRegistry struct {
|
||||
client valkey.Client
|
||||
logger *slog.Logger
|
||||
timeout time.Duration
|
||||
pubsubTasks *cache_utils.PubSubManager
|
||||
pubsubCompletions *cache_utils.PubSubManager
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) Run(ctx context.Context) {
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(cleanupTickerInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) GetAvailableNodes() ([]TaskNode, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
var allKeys []string
|
||||
cursor := uint64(0)
|
||||
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
|
||||
|
||||
for {
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return nil, fmt.Errorf("failed to scan node keys: %w", result.Error())
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse scan result: %w", err)
|
||||
}
|
||||
|
||||
allKeys = append(allKeys, scanResult.Elements...)
|
||||
|
||||
cursor = scanResult.Cursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return []TaskNode{}, nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pipeline get node keys: %w", err)
|
||||
}
|
||||
|
||||
threshold := time.Now().UTC().Add(-deadNodeThreshold)
|
||||
var nodes []TaskNode
|
||||
for key, data := range keyDataMap {
|
||||
// Skip if the key doesn't exist (data is empty)
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node TaskNode
|
||||
if err := json.Unmarshal(data, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data", "key", key, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip nodes with zero/uninitialized heartbeat
|
||||
if node.LastHeartbeat.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.LastHeartbeat.Before(threshold) {
|
||||
continue
|
||||
}
|
||||
|
||||
nodes = append(nodes, node)
|
||||
}
|
||||
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) GetNodesStats() ([]TaskNodeStats, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
var allKeys []string
|
||||
cursor := uint64(0)
|
||||
pattern := nodeActiveTasksPrefix + "*" + nodeActiveTasksSuffix
|
||||
|
||||
for {
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(100).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return nil, fmt.Errorf("failed to scan active tasks keys: %w", result.Error())
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse scan result: %w", err)
|
||||
}
|
||||
|
||||
allKeys = append(allKeys, scanResult.Elements...)
|
||||
|
||||
cursor = scanResult.Cursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return []TaskNodeStats{}, nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pipeline get active tasks keys: %w", err)
|
||||
}
|
||||
|
||||
var nodeInfoKeys []string
|
||||
nodeIDToStatsKey := make(map[string]string)
|
||||
for key := range keyDataMap {
|
||||
nodeID := r.extractNodeIDFromKey(key, nodeActiveTasksPrefix, nodeActiveTasksSuffix)
|
||||
nodeIDStr := nodeID.String()
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeIDStr, nodeInfoKeySuffix)
|
||||
nodeInfoKeys = append(nodeInfoKeys, infoKey)
|
||||
nodeIDToStatsKey[infoKey] = key
|
||||
}
|
||||
|
||||
nodeInfoMap, err := r.pipelineGetKeys(nodeInfoKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pipeline get node info keys: %w", err)
|
||||
}
|
||||
|
||||
threshold := time.Now().UTC().Add(-deadNodeThreshold)
|
||||
var stats []TaskNodeStats
|
||||
for infoKey, nodeData := range nodeInfoMap {
|
||||
// Skip if the info key doesn't exist (nodeData is empty)
|
||||
if len(nodeData) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node TaskNode
|
||||
if err := json.Unmarshal(nodeData, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data", "key", infoKey, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip nodes with zero/uninitialized heartbeat
|
||||
if node.LastHeartbeat.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.LastHeartbeat.Before(threshold) {
|
||||
continue
|
||||
}
|
||||
|
||||
statsKey := nodeIDToStatsKey[infoKey]
|
||||
tasksData := keyDataMap[statsKey]
|
||||
count, err := r.parseIntFromBytes(tasksData)
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse active tasks count", "key", statsKey, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
stat := TaskNodeStats{
|
||||
ID: node.ID,
|
||||
ActiveTasks: int(count),
|
||||
}
|
||||
stats = append(stats, stat)
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) IncrementTasksInProgress(nodeID string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeActiveTasksPrefix, nodeID, nodeActiveTasksSuffix)
|
||||
result := r.client.Do(ctx, r.client.B().Incr().Key(key).Build())
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to increment tasks in progress for node %s: %w",
|
||||
nodeID,
|
||||
result.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) DecrementTasksInProgress(nodeID string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeActiveTasksPrefix, nodeID, nodeActiveTasksSuffix)
|
||||
result := r.client.Do(ctx, r.client.B().Decr().Key(key).Build())
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to decrement tasks in progress for node %s: %w",
|
||||
nodeID,
|
||||
result.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
newValue, err := result.AsInt64()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse decremented value for node %s: %w", nodeID, err)
|
||||
}
|
||||
|
||||
if newValue < 0 {
|
||||
setCtx, setCancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
r.client.Do(setCtx, r.client.B().Set().Key(key).Value("0").Build())
|
||||
setCancel()
|
||||
r.logger.Warn("Active tasks counter went below 0, reset to 0", "nodeID", nodeID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) HearthbeatNodeInRegistry(now time.Time, node TaskNode) error {
|
||||
if now.IsZero() {
|
||||
return fmt.Errorf("cannot register node with zero heartbeat timestamp")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
node.LastHeartbeat = now
|
||||
|
||||
data, err := json.Marshal(node)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal node: %w", err)
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node.ID.String(), nodeInfoKeySuffix)
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Set().Key(key).Value(string(data)).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to register node %s: %w", node.ID, result.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) UnregisterNodeFromRegistry(node TaskNode) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node.ID.String(), nodeInfoKeySuffix)
|
||||
counterKey := fmt.Sprintf(
|
||||
"%s%s%s",
|
||||
nodeActiveTasksPrefix,
|
||||
node.ID.String(),
|
||||
nodeActiveTasksSuffix,
|
||||
)
|
||||
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Del().Key(infoKey, counterKey).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to unregister node %s: %w", node.ID, result.Error())
|
||||
}
|
||||
|
||||
r.logger.Info("Unregistered node from registry", "nodeID", node.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) AssignTaskToNode(
|
||||
targetNodeID string,
|
||||
taskID uuid.UUID,
|
||||
isCallNotifier bool,
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
message := TaskSubmitMessage{
|
||||
NodeID: targetNodeID,
|
||||
TaskID: taskID.String(),
|
||||
IsCallNotifier: isCallNotifier,
|
||||
}
|
||||
|
||||
messageJSON, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal task submit message: %w", err)
|
||||
}
|
||||
|
||||
err = r.pubsubTasks.Publish(ctx, taskSubmitChannel, string(messageJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to publish task submit message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) SubscribeNodeForTasksAssignment(
|
||||
nodeID string,
|
||||
handler func(taskID uuid.UUID, isCallNotifier bool),
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
wrappedHandler := func(message string) {
|
||||
var msg TaskSubmitMessage
|
||||
if err := json.Unmarshal([]byte(message), &msg); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal task submit message", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if msg.NodeID != nodeID {
|
||||
return
|
||||
}
|
||||
|
||||
taskID, err := uuid.Parse(msg.TaskID)
|
||||
if err != nil {
|
||||
r.logger.Warn(
|
||||
"Failed to parse task ID from message",
|
||||
"taskId",
|
||||
msg.TaskID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
handler(taskID, msg.IsCallNotifier)
|
||||
}
|
||||
|
||||
err := r.pubsubTasks.Subscribe(ctx, taskSubmitChannel, wrappedHandler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to task submit channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Subscribed to task submit channel", "nodeID", nodeID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) UnsubscribeNodeForTasksAssignments() error {
|
||||
err := r.pubsubTasks.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unsubscribe from task submit channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Unsubscribed from task submit channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) PublishTaskCompletion(nodeID string, taskID uuid.UUID) error {
|
||||
ctx := context.Background()
|
||||
|
||||
message := TaskCompletionMessage{
|
||||
NodeID: nodeID,
|
||||
TaskID: taskID.String(),
|
||||
}
|
||||
|
||||
messageJSON, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal task completion message: %w", err)
|
||||
}
|
||||
|
||||
err = r.pubsubCompletions.Publish(ctx, taskCompletionChannel, string(messageJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to publish task completion message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) SubscribeForTasksCompletions(
|
||||
handler func(nodeID string, taskID uuid.UUID),
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
wrappedHandler := func(message string) {
|
||||
var msg TaskCompletionMessage
|
||||
if err := json.Unmarshal([]byte(message), &msg); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal task completion message", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
taskID, err := uuid.Parse(msg.TaskID)
|
||||
if err != nil {
|
||||
r.logger.Warn(
|
||||
"Failed to parse task ID from completion message",
|
||||
"taskId",
|
||||
msg.TaskID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
handler(msg.NodeID, taskID)
|
||||
}
|
||||
|
||||
err := r.pubsubCompletions.Subscribe(ctx, taskCompletionChannel, wrappedHandler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to task completion channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Subscribed to task completion channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) UnsubscribeForTasksCompletions() error {
|
||||
err := r.pubsubCompletions.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unsubscribe from task completion channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Unsubscribed from task completion channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uuid.UUID {
|
||||
nodeIDStr := strings.TrimPrefix(key, prefix)
|
||||
nodeIDStr = strings.TrimSuffix(nodeIDStr, suffix)
|
||||
|
||||
nodeID, err := uuid.Parse(nodeIDStr)
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse node ID from key", "key", key, "error", err)
|
||||
return uuid.Nil
|
||||
}
|
||||
|
||||
return nodeID
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, error) {
|
||||
if len(keys) == 0 {
|
||||
return make(map[string][]byte), nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
commands := make([]valkey.Completed, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
commands = append(commands, r.client.B().Get().Key(key).Build())
|
||||
}
|
||||
|
||||
results := r.client.DoMulti(ctx, commands...)
|
||||
|
||||
keyDataMap := make(map[string][]byte, len(keys))
|
||||
for i, result := range results {
|
||||
if result.Error() != nil {
|
||||
r.logger.Warn("Failed to get key in pipeline", "key", keys[i], "error", result.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
data, err := result.AsBytes()
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse key data in pipeline", "key", keys[i], "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
keyDataMap[keys[i]] = data
|
||||
}
|
||||
|
||||
return keyDataMap, nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
|
||||
str := string(data)
|
||||
var count int64
|
||||
_, err := fmt.Sscanf(str, "%d", &count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to parse integer from bytes: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) cleanupDeadNodes() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
var allKeys []string
|
||||
cursor := uint64(0)
|
||||
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
|
||||
|
||||
for {
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to scan node keys: %w", result.Error())
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse scan result: %w", err)
|
||||
}
|
||||
|
||||
allKeys = append(allKeys, scanResult.Elements...)
|
||||
|
||||
cursor = scanResult.Cursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to pipeline get node keys: %w", err)
|
||||
}
|
||||
|
||||
threshold := time.Now().UTC().Add(-deadNodeThreshold)
|
||||
var deadNodeKeys []string
|
||||
|
||||
for key, data := range keyDataMap {
|
||||
|
||||
// Skip if the key doesn't exist (data is empty)
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node TaskNode
|
||||
if err := json.Unmarshal(data, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data during cleanup", "key", key, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip nodes with zero/uninitialized heartbeat
|
||||
if node.LastHeartbeat.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.LastHeartbeat.Before(threshold) {
|
||||
nodeID := node.ID.String()
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeID, nodeInfoKeySuffix)
|
||||
statsKey := fmt.Sprintf("%s%s%s", nodeActiveTasksPrefix, nodeID, nodeActiveTasksSuffix)
|
||||
|
||||
deadNodeKeys = append(deadNodeKeys, infoKey, statsKey)
|
||||
r.logger.Info(
|
||||
"Marking node for cleanup",
|
||||
"nodeID", nodeID,
|
||||
"lastHeartbeat", node.LastHeartbeat,
|
||||
"threshold", threshold,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if len(deadNodeKeys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
delCtx, delCancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer delCancel()
|
||||
|
||||
result := r.client.Do(
|
||||
delCtx,
|
||||
r.client.B().Del().Key(deadNodeKeys...).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to delete dead node keys: %w", result.Error())
|
||||
}
|
||||
|
||||
deletedCount, err := result.AsInt64()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse deleted count: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Cleaned up dead nodes", "deletedKeysCount", deletedCount)
|
||||
return nil
|
||||
}
|
||||
1170
backend/internal/features/tasks/registry/registry_test.go
Normal file
1170
backend/internal/features/tasks/registry/registry_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -2,13 +2,12 @@ package users_controllers
|
||||
|
||||
import (
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
)
|
||||
|
||||
var userController = &UserController{
|
||||
users_services.GetUserService(),
|
||||
rate.NewLimiter(rate.Limit(3), 3), // 3 rps with 3 burst
|
||||
cache_utils.NewRateLimiter(cache_utils.GetValkeyClient()),
|
||||
}
|
||||
|
||||
var settingsController = &SettingsController{
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
func Test_AdminLifecycleE2E_CompletesSuccessfully(t *testing.T) {
|
||||
@@ -185,7 +184,6 @@ func createUserTestRouter() *gin.Engine {
|
||||
// Register protected routes with auth middleware
|
||||
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
|
||||
GetUserController().RegisterProtectedRoutes(protected.(*gin.RouterGroup))
|
||||
GetUserController().SetSignInLimiter(rate.NewLimiter(rate.Limit(100), 100))
|
||||
|
||||
// Setup audit log service
|
||||
users_services.GetUserService().SetAuditLogWriter(&AuditLogWriterStub{})
|
||||
|
||||
@@ -3,20 +3,21 @@ package users_controllers
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
user_dto "databasus-backend/internal/features/users/dto"
|
||||
users_errors "databasus-backend/internal/features/users/errors"
|
||||
user_middleware "databasus-backend/internal/features/users/middleware"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type UserController struct {
|
||||
userService *users_services.UserService
|
||||
signinLimiter *rate.Limiter
|
||||
userService *users_services.UserService
|
||||
rateLimiter *cache_utils.RateLimiter
|
||||
}
|
||||
|
||||
func (c *UserController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
@@ -39,10 +40,6 @@ func (c *UserController) RegisterProtectedRoutes(router *gin.RouterGroup) {
|
||||
router.POST("/users/invite", c.InviteUser)
|
||||
}
|
||||
|
||||
func (c *UserController) SetSignInLimiter(limiter *rate.Limiter) {
|
||||
c.signinLimiter = limiter
|
||||
}
|
||||
|
||||
// SignUp
|
||||
// @Summary Register a new user
|
||||
// @Description Register a new user with email and password
|
||||
@@ -81,8 +78,14 @@ func (c *UserController) SignUp(ctx *gin.Context) {
|
||||
// @Failure 429 {object} map[string]string "Rate limit exceeded"
|
||||
// @Router /users/signin [post]
|
||||
func (c *UserController) SignIn(ctx *gin.Context) {
|
||||
// We use rate limiter to prevent brute force attacks
|
||||
if !c.signinLimiter.Allow() {
|
||||
var request user_dto.SignInRequestDTO
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"})
|
||||
return
|
||||
}
|
||||
|
||||
allowed, _ := c.rateLimiter.CheckLimit(request.Email, "signin", 10, 1*time.Minute)
|
||||
if !allowed {
|
||||
ctx.JSON(
|
||||
http.StatusTooManyRequests,
|
||||
gin.H{"error": "Rate limit exceeded. Please try again later."},
|
||||
@@ -90,12 +93,6 @@ func (c *UserController) SignIn(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var request user_dto.SignInRequestDTO
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.userService.SignIn(&request)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
|
||||
@@ -1146,3 +1146,48 @@ func Test_GoogleOAuth_WithInvitedUser_ActivatesUser(t *testing.T) {
|
||||
assert.Equal(t, email, response.Email)
|
||||
assert.False(t, response.IsNewUser)
|
||||
}
|
||||
|
||||
func Test_SignIn_WithExcessiveAttempts_RateLimitEnforced(t *testing.T) {
|
||||
router := createUserTestRouter()
|
||||
email := "ratelimit" + uuid.New().String() + "@example.com"
|
||||
password := "testpassword123"
|
||||
|
||||
// Create a user first
|
||||
signupRequest := users_dto.SignUpRequestDTO{
|
||||
Email: email,
|
||||
Password: password,
|
||||
Name: "Rate Limit Test User",
|
||||
}
|
||||
test_utils.MakePostRequest(t, router, "/api/v1/users/signup", "", signupRequest, http.StatusOK)
|
||||
|
||||
// Make 10 sign-in attempts (should succeed)
|
||||
for range 10 {
|
||||
signinRequest := users_dto.SignInRequestDTO{
|
||||
Email: email,
|
||||
Password: password,
|
||||
}
|
||||
test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/users/signin",
|
||||
"",
|
||||
signinRequest,
|
||||
http.StatusOK,
|
||||
)
|
||||
}
|
||||
|
||||
// 11th attempt should be rate limited
|
||||
signinRequest := users_dto.SignInRequestDTO{
|
||||
Email: email,
|
||||
Password: password,
|
||||
}
|
||||
resp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/users/signin",
|
||||
"",
|
||||
signinRequest,
|
||||
http.StatusTooManyRequests,
|
||||
)
|
||||
assert.Contains(t, string(resp.Body), "Rate limit exceeded")
|
||||
}
|
||||
|
||||
51
backend/internal/util/cache/cache_test.go
vendored
Normal file
51
backend/internal/util/cache/cache_test.go
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
package cache_utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_ClearAllCache_AfterClear_CacheIsEmpty(t *testing.T) {
|
||||
client := getCache()
|
||||
|
||||
// Arrange: Set up multiple cache entries with different prefixes
|
||||
testKeys := []struct {
|
||||
prefix string
|
||||
key string
|
||||
value string
|
||||
}{
|
||||
{"test:user:", "user1", "John Doe"},
|
||||
{"test:user:", "user2", "Jane Smith"},
|
||||
{"test:session:", "session1", "abc123"},
|
||||
{"test:session:", "session2", "def456"},
|
||||
{"test:data:", "item1", "value1"},
|
||||
}
|
||||
|
||||
// Set all test keys
|
||||
for _, tk := range testKeys {
|
||||
cacheUtil := NewCacheUtil[string](client, tk.prefix)
|
||||
cacheUtil.Set(tk.key, &tk.value)
|
||||
}
|
||||
|
||||
// Verify keys were set correctly before clearing
|
||||
for _, tk := range testKeys {
|
||||
cacheUtil := NewCacheUtil[string](client, tk.prefix)
|
||||
retrieved := cacheUtil.Get(tk.key)
|
||||
assert.NotNil(t, retrieved, "Key %s should exist before clearing", tk.prefix+tk.key)
|
||||
assert.Equal(t, tk.value, *retrieved, "Retrieved value should match set value")
|
||||
}
|
||||
|
||||
// Act: Clear all cache
|
||||
err := ClearAllCache()
|
||||
|
||||
// Assert: No error returned
|
||||
assert.NoError(t, err, "ClearAllCache should not return an error")
|
||||
|
||||
// Assert: All keys should be deleted
|
||||
for _, tk := range testKeys {
|
||||
cacheUtil := NewCacheUtil[string](client, tk.prefix)
|
||||
retrieved := cacheUtil.Get(tk.key)
|
||||
assert.Nil(t, retrieved, "Key %s should be deleted after clearing", tk.prefix+tk.key)
|
||||
}
|
||||
}
|
||||
85
backend/internal/util/cache/rate_limiter.go
vendored
Normal file
85
backend/internal/util/cache/rate_limiter.go
vendored
Normal file
@@ -0,0 +1,85 @@
|
||||
package cache_utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/valkey-io/valkey-go"
|
||||
)
|
||||
|
||||
type RateLimiter struct {
|
||||
client valkey.Client
|
||||
}
|
||||
|
||||
func NewRateLimiter(client valkey.Client) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RateLimiter) CheckLimit(
|
||||
identifier string,
|
||||
endpoint string,
|
||||
maxRequests int,
|
||||
windowDuration time.Duration,
|
||||
) (bool, error) {
|
||||
requestID := uuid.New().String()
|
||||
keyPrefix := fmt.Sprintf("ratelimit:%s:%s", endpoint, identifier)
|
||||
fullKey := fmt.Sprintf("%s:%s", keyPrefix, requestID)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), DefaultCacheTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Set the key with TTL
|
||||
setCmd := r.client.B().
|
||||
Set().
|
||||
Key(fullKey).
|
||||
Value("1").
|
||||
ExSeconds(int64(windowDuration.Seconds())).
|
||||
Build()
|
||||
if err := r.client.Do(ctx, setCmd).Error(); err != nil {
|
||||
return true, fmt.Errorf("failed to set rate limit key: %w", err)
|
||||
}
|
||||
|
||||
// Count keys matching the pattern
|
||||
count, err := r.countKeys(keyPrefix)
|
||||
if err != nil {
|
||||
return true, fmt.Errorf("failed to count rate limit keys: %w", err)
|
||||
}
|
||||
|
||||
return count <= maxRequests, nil
|
||||
}
|
||||
|
||||
func (r *RateLimiter) countKeys(keyPrefix string) (int, error) {
|
||||
pattern := keyPrefix + ":*"
|
||||
cursor := uint64(0)
|
||||
totalCount := 0
|
||||
|
||||
for {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), DefaultCacheTimeout)
|
||||
|
||||
scanCmd := r.client.B().Scan().Cursor(cursor).Match(pattern).Count(100).Build()
|
||||
result := r.client.Do(ctx, scanCmd)
|
||||
cancel()
|
||||
|
||||
if result.Error() != nil {
|
||||
return 0, result.Error()
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
totalCount += len(scanResult.Elements)
|
||||
cursor = scanResult.Cursor
|
||||
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return totalCount, nil
|
||||
}
|
||||
@@ -87,17 +87,20 @@ func VerifyMariadbInstallation(
|
||||
logger *slog.Logger,
|
||||
envMode env_utils.EnvMode,
|
||||
mariadbInstallDir string,
|
||||
isShowLogs bool,
|
||||
) {
|
||||
clientVersions := []MariadbClientVersion{MariadbClientLegacy, MariadbClientModern}
|
||||
|
||||
for _, clientVersion := range clientVersions {
|
||||
binDir := getMariadbBasePath(clientVersion, envMode, mariadbInstallDir)
|
||||
|
||||
logger.Info(
|
||||
"Verifying MariaDB installation",
|
||||
"clientVersion", clientVersion,
|
||||
"path", binDir,
|
||||
)
|
||||
if isShowLogs {
|
||||
logger.Info(
|
||||
"Verifying MariaDB installation",
|
||||
"clientVersion", clientVersion,
|
||||
"path", binDir,
|
||||
)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(binDir); os.IsNotExist(err) {
|
||||
if envMode == env_utils.EnvModeDevelopment {
|
||||
@@ -133,12 +136,14 @@ func VerifyMariadbInstallation(
|
||||
}
|
||||
cmdPath := GetMariadbExecutable(cmd, dummyServerVersion, envMode, mariadbInstallDir)
|
||||
|
||||
logger.Info(
|
||||
"Checking for MariaDB command",
|
||||
"clientVersion", clientVersion,
|
||||
"command", cmd,
|
||||
"path", cmdPath,
|
||||
)
|
||||
if isShowLogs {
|
||||
logger.Info(
|
||||
"Checking for MariaDB command",
|
||||
"clientVersion", clientVersion,
|
||||
"command", cmd,
|
||||
"path", cmdPath,
|
||||
)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(cmdPath); os.IsNotExist(err) {
|
||||
if envMode == env_utils.EnvModeDevelopment {
|
||||
@@ -162,11 +167,15 @@ func VerifyMariadbInstallation(
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Info("MariaDB command found", "clientVersion", clientVersion, "command", cmd)
|
||||
if isShowLogs {
|
||||
logger.Info("MariaDB command found", "clientVersion", clientVersion, "command", cmd)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("MariaDB client tools verification completed!")
|
||||
if isShowLogs {
|
||||
logger.Info("MariaDB client tools verification completed!")
|
||||
}
|
||||
}
|
||||
|
||||
// IsMariadbBackupVersionHigherThanRestoreVersion checks if backup was made with
|
||||
|
||||
@@ -53,13 +53,16 @@ func VerifyMongodbInstallation(
|
||||
logger *slog.Logger,
|
||||
envMode env_utils.EnvMode,
|
||||
mongodbInstallDir string,
|
||||
isShowLogs bool,
|
||||
) {
|
||||
binDir := getMongodbBasePath(envMode, mongodbInstallDir)
|
||||
|
||||
logger.Info(
|
||||
"Verifying MongoDB Database Tools installation",
|
||||
"path", binDir,
|
||||
)
|
||||
if isShowLogs {
|
||||
logger.Info(
|
||||
"Verifying MongoDB Database Tools installation",
|
||||
"path", binDir,
|
||||
)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(binDir); os.IsNotExist(err) {
|
||||
if envMode == env_utils.EnvModeDevelopment {
|
||||
@@ -85,11 +88,13 @@ func VerifyMongodbInstallation(
|
||||
for _, cmd := range requiredCommands {
|
||||
cmdPath := GetMongodbExecutable(cmd, envMode, mongodbInstallDir)
|
||||
|
||||
logger.Info(
|
||||
"Checking for MongoDB command",
|
||||
"command", cmd,
|
||||
"path", cmdPath,
|
||||
)
|
||||
if isShowLogs {
|
||||
logger.Info(
|
||||
"Checking for MongoDB command",
|
||||
"command", cmd,
|
||||
"path", cmdPath,
|
||||
)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(cmdPath); os.IsNotExist(err) {
|
||||
if envMode == env_utils.EnvModeDevelopment {
|
||||
@@ -110,10 +115,14 @@ func VerifyMongodbInstallation(
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Info("MongoDB command found", "command", cmd)
|
||||
if isShowLogs {
|
||||
logger.Info("MongoDB command found", "command", cmd)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("MongoDB Database Tools verification completed!")
|
||||
if isShowLogs {
|
||||
logger.Info("MongoDB Database Tools verification completed!")
|
||||
}
|
||||
}
|
||||
|
||||
// IsMongodbBackupVersionHigherThanRestoreVersion checks if backup was made with
|
||||
|
||||
@@ -55,6 +55,7 @@ func VerifyMysqlInstallation(
|
||||
logger *slog.Logger,
|
||||
envMode env_utils.EnvMode,
|
||||
mysqlInstallDir string,
|
||||
isShowLogs bool,
|
||||
) {
|
||||
versions := []MysqlVersion{
|
||||
MysqlVersion57,
|
||||
@@ -71,13 +72,15 @@ func VerifyMysqlInstallation(
|
||||
for _, version := range versions {
|
||||
binDir := getMysqlBasePath(version, envMode, mysqlInstallDir)
|
||||
|
||||
logger.Info(
|
||||
"Verifying MySQL installation",
|
||||
"version",
|
||||
string(version),
|
||||
"path",
|
||||
binDir,
|
||||
)
|
||||
if isShowLogs {
|
||||
logger.Info(
|
||||
"Verifying MySQL installation",
|
||||
"version",
|
||||
string(version),
|
||||
"path",
|
||||
binDir,
|
||||
)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(binDir); os.IsNotExist(err) {
|
||||
if envMode == env_utils.EnvModeDevelopment {
|
||||
@@ -108,15 +111,17 @@ func VerifyMysqlInstallation(
|
||||
mysqlInstallDir,
|
||||
)
|
||||
|
||||
logger.Info(
|
||||
"Checking for MySQL command",
|
||||
"command",
|
||||
cmd,
|
||||
"version",
|
||||
string(version),
|
||||
"path",
|
||||
cmdPath,
|
||||
)
|
||||
if isShowLogs {
|
||||
logger.Info(
|
||||
"Checking for MySQL command",
|
||||
"command",
|
||||
cmd,
|
||||
"version",
|
||||
string(version),
|
||||
"path",
|
||||
cmdPath,
|
||||
)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(cmdPath); os.IsNotExist(err) {
|
||||
if envMode == env_utils.EnvModeDevelopment {
|
||||
@@ -143,25 +148,31 @@ func VerifyMysqlInstallation(
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Info(
|
||||
"MySQL command found",
|
||||
"command",
|
||||
cmd,
|
||||
"version",
|
||||
string(version),
|
||||
)
|
||||
if isShowLogs {
|
||||
logger.Info(
|
||||
"MySQL command found",
|
||||
"command",
|
||||
cmd,
|
||||
"version",
|
||||
string(version),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info(
|
||||
"Installation of MySQL verified",
|
||||
"version",
|
||||
string(version),
|
||||
"path",
|
||||
binDir,
|
||||
)
|
||||
if isShowLogs {
|
||||
logger.Info(
|
||||
"Installation of MySQL verified",
|
||||
"version",
|
||||
string(version),
|
||||
"path",
|
||||
binDir,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("MySQL version-specific client tools verification completed!")
|
||||
if isShowLogs {
|
||||
logger.Info("MySQL version-specific client tools verification completed!")
|
||||
}
|
||||
}
|
||||
|
||||
// IsMysqlBackupVersionHigherThanRestoreVersion checks if backup was made with
|
||||
|
||||
@@ -40,6 +40,7 @@ func VerifyPostgresesInstallation(
|
||||
logger *slog.Logger,
|
||||
envMode env_utils.EnvMode,
|
||||
postgresesInstallDir string,
|
||||
isShowLogs bool,
|
||||
) {
|
||||
versions := []PostgresqlVersion{
|
||||
PostgresqlVersion12,
|
||||
@@ -59,13 +60,15 @@ func VerifyPostgresesInstallation(
|
||||
for _, version := range versions {
|
||||
binDir := getPostgresqlBasePath(version, envMode, postgresesInstallDir)
|
||||
|
||||
logger.Info(
|
||||
"Verifying PostgreSQL installation",
|
||||
"version",
|
||||
string(version),
|
||||
"path",
|
||||
binDir,
|
||||
)
|
||||
if isShowLogs {
|
||||
logger.Info(
|
||||
"Verifying PostgreSQL installation",
|
||||
"version",
|
||||
string(version),
|
||||
"path",
|
||||
binDir,
|
||||
)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(binDir); os.IsNotExist(err) {
|
||||
if envMode == env_utils.EnvModeDevelopment {
|
||||
@@ -96,15 +99,17 @@ func VerifyPostgresesInstallation(
|
||||
postgresesInstallDir,
|
||||
)
|
||||
|
||||
logger.Info(
|
||||
"Checking for PostgreSQL command",
|
||||
"command",
|
||||
cmd,
|
||||
"version",
|
||||
string(version),
|
||||
"path",
|
||||
cmdPath,
|
||||
)
|
||||
if isShowLogs {
|
||||
logger.Info(
|
||||
"Checking for PostgreSQL command",
|
||||
"command",
|
||||
cmd,
|
||||
"version",
|
||||
string(version),
|
||||
"path",
|
||||
cmdPath,
|
||||
)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(cmdPath); os.IsNotExist(err) {
|
||||
if envMode == env_utils.EnvModeDevelopment {
|
||||
@@ -131,25 +136,33 @@ func VerifyPostgresesInstallation(
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
logger.Info(
|
||||
"PostgreSQL command found",
|
||||
"command",
|
||||
cmd,
|
||||
"version",
|
||||
string(version),
|
||||
)
|
||||
if isShowLogs {
|
||||
logger.Info(
|
||||
"PostgreSQL command found",
|
||||
"command",
|
||||
cmd,
|
||||
"version",
|
||||
string(version),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info(
|
||||
"Installation of PostgreSQL verified",
|
||||
"version",
|
||||
string(version),
|
||||
"path",
|
||||
binDir,
|
||||
)
|
||||
if isShowLogs {
|
||||
logger.Info(
|
||||
"Installation of PostgreSQL verified",
|
||||
"version",
|
||||
string(version),
|
||||
"path",
|
||||
binDir,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("All PostgreSQL version-specific client tools verification completed successfully!")
|
||||
if isShowLogs {
|
||||
logger.Info(
|
||||
"All PostgreSQL version-specific client tools verification completed successfully!",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// EscapePgpassField escapes special characters in a field value for .pgpass file format.
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
Write ReactComponent with the following structure:
|
||||
|
||||
interface Props {
|
||||
someValue: SomeValue;
|
||||
}
|
||||
|
||||
const someHelperFunction = () => {
|
||||
...
|
||||
}
|
||||
|
||||
export const ReactComponent = ({ someValue }: Props): JSX.Element => {
|
||||
// first put states
|
||||
const [someState, setSomeState] = useState<...>(...)
|
||||
|
||||
// then place functions
|
||||
const loadSomeData = async () => {
|
||||
...
|
||||
}
|
||||
|
||||
// then hooks
|
||||
useEffect(() => {
|
||||
loadSomeData();
|
||||
});
|
||||
|
||||
// then calculated values
|
||||
const calculatedValue = someValue.calculate();
|
||||
|
||||
return <div> ... </div>
|
||||
}
|
||||
Reference in New Issue
Block a user