Merge pull request #271 from databasus/develop

Develop
This commit is contained in:
Rostislav Dugin
2026-01-15 21:19:55 +03:00
committed by GitHub
41 changed files with 4346 additions and 3070 deletions

View File

@@ -19,16 +19,6 @@ jobs:
with:
go-version: "1.24.9"
- name: Cache Go modules
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
~/.cache/go-build
key: ${{ runner.os }}-go-${{ hashFiles('backend/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Install golangci-lint
run: |
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/HEAD/install.sh | sh -s -- -b $(go env GOPATH)/bin v2.7.2
@@ -63,8 +53,6 @@ jobs:
uses: actions/setup-node@v4
with:
node-version: "20"
cache: "npm"
cache-dependency-path: frontend/package-lock.json
- name: Install dependencies
run: |
@@ -93,8 +81,6 @@ jobs:
uses: actions/setup-node@v4
with:
node-version: "20"
cache: "npm"
cache-dependency-path: frontend/package-lock.json
- name: Install dependencies
run: |
@@ -136,16 +122,6 @@ jobs:
with:
go-version: "1.24.9"
- name: Cache Go modules
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
~/.cache/go-build
key: ${{ runner.os }}-go-${{ hashFiles('backend/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Create .env file for testing
run: |
cd backend
@@ -321,34 +297,6 @@ jobs:
mkdir -p databasus-data/backups
mkdir -p databasus-data/temp
- name: Cache PostgreSQL client tools
id: cache-postgres
uses: actions/cache@v4
with:
path: /usr/lib/postgresql
key: postgres-clients-12-18-v1
- name: Cache MySQL client tools
id: cache-mysql
uses: actions/cache@v4
with:
path: backend/tools/mysql
key: mysql-clients-57-80-84-9-v1
- name: Cache MariaDB client tools
id: cache-mariadb
uses: actions/cache@v4
with:
path: backend/tools/mariadb
key: mariadb-clients-106-121-v1
- name: Cache MongoDB Database Tools
id: cache-mongodb
uses: actions/cache@v4
with:
path: backend/tools/mongodb
key: mongodb-database-tools-100.10.0-v1
- name: Install MySQL dependencies
run: |
sudo apt-get update -qq
@@ -356,31 +304,53 @@ jobs:
sudo ln -sf /usr/lib/x86_64-linux-gnu/libncurses.so.6 /usr/lib/x86_64-linux-gnu/libncurses.so.5
sudo ln -sf /usr/lib/x86_64-linux-gnu/libtinfo.so.6 /usr/lib/x86_64-linux-gnu/libtinfo.so.5
- name: Install PostgreSQL, MySQL, MariaDB and MongoDB client tools
if: steps.cache-postgres.outputs.cache-hit != 'true' || steps.cache-mysql.outputs.cache-hit != 'true' || steps.cache-mariadb.outputs.cache-hit != 'true' || steps.cache-mongodb.outputs.cache-hit != 'true'
run: |
chmod +x backend/tools/download_linux.sh
cd backend/tools
./download_linux.sh
- name: Setup PostgreSQL symlinks (when using cache)
if: steps.cache-postgres.outputs.cache-hit == 'true'
- name: Setup PostgreSQL, MySQL and MariaDB client tools from pre-built assets
run: |
cd backend/tools
mkdir -p postgresql
# Create directory structure
mkdir -p postgresql mysql mariadb mongodb/bin
# Copy PostgreSQL client tools (12-18) from pre-built assets
for version in 12 13 14 15 16 17 18; do
version_dir="postgresql/postgresql-$version"
mkdir -p "$version_dir/bin"
pg_bin_dir="/usr/lib/postgresql/$version/bin"
if [ -d "$pg_bin_dir" ]; then
ln -sf "$pg_bin_dir/pg_dump" "$version_dir/bin/pg_dump"
ln -sf "$pg_bin_dir/pg_dumpall" "$version_dir/bin/pg_dumpall"
ln -sf "$pg_bin_dir/psql" "$version_dir/bin/psql"
ln -sf "$pg_bin_dir/pg_restore" "$version_dir/bin/pg_restore"
ln -sf "$pg_bin_dir/createdb" "$version_dir/bin/createdb"
ln -sf "$pg_bin_dir/dropdb" "$version_dir/bin/dropdb"
fi
mkdir -p postgresql/postgresql-$version
cp -r ../../assets/tools/x64/postgresql/postgresql-$version/bin postgresql/postgresql-$version/
done
# Copy MySQL client tools (5.7, 8.0, 8.4, 9) from pre-built assets
for version in 5.7 8.0 8.4 9; do
mkdir -p mysql/mysql-$version
cp -r ../../assets/tools/x64/mysql/mysql-$version/bin mysql/mysql-$version/
done
# Copy MariaDB client tools (10.6, 12.1) from pre-built assets
for version in 10.6 12.1; do
mkdir -p mariadb/mariadb-$version
cp -r ../../assets/tools/x64/mariadb/mariadb-$version/bin mariadb/mariadb-$version/
done
echo "Pre-built client tools setup complete"
- name: Install MongoDB Database Tools
run: |
cd backend/tools
# MongoDB Database Tools must be downloaded (not in pre-built assets)
# They are backward compatible - single version supports all servers (4.0-8.0)
MONGODB_TOOLS_URL="https://fastdl.mongodb.org/tools/db/mongodb-database-tools-debian12-x86_64-100.10.0.deb"
echo "Downloading MongoDB Database Tools..."
wget -q "$MONGODB_TOOLS_URL" -O /tmp/mongodb-database-tools.deb
echo "Installing MongoDB Database Tools..."
sudo dpkg -i /tmp/mongodb-database-tools.deb || sudo apt-get install -f -y --no-install-recommends
# Create symlinks to tools directory
ln -sf /usr/bin/mongodump mongodb/bin/mongodump
ln -sf /usr/bin/mongorestore mongodb/bin/mongorestore
rm -f /tmp/mongodb-database-tools.deb
echo "MongoDB Database Tools installed successfully"
- name: Verify MariaDB client tools exist
run: |
@@ -726,4 +696,4 @@ jobs:
- name: Push Helm chart to GHCR
run: |
VERSION="${{ needs.determine-version.outputs.new_version }}"
helm push databasus-${VERSION}.tgz oci://ghcr.io/databasus/charts
helm push databasus-${VERSION}.tgz oci://ghcr.io/databasus/charts

1345
AGENTS.md Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,152 +0,0 @@
---
description:
globs:
alwaysApply: true
---
Always place private methods to the bottom of file
**This rule applies to ALL Go files including tests, services, controllers, repositories, etc.**
In Go, exported (public) functions/methods start with uppercase letters, while unexported (private) ones start with lowercase letters.
## Structure Order:
1. Type definitions and constants
2. Public methods/functions (uppercase)
3. Private methods/functions (lowercase)
## Examples:
### Service with methods:
```go
type UserService struct {
repository *UserRepository
}
// Public methods first
func (s *UserService) CreateUser(user *User) error {
if err := s.validateUser(user); err != nil {
return err
}
return s.repository.Save(user)
}
func (s *UserService) GetUser(id uuid.UUID) (*User, error) {
return s.repository.FindByID(id)
}
// Private methods at the bottom
func (s *UserService) validateUser(user *User) error {
if user.Name == "" {
return errors.New("name is required")
}
return nil
}
```
### Package-level functions:
```go
package utils
// Public functions first
func ProcessData(data []byte) (Result, error) {
cleaned := sanitizeInput(data)
return parseData(cleaned)
}
func ValidateInput(input string) bool {
return isValidFormat(input) && checkLength(input)
}
// Private functions at the bottom
func sanitizeInput(data []byte) []byte {
// implementation
}
func parseData(data []byte) (Result, error) {
// implementation
}
func isValidFormat(input string) bool {
// implementation
}
func checkLength(input string) bool {
// implementation
}
```
### Test files:
```go
package user_test
// Public test functions first
func Test_CreateUser_ValidInput_UserCreated(t *testing.T) {
user := createTestUser()
result, err := service.CreateUser(user)
assert.NoError(t, err)
assert.NotNil(t, result)
}
func Test_GetUser_ExistingUser_ReturnsUser(t *testing.T) {
user := createTestUser()
// test implementation
}
// Private helper functions at the bottom
func createTestUser() *User {
return &User{
Name: "Test User",
Email: "test@example.com",
}
}
func setupTestDatabase() *Database {
// setup implementation
}
```
### Controller example:
```go
type ProjectController struct {
service *ProjectService
}
// Public HTTP handlers first
func (c *ProjectController) CreateProject(ctx *gin.Context) {
var request CreateProjectRequest
if err := ctx.ShouldBindJSON(&request); err != nil {
c.handleError(ctx, err)
return
}
// handler logic
}
func (c *ProjectController) GetProject(ctx *gin.Context) {
projectID := c.extractProjectID(ctx)
// handler logic
}
// Private helper methods at the bottom
func (c *ProjectController) handleError(ctx *gin.Context, err error) {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
}
func (c *ProjectController) extractProjectID(ctx *gin.Context) uuid.UUID {
return uuid.MustParse(ctx.Param("projectId"))
}
```
## Key Points:
- **Exported/Public** = starts with uppercase letter (CreateUser, GetProject)
- **Unexported/Private** = starts with lowercase letter (validateUser, handleError)
- This improves code readability by showing the public API first
- Private helpers are implementation details, so they go at the bottom
- Apply this rule consistently across ALL Go files in the project

View File

@@ -1,45 +0,0 @@
---
description:
globs:
alwaysApply: true
---
## Comment Guidelines
1. **No obvious comments** - Don't state what the code already clearly shows
2. **Functions and variables should have meaningful names** - Code should be self-documenting
3. **Comments for unclear code only** - Only add comments when code logic isn't immediately clear
## Key Principles:
- **Code should tell a story** - Use descriptive variable and function names
- **Comments explain WHY, not WHAT** - The code shows what happens, comments explain business logic or complex decisions
- **Prefer refactoring over commenting** - If code needs explaining, consider making it clearer instead
- **API documentation is required** - Swagger comments for all HTTP endpoints are mandatory
- **Complex algorithms deserve comments** - Mathematical formulas, business rules, or non-obvious optimizations
Example of useless comment:
1.
```sql
// Create projects table
CREATE TABLE projects (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
name TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
```
2.
```go
// Create test project
project := CreateTestProject(projectName, user, router)
```
3.
```go
// CreateValidLogItems creates valid log items for testing
func CreateValidLogItems(count int, uniqueID string) []logs_receiving.LogItemRequestDTO {
```

View File

@@ -1,133 +0,0 @@
---
description:
globs:
alwaysApply: true
---
1. When we write controller:
- we combine all routes to single controller
- names them as .WhatWeDo (not "handlers") concept
2. We use gin and \*gin.Context for all routes.
Example:
func (c *TasksController) GetAvailableTasks(ctx *gin.Context) ...
3. We document all routes with Swagger in the following format:
package audit_logs
import (
"net/http"
user_models "databasus-backend/internal/features/users/models"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type AuditLogController struct {
auditLogService \*AuditLogService
}
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
// All audit log endpoints require authentication (handled in main.go)
auditRoutes := router.Group("/audit-logs")
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
}
// GetGlobalAuditLogs
// @Summary Get global audit logs (ADMIN only)
// @Description Retrieve all audit logs across the system
// @Tags audit-logs
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
// @Success 200 {object} GetAuditLogsResponse
// @Failure 401 {object} map[string]string
// @Failure 403 {object} map[string]string
// @Router /audit-logs/global [get]
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
user, isOk := ctx.MustGet("user").(\*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
if err != nil {
if err.Error() == "only administrators can view global audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
ctx.JSON(http.StatusOK, response)
}
// GetUserAuditLogs
// @Summary Get user audit logs
// @Description Retrieve audit logs for a specific user
// @Tags audit-logs
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param userId path string true "User ID"
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
// @Success 200 {object} GetAuditLogsResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 403 {object} map[string]string
// @Router /audit-logs/users/{userId} [get]
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
user, isOk := ctx.MustGet("user").(\*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
userIDStr := ctx.Param("userId")
targetUserID, err := uuid.Parse(userIDStr)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
if err != nil {
if err.Error() == "insufficient permissions to view user audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
ctx.JSON(http.StatusOK, response)
}

View File

@@ -1,671 +0,0 @@
---
alwaysApply: false
---
This is example of CRUD:
------ backend/internal/features/audit_logs/controller.go ------
```
package audit_logs
import (
"net/http"
user_models "databasus-backend/internal/features/users/models"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type AuditLogController struct {
auditLogService *AuditLogService
}
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
// All audit log endpoints require authentication (handled in main.go)
auditRoutes := router.Group("/audit-logs")
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
}
// GetGlobalAuditLogs
// @Summary Get global audit logs (ADMIN only)
// @Description Retrieve all audit logs across the system
// @Tags audit-logs
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
// @Success 200 {object} GetAuditLogsResponse
// @Failure 401 {object} map[string]string
// @Failure 403 {object} map[string]string
// @Router /audit-logs/global [get]
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
user, isOk := ctx.MustGet("user").(*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
if err != nil {
if err.Error() == "only administrators can view global audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
ctx.JSON(http.StatusOK, response)
}
// GetUserAuditLogs
// @Summary Get user audit logs
// @Description Retrieve audit logs for a specific user
// @Tags audit-logs
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param userId path string true "User ID"
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
// @Success 200 {object} GetAuditLogsResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 403 {object} map[string]string
// @Router /audit-logs/users/{userId} [get]
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
user, isOk := ctx.MustGet("user").(*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
userIDStr := ctx.Param("userId")
targetUserID, err := uuid.Parse(userIDStr)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
if err != nil {
if err.Error() == "insufficient permissions to view user audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
ctx.JSON(http.StatusOK, response)
}
```
------ backend/internal/features/audit_logs/controller_test.go ------
```
package audit_logs
import (
"fmt"
"net/http"
"testing"
"time"
user_enums "databasus-backend/internal/features/users/enums"
users_middleware "databasus-backend/internal/features/users/middleware"
users_services "databasus-backend/internal/features/users/services"
users_testing "databasus-backend/internal/features/users/testing"
"databasus-backend/internal/storage"
test_utils "databasus-backend/internal/util/testing"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_GetGlobalAuditLogs_AdminSucceedsAndMemberGetsForbidden(t *testing.T) {
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
memberUser := users_testing.CreateTestUser(user_enums.UserRoleMember)
router := createRouter()
service := GetAuditLogService()
projectID := uuid.New()
// Create test logs
createAuditLog(service, "Test log with user", &adminUser.UserID, nil)
createAuditLog(service, "Test log with project", nil, &projectID)
createAuditLog(service, "Test log standalone", nil, nil)
// Test ADMIN can access global logs
var response GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router,
"/api/v1/audit-logs/global?limit=10", "Bearer "+adminUser.Token, http.StatusOK, &response)
assert.GreaterOrEqual(t, len(response.AuditLogs), 3)
assert.GreaterOrEqual(t, response.Total, int64(3))
messages := extractMessages(response.AuditLogs)
assert.Contains(t, messages, "Test log with user")
assert.Contains(t, messages, "Test log with project")
assert.Contains(t, messages, "Test log standalone")
// Test MEMBER cannot access global logs
resp := test_utils.MakeGetRequest(t, router, "/api/v1/audit-logs/global",
"Bearer "+memberUser.Token, http.StatusForbidden)
assert.Contains(t, string(resp.Body), "only administrators can view global audit logs")
}
func Test_GetUserAuditLogs_PermissionsEnforcedCorrectly(t *testing.T) {
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
router := createRouter()
service := GetAuditLogService()
projectID := uuid.New()
// Create test logs for different users
createAuditLog(service, "Test log user1 first", &user1.UserID, nil)
createAuditLog(service, "Test log user1 second", &user1.UserID, &projectID)
createAuditLog(service, "Test log user2 first", &user2.UserID, nil)
createAuditLog(service, "Test log user2 second", &user2.UserID, &projectID)
createAuditLog(service, "Test project log", nil, &projectID)
// Test ADMIN can view any user's logs
var user1Response GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router,
fmt.Sprintf("/api/v1/audit-logs/users/%s?limit=10", user1.UserID.String()),
"Bearer "+adminUser.Token, http.StatusOK, &user1Response)
assert.Equal(t, 2, len(user1Response.AuditLogs))
messages := extractMessages(user1Response.AuditLogs)
assert.Contains(t, messages, "Test log user1 first")
assert.Contains(t, messages, "Test log user1 second")
// Test user can view own logs
var ownLogsResponse GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router,
fmt.Sprintf("/api/v1/audit-logs/users/%s", user2.UserID.String()),
"Bearer "+user2.Token, http.StatusOK, &ownLogsResponse)
assert.Equal(t, 2, len(ownLogsResponse.AuditLogs))
// Test user cannot view other user's logs
resp := test_utils.MakeGetRequest(t, router,
fmt.Sprintf("/api/v1/audit-logs/users/%s", user1.UserID.String()),
"Bearer "+user2.Token, http.StatusForbidden)
assert.Contains(t, string(resp.Body), "insufficient permissions")
}
func Test_FilterAuditLogsByTime_ReturnsOnlyLogsBeforeDate(t *testing.T) {
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
router := createRouter()
service := GetAuditLogService()
db := storage.GetDb()
baseTime := time.Now().UTC()
// Create logs with different timestamps
createTimedLog(db, &adminUser.UserID, "Test old log", baseTime.Add(-2*time.Hour))
createTimedLog(db, &adminUser.UserID, "Test recent log", baseTime.Add(-30*time.Minute))
createAuditLog(service, "Test current log", &adminUser.UserID, nil)
// Test filtering - get logs before 1 hour ago
beforeTime := baseTime.Add(-1 * time.Hour)
var filteredResponse GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router,
fmt.Sprintf("/api/v1/audit-logs/global?beforeDate=%s", beforeTime.Format(time.RFC3339)),
"Bearer "+adminUser.Token, http.StatusOK, &filteredResponse)
// Verify only old log is returned
messages := extractMessages(filteredResponse.AuditLogs)
assert.Contains(t, messages, "Test old log")
assert.NotContains(t, messages, "Test recent log")
assert.NotContains(t, messages, "Test current log")
// Test without filter - should get all logs
var allResponse GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router, "/api/v1/audit-logs/global",
"Bearer "+adminUser.Token, http.StatusOK, &allResponse)
assert.GreaterOrEqual(t, len(allResponse.AuditLogs), 3)
}
func createRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
SetupDependencies()
v1 := router.Group("/api/v1")
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
GetAuditLogController().RegisterRoutes(protected.(*gin.RouterGroup))
return router
}
```
------ backend/internal/features/audit_logs/di.go ------
```
package audit_logs
import (
users_services "databasus-backend/internal/features/users/services"
"databasus-backend/internal/util/logger"
)
var auditLogRepository = &AuditLogRepository{}
var auditLogService = &AuditLogService{
auditLogRepository: auditLogRepository,
logger: logger.GetLogger(),
}
var auditLogController = &AuditLogController{
auditLogService: auditLogService,
}
func GetAuditLogService() *AuditLogService {
return auditLogService
}
func GetAuditLogController() *AuditLogController {
return auditLogController
}
func SetupDependencies() {
users_services.GetUserService().SetAuditLogWriter(auditLogService)
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
}
```
------ backend/internal/features/audit_logs/dto.go ------
```
package audit_logs
import "time"
type GetAuditLogsRequest struct {
Limit int `form:"limit" json:"limit"`
Offset int `form:"offset" json:"offset"`
BeforeDate *time.Time `form:"beforeDate" json:"beforeDate"`
}
type GetAuditLogsResponse struct {
AuditLogs []*AuditLog `json:"auditLogs"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}
```
------ backend/internal/features/audit_logs/models.go ------
```
package audit_logs
import (
"time"
"github.com/google/uuid"
)
type AuditLog struct {
ID uuid.UUID `json:"id" gorm:"column:id"`
UserID *uuid.UUID `json:"userId" gorm:"column:user_id"`
ProjectID *uuid.UUID `json:"projectId" gorm:"column:project_id"`
Message string `json:"message" gorm:"column:message"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
}
func (AuditLog) TableName() string {
return "audit_logs"
}
```
------ backend/internal/features/audit_logs/repository.go ------
```
package audit_logs
import (
"databasus-backend/internal/storage"
"time"
"github.com/google/uuid"
)
type AuditLogRepository struct{}
func (r *AuditLogRepository) Create(auditLog *AuditLog) error {
if auditLog.ID == uuid.Nil {
auditLog.ID = uuid.New()
}
return storage.GetDb().Create(auditLog).Error
}
func (r *AuditLogRepository) GetGlobal(limit, offset int, beforeDate *time.Time) ([]*AuditLog, error) {
var auditLogs []*AuditLog
query := storage.GetDb().Order("created_at DESC")
if beforeDate != nil {
query = query.Where("created_at < ?", *beforeDate)
}
err := query.
Limit(limit).
Offset(offset).
Find(&auditLogs).Error
return auditLogs, err
}
func (r *AuditLogRepository) GetByUser(
userID uuid.UUID,
limit, offset int,
beforeDate *time.Time,
) ([]*AuditLog, error) {
var auditLogs []*AuditLog
query := storage.GetDb().
Where("user_id = ?", userID).
Order("created_at DESC")
if beforeDate != nil {
query = query.Where("created_at < ?", *beforeDate)
}
err := query.
Limit(limit).
Offset(offset).
Find(&auditLogs).Error
return auditLogs, err
}
func (r *AuditLogRepository) GetByProject(
projectID uuid.UUID,
limit, offset int,
beforeDate *time.Time,
) ([]*AuditLog, error) {
var auditLogs []*AuditLog
query := storage.GetDb().
Where("project_id = ?", projectID).
Order("created_at DESC")
if beforeDate != nil {
query = query.Where("created_at < ?", *beforeDate)
}
err := query.
Limit(limit).
Offset(offset).
Find(&auditLogs).Error
return auditLogs, err
}
func (r *AuditLogRepository) CountGlobal(beforeDate *time.Time) (int64, error) {
var count int64
query := storage.GetDb().Model(&AuditLog{})
if beforeDate != nil {
query = query.Where("created_at < ?", *beforeDate)
}
err := query.Count(&count).Error
return count, err
}
```
------ backend/internal/features/audit_logs/service.go ------
```
package audit_logs
import (
"errors"
"log/slog"
"time"
user_enums "databasus-backend/internal/features/users/enums"
user_models "databasus-backend/internal/features/users/models"
"github.com/google/uuid"
)
type AuditLogService struct {
auditLogRepository *AuditLogRepository
logger *slog.Logger
}
func (s *AuditLogService) WriteAuditLog(
message string,
userID *uuid.UUID,
projectID *uuid.UUID,
) {
auditLog := &AuditLog{
UserID: userID,
ProjectID: projectID,
Message: message,
CreatedAt: time.Now().UTC(),
}
err := s.auditLogRepository.Create(auditLog)
if err != nil {
s.logger.Error("failed to create audit log", "error", err)
return
}
}
func (s *AuditLogService) CreateAuditLog(auditLog *AuditLog) error {
return s.auditLogRepository.Create(auditLog)
}
func (s *AuditLogService) GetGlobalAuditLogs(
user *user_models.User,
request *GetAuditLogsRequest,
) (*GetAuditLogsResponse, error) {
if user.Role != user_enums.UserRoleAdmin {
return nil, errors.New("only administrators can view global audit logs")
}
limit := request.Limit
if limit <= 0 || limit > 1000 {
limit = 100
}
offset := max(request.Offset, 0)
auditLogs, err := s.auditLogRepository.GetGlobal(limit, offset, request.BeforeDate)
if err != nil {
return nil, err
}
total, err := s.auditLogRepository.CountGlobal(request.BeforeDate)
if err != nil {
return nil, err
}
return &GetAuditLogsResponse{
AuditLogs: auditLogs,
Total: total,
Limit: limit,
Offset: offset,
}, nil
}
func (s *AuditLogService) GetUserAuditLogs(
targetUserID uuid.UUID,
user *user_models.User,
request *GetAuditLogsRequest,
) (*GetAuditLogsResponse, error) {
// Users can view their own logs, ADMIN can view any user's logs
if user.Role != user_enums.UserRoleAdmin && user.ID != targetUserID {
return nil, errors.New("insufficient permissions to view user audit logs")
}
limit := request.Limit
if limit <= 0 || limit > 1000 {
limit = 100
}
offset := max(request.Offset, 0)
auditLogs, err := s.auditLogRepository.GetByUser(targetUserID, limit, offset, request.BeforeDate)
if err != nil {
return nil, err
}
return &GetAuditLogsResponse{
AuditLogs: auditLogs,
Total: int64(len(auditLogs)),
Limit: limit,
Offset: offset,
}, nil
}
func (s *AuditLogService) GetProjectAuditLogs(
projectID uuid.UUID,
request *GetAuditLogsRequest,
) (*GetAuditLogsResponse, error) {
limit := request.Limit
if limit <= 0 || limit > 1000 {
limit = 100
}
offset := max(request.Offset, 0)
auditLogs, err := s.auditLogRepository.GetByProject(projectID, limit, offset, request.BeforeDate)
if err != nil {
return nil, err
}
return &GetAuditLogsResponse{
AuditLogs: auditLogs,
Total: int64(len(auditLogs)),
Limit: limit,
Offset: offset,
}, nil
}
```
------ backend/internal/features/audit_logs/service_test.go ------
```
package audit_logs
import (
"testing"
"time"
user_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
)
func Test_AuditLogs_ProjectSpecificLogs(t *testing.T) {
service := GetAuditLogService()
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
project1ID, project2ID := uuid.New(), uuid.New()
// Create test logs for projects
createAuditLog(service, "Test project1 log first", &user1.UserID, &project1ID)
createAuditLog(service, "Test project1 log second", &user2.UserID, &project1ID)
createAuditLog(service, "Test project2 log first", &user1.UserID, &project2ID)
createAuditLog(service, "Test project2 log second", &user2.UserID, &project2ID)
createAuditLog(service, "Test no project log", &user1.UserID, nil)
request := &GetAuditLogsRequest{Limit: 10, Offset: 0}
// Test project 1 logs
project1Response, err := service.GetProjectAuditLogs(project1ID, request)
assert.NoError(t, err)
assert.Equal(t, 2, len(project1Response.AuditLogs))
messages := extractMessages(project1Response.AuditLogs)
assert.Contains(t, messages, "Test project1 log first")
assert.Contains(t, messages, "Test project1 log second")
for _, log := range project1Response.AuditLogs {
assert.Equal(t, &project1ID, log.ProjectID)
}
// Test project 2 logs
project2Response, err := service.GetProjectAuditLogs(project2ID, request)
assert.NoError(t, err)
assert.Equal(t, 2, len(project2Response.AuditLogs))
messages2 := extractMessages(project2Response.AuditLogs)
assert.Contains(t, messages2, "Test project2 log first")
assert.Contains(t, messages2, "Test project2 log second")
// Test pagination
limitedResponse, err := service.GetProjectAuditLogs(project1ID,
&GetAuditLogsRequest{Limit: 1, Offset: 0})
assert.NoError(t, err)
assert.Equal(t, 1, len(limitedResponse.AuditLogs))
assert.Equal(t, 1, limitedResponse.Limit)
// Test beforeDate filter
beforeTime := time.Now().UTC().Add(-1 * time.Minute)
filteredResponse, err := service.GetProjectAuditLogs(project1ID,
&GetAuditLogsRequest{Limit: 10, BeforeDate: &beforeTime})
assert.NoError(t, err)
for _, log := range filteredResponse.AuditLogs {
assert.True(t, log.CreatedAt.Before(beforeTime))
}
}
func createAuditLog(service *AuditLogService, message string, userID, projectID *uuid.UUID) {
service.WriteAuditLog(message, userID, projectID)
}
func extractMessages(logs []*AuditLog) []string {
messages := make([]string, len(logs))
for i, log := range logs {
messages[i] = log.Message
}
return messages
}
func createTimedLog(db *gorm.DB, userID *uuid.UUID, message string, createdAt time.Time) {
log := &AuditLog{
ID: uuid.New(),
UserID: userID,
Message: message,
CreatedAt: createdAt,
}
db.Create(log)
}
```

View File

@@ -1,74 +0,0 @@
---
description:
globs:
alwaysApply: true
---
For DI files use implicit fields declaration styles (espesially
for controllers, services, repositories, use cases, etc., not simple
data structures).
So, instead of:
var orderController = &OrderController{
orderService: orderService,
botUserService: bot_users.GetBotUserService(),
botService: bots.GetBotService(),
userService: users.GetUserService(),
}
Use:
var orderController = &OrderController{
orderService,
bot_users.GetBotUserService(),
bots.GetBotService(),
users.GetUserService(),
}
This is needed to avoid forgetting to update DI style
when we add new dependency.
---
Please force such usage if file look like this (see some
services\controllers\repos definitions and getters):
var orderBackgroundService = &OrderBackgroundService{
orderService: orderService,
orderPaymentRepository: orderPaymentRepository,
botService: bots.GetBotService(),
paymentSettingsService: payment_settings.GetPaymentSettingsService(),
orderSubscriptionListeners: []OrderSubscriptionListener{},
}
var orderController = &OrderController{
orderService: orderService,
botUserService: bot_users.GetBotUserService(),
botService: bots.GetBotService(),
userService: users.GetUserService(),
}
func GetUniquePaymentRepository() *repositories.UniquePaymentRepository {
return uniquePaymentRepository
}
func GetOrderPaymentRepository() *repositories.OrderPaymentRepository {
return orderPaymentRepository
}
func GetOrderService() *OrderService {
return orderService
}
func GetOrderController() *OrderController {
return orderController
}
func GetOrderBackgroundService() *OrderBackgroundService {
return orderBackgroundService
}
func GetOrderRepository() *repositories.OrderRepository {
return orderRepository
}

View File

@@ -1,27 +0,0 @@
---
description:
globs:
alwaysApply: true
---
When writting migrations:
- write them for PostgreSQL
- for PRIMARY UUID keys use gen_random_uuid()
- for time use TIMESTAMPTZ (timestamp with zone)
- split table, constraint and indexes declaration (table first, them other one by one)
- format SQL in pretty way (add spaces, align columns types), constraints split by lines. The example:
CREATE TABLE marketplace_info (
bot_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
title TEXT NOT NULL,
description TEXT NOT NULL,
short_description TEXT NOT NULL,
tutorial_url TEXT,
info_order BIGINT NOT NULL DEFAULT 0,
is_published BOOLEAN NOT NULL DEFAULT FALSE
);
ALTER TABLE marketplace_info_images
ADD CONSTRAINT fk_marketplace_info_images_bot_id
FOREIGN KEY (bot_id)
REFERENCES marketplace_info (bot_id);

View File

@@ -1,12 +0,0 @@
---
description:
globs:
alwaysApply: true
---
When applying changes, do not forget to refactor old code.
You can shortify, make more readable, improve code quality, etc.
Common logic can be extracted to functions, constants, files, etc.
After each large change with more than ~50-100 lines of code - always run `make lint` (from backend root folder) and, if you change frontend, run `npm run format` (from frontend root folder).

View File

@@ -1,147 +0,0 @@
---
description:
globs:
alwaysApply: true
---
After writing tests, always launch them and verify that they pass.
## Test Naming Format
Use these naming patterns:
- `Test_WhatWeDo_WhatWeExpect`
- `Test_WhatWeDo_WhichConditions_WhatWeExpect`
## Examples from Real Codebase:
- `Test_CreateApiKey_WhenUserIsProjectOwner_ApiKeyCreated`
- `Test_UpdateProject_WhenUserIsProjectAdmin_ProjectUpdated`
- `Test_DeleteApiKey_WhenUserIsProjectMember_ReturnsForbidden`
- `Test_GetProjectAuditLogs_WithDifferentUserRoles_EnforcesPermissionsCorrectly`
- `Test_ProjectLifecycleE2E_CompletesSuccessfully`
## Testing Philosophy
**Prefer Controllers Over Unit Tests:**
- Test through HTTP endpoints via controllers whenever possible
- Avoid testing repositories, services in isolation - test via API instead
- Only use unit tests for complex model logic when no API exists
- Name test files `controller_test.go` or `service_test.go`, not `integration_test.go`
**Extract Common Logic to Testing Utilities:**
- Create `testing.go` or `testing/testing.go` files for shared test utilities
- Extract router creation, user setup, models creation helpers (in API, not just structs creation)
- Reuse common patterns across different test files
**Refactor Existing Tests:**
- When working with existing tests, always look for opportunities to refactor and improve
- Extract repetitive setup code to common utilities
- Simplify complex tests by breaking them into smaller, focused tests
- Replace inline test data creation with reusable helper functions
- Consolidate similar test patterns across different test files
- Make tests more readable and maintainable for other developers
## Testing Utilities Structure
**Create `testing.go` or `testing/testing.go` files with common utilities:**
```go
package projects_testing
// CreateTestRouter creates unified router for all controllers
func CreateTestRouter(controllers ...ControllerInterface) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
v1 := router.Group("/api/v1")
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
for _, controller := range controllers {
if routerGroup, ok := protected.(*gin.RouterGroup); ok {
controller.RegisterRoutes(routerGroup)
}
}
return router
}
// CreateTestProjectViaAPI creates project through HTTP API
func CreateTestProjectViaAPI(name string, owner *users_dto.SignInResponseDTO, router *gin.Engine) (*projects_models.Project, string) {
request := projects_dto.CreateProjectRequestDTO{Name: name}
w := MakeAPIRequest(router, "POST", "/api/v1/projects", "Bearer "+owner.Token, request)
// Handle response...
return project, owner.Token
}
// AddMemberToProject adds member via API call
func AddMemberToProject(project *projects_models.Project, member *users_dto.SignInResponseDTO, role users_enums.ProjectRole, ownerToken string, router *gin.Engine) {
// Implementation...
}
```
## Controller Test Examples
**Permission-based testing:**
```go
func Test_CreateApiKey_WhenUserIsProjectOwner_ApiKeyCreated(t *testing.T) {
router := CreateApiKeyTestRouter(GetProjectController(), GetMembershipController())
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
project, _ := projects_testing.CreateTestProjectViaAPI("Test Project", owner, router)
request := CreateApiKeyRequestDTO{Name: "Test API Key"}
var response ApiKey
test_utils.MakePostRequestAndUnmarshal(t, router, "/api/v1/projects/api-keys/"+project.ID.String(), "Bearer "+owner.Token, request, http.StatusOK, &response)
assert.Equal(t, "Test API Key", response.Name)
assert.NotEmpty(t, response.Token)
}
```
**Cross-project security testing:**
```go
func Test_UpdateApiKey_WithApiKeyFromDifferentProject_ReturnsBadRequest(t *testing.T) {
router := CreateApiKeyTestRouter(GetProjectController(), GetMembershipController())
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
project1, _ := projects_testing.CreateTestProjectViaAPI("Project 1", owner1, router)
project2, _ := projects_testing.CreateTestProjectViaAPI("Project 2", owner2, router)
apiKey := CreateTestApiKey("Cross Project Key", project1.ID, owner1.Token, router)
// Try to update via different project endpoint
request := UpdateApiKeyRequestDTO{Name: &"Hacked Key"}
resp := test_utils.MakePutRequest(t, router, "/api/v1/projects/api-keys/"+project2.ID.String()+"/"+apiKey.ID.String(), "Bearer "+owner2.Token, request, http.StatusBadRequest)
assert.Contains(t, string(resp.Body), "API key does not belong to this project")
}
```
**E2E lifecycle testing:**
```go
func Test_ProjectLifecycleE2E_CompletesSuccessfully(t *testing.T) {
router := projects_testing.CreateTestRouter(GetProjectController(), GetMembershipController())
// 1. Create project
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
project := projects_testing.CreateTestProject("E2E Project", owner, router)
// 2. Add member
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
projects_testing.AddMemberToProject(project, member, users_enums.ProjectRoleMember, owner.Token, router)
// 3. Promote to admin
projects_testing.ChangeMemberRole(project, member.UserID, users_enums.ProjectRoleAdmin, owner.Token, router)
// 4. Transfer ownership
projects_testing.TransferProjectOwnership(project, member.UserID, owner.Token, router)
// 5. Verify new owner can manage project
finalProject := projects_testing.GetProject(project.ID, member.Token, router)
assert.Equal(t, project.ID, finalProject.ID)
}
```

View File

@@ -1,6 +0,0 @@
---
description:
globs:
alwaysApply: true
---
Always use time.Now().UTC() instead of time.Now()

View File

@@ -2,8 +2,10 @@
DEV_DB_NAME=databasus
DEV_DB_USERNAME=postgres
DEV_DB_PASSWORD=Q1234567
#app
# app
ENV_MODE=development
# logging
SHOW_DB_INSTALLATION_VERIFICATION_LOGS=true
# db
DATABASE_DSN=host=dev-db user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
DATABASE_URL=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable

View File

@@ -16,7 +16,6 @@ import (
"databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/features/backups/backups/backuping"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_download "databasus-backend/internal/features/backups/backups/download"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -28,6 +27,8 @@ import (
"databasus-backend/internal/features/restores"
"databasus-backend/internal/features/storages"
system_healthcheck "databasus-backend/internal/features/system/healthcheck"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
task_registry "databasus-backend/internal/features/tasks/registry"
users_controllers "databasus-backend/internal/features/users/controllers"
users_middleware "databasus-backend/internal/features/users/middleware"
users_services "databasus-backend/internal/features/users/services"
@@ -59,6 +60,8 @@ func main() {
cache_utils.TestCacheConnection()
if config.GetEnv().IsPrimaryNode {
log.Info("Clearing cache...")
err := cache_utils.ClearAllCache()
if err != nil {
log.Error("Failed to clear cache", "error", err)
@@ -239,7 +242,7 @@ func setUpDependencies() {
notifiers.SetupDependencies()
storages.SetupDependencies()
backups_config.SetupDependencies()
backups_cancellation.SetupDependencies()
task_cancellation.SetupDependencies()
}
func runBackgroundTasks(log *slog.Logger) {
@@ -257,14 +260,14 @@ func runBackgroundTasks(log *slog.Logger) {
cancel()
}()
err := files_utils.CleanFolder(config.GetEnv().TempFolder)
if err != nil {
log.Error("Failed to clean temp folder", "error", err)
}
if config.GetEnv().IsPrimaryNode {
log.Info("Starting primary node background tasks...")
err := files_utils.CleanFolder(config.GetEnv().TempFolder)
if err != nil {
log.Error("Failed to clean temp folder", "error", err)
}
go runWithPanicLogging(log, "backup background service", func() {
backuping.GetBackupsScheduler().Run(ctx)
})
@@ -284,6 +287,10 @@ func runBackgroundTasks(log *slog.Logger) {
go runWithPanicLogging(log, "download token cleanup background service", func() {
backups_download.GetDownloadTokenBackgroundService().Run(ctx)
})
go runWithPanicLogging(log, "task nodes registry background service", func() {
task_registry.GetTaskNodesRegistry().Run(ctx)
})
} else {
log.Info("Skipping primary node tasks as not primary node")
}

View File

@@ -30,6 +30,8 @@ type EnvVariables struct {
MariadbInstallDir string `env:"MARIADB_INSTALL_DIR"`
MongodbInstallDir string `env:"MONGODB_INSTALL_DIR"`
ShowDbInstallationVerificationLogs bool `env:"SHOW_DB_INSTALLATION_VERIFICATION_LOGS"`
NodeID string
IsManyNodesMode bool `env:"IS_MANY_NODES_MODE"`
IsPrimaryNode bool `env:"IS_PRIMARY_NODE"`
@@ -169,6 +171,11 @@ func loadEnvVariables() {
os.Exit(1)
}
// Set default value for ShowDbInstallationVerificationLogs if not defined
if os.Getenv("SHOW_DB_INSTALLATION_VERIFICATION_LOGS") == "" {
env.ShowDbInstallationVerificationLogs = true
}
for _, arg := range os.Args {
if strings.Contains(arg, "test") {
env.IsTesting = true
@@ -192,16 +199,36 @@ func loadEnvVariables() {
log.Info("ENV_MODE loaded", "mode", env.EnvMode)
env.PostgresesInstallDir = filepath.Join(backendRoot, "tools", "postgresql")
tools.VerifyPostgresesInstallation(log, env.EnvMode, env.PostgresesInstallDir)
tools.VerifyPostgresesInstallation(
log,
env.EnvMode,
env.PostgresesInstallDir,
env.ShowDbInstallationVerificationLogs,
)
env.MysqlInstallDir = filepath.Join(backendRoot, "tools", "mysql")
tools.VerifyMysqlInstallation(log, env.EnvMode, env.MysqlInstallDir)
tools.VerifyMysqlInstallation(
log,
env.EnvMode,
env.MysqlInstallDir,
env.ShowDbInstallationVerificationLogs,
)
env.MariadbInstallDir = filepath.Join(backendRoot, "tools", "mariadb")
tools.VerifyMariadbInstallation(log, env.EnvMode, env.MariadbInstallDir)
tools.VerifyMariadbInstallation(
log,
env.EnvMode,
env.MariadbInstallDir,
env.ShowDbInstallationVerificationLogs,
)
env.MongodbInstallDir = filepath.Join(backendRoot, "tools", "mongodb")
tools.VerifyMongodbInstallation(log, env.EnvMode, env.MongodbInstallDir)
tools.VerifyMongodbInstallation(
log,
env.EnvMode,
env.MongodbInstallDir,
env.ShowDbInstallationVerificationLogs,
)
env.NodeID = uuid.New().String()
if env.NodeNetworkThroughputMBs == 0 {

View File

@@ -3,11 +3,12 @@ package backuping
import (
"context"
"databasus-backend/internal/config"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/storages"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
task_registry "databasus-backend/internal/features/tasks/registry"
workspaces_services "databasus-backend/internal/features/workspaces/services"
util_encryption "databasus-backend/internal/util/encryption"
"errors"
@@ -20,6 +21,11 @@ import (
"github.com/google/uuid"
)
const (
heartbeatTickerInterval = 15 * time.Second
backuperHeathcheckThreshold = 5 * time.Minute
)
type BackuperNode struct {
databaseService *databases.DatabaseService
fieldEncryptor util_encryption.FieldEncryptor
@@ -28,8 +34,8 @@ type BackuperNode struct {
backupConfigService *backups_config.BackupConfigService
storageService *storages.StorageService
notificationSender backups_core.NotificationSender
backupCancelManager *backups_cancellation.BackupCancelManager
nodesRegistry *BackupNodesRegistry
backupCancelManager *tasks_cancellation.TaskCancelManager
tasksRegistry *task_registry.TaskNodesRegistry
logger *slog.Logger
createBackupUseCase backups_core.CreateBackupUsecase
nodeID uuid.UUID
@@ -42,20 +48,19 @@ func (n *BackuperNode) Run(ctx context.Context) {
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
backupNode := BackupNode{
backupNode := task_registry.TaskNode{
ID: n.nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: time.Now().UTC(),
}
if err := n.nodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
if err := n.tasksRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
n.logger.Error("Failed to register node in registry", "error", err)
panic(err)
}
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
n.MakeBackup(backupID, isCallNotifier)
if err := n.nodesRegistry.PublishBackupCompletion(n.nodeID.String(), backupID); err != nil {
if err := n.tasksRegistry.PublishTaskCompletion(n.nodeID.String(), backupID); err != nil {
n.logger.Error(
"Failed to publish backup completion",
"error",
@@ -66,17 +71,17 @@ func (n *BackuperNode) Run(ctx context.Context) {
}
}
if err := n.nodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID.String(), backupHandler); err != nil {
if err := n.tasksRegistry.SubscribeNodeForTasksAssignment(n.nodeID.String(), backupHandler); err != nil {
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
panic(err)
}
defer func() {
if err := n.nodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
if err := n.tasksRegistry.UnsubscribeNodeForTasksAssignments(); err != nil {
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
}
}()
ticker := time.NewTicker(15 * time.Second)
ticker := time.NewTicker(heartbeatTickerInterval)
defer ticker.Stop()
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
@@ -86,7 +91,7 @@ func (n *BackuperNode) Run(ctx context.Context) {
case <-ctx.Done():
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
if err := n.nodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
if err := n.tasksRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
n.logger.Error("Failed to unregister node from registry", "error", err)
}
@@ -98,7 +103,7 @@ func (n *BackuperNode) Run(ctx context.Context) {
}
func (n *BackuperNode) IsBackuperRunning() bool {
return n.lastHeartbeat.After(time.Now().UTC().Add(-5 * time.Minute))
return n.lastHeartbeat.After(time.Now().UTC().Add(-backuperHeathcheckThreshold))
}
func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
@@ -147,8 +152,8 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
}
ctx, cancel := context.WithCancel(context.Background())
n.backupCancelManager.RegisterBackup(backup.ID, cancel)
defer n.backupCancelManager.UnregisterBackup(backup.ID)
n.backupCancelManager.RegisterTask(backup.ID, cancel)
defer n.backupCancelManager.UnregisterTask(backup.ID)
backupMetadata, err := n.createBackupUseCase.Execute(
ctx,
@@ -335,10 +340,9 @@ func (n *BackuperNode) SendBackupNotification(
}
}
func (n *BackuperNode) sendHeartbeat(backupNode *BackupNode) {
func (n *BackuperNode) sendHeartbeat(backupNode *task_registry.TaskNode) {
n.lastHeartbeat = time.Now().UTC()
backupNode.LastHeartbeat = time.Now().UTC()
if err := n.nodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil {
if err := n.tasksRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil {
n.logger.Error("Failed to send heartbeat", "error", err)
}
}

View File

@@ -2,15 +2,15 @@ package backuping
import (
"databasus-backend/internal/config"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/usecases"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
task_registry "databasus-backend/internal/features/tasks/registry"
workspaces_services "databasus-backend/internal/features/workspaces/services"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
"time"
@@ -20,15 +20,9 @@ import (
var backupRepository = &backups_core.BackupRepository{}
var backupCancelManager = backups_cancellation.GetBackupCancelManager()
var taskCancelManager = tasks_cancellation.GetTaskCancelManager()
var nodesRegistry = &BackupNodesRegistry{
cache_utils.GetValkeyClient(),
logger.GetLogger(),
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
}
var nodesRegistry = task_registry.GetTaskNodesRegistry()
func getNodeID() uuid.UUID {
nodeIDStr := config.GetEnv().NodeID
@@ -48,7 +42,7 @@ var backuperNode = &BackuperNode{
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
backupCancelManager,
taskCancelManager,
nodesRegistry,
logger.GetLogger(),
usecases.GetCreateBackupUsecase(),
@@ -60,7 +54,7 @@ var backupsScheduler = &BackupsScheduler{
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
backupCancelManager,
taskCancelManager,
nodesRegistry,
time.Now().UTC(),
logger.GetLogger(),

View File

@@ -1,32 +1,6 @@
package backuping
import (
"time"
"github.com/google/uuid"
)
type BackupNode struct {
ID uuid.UUID `json:"id"`
ThroughputMBs int `json:"throughputMBs"`
LastHeartbeat time.Time `json:"lastHeartbeat"`
}
type BackupNodeStats struct {
ID uuid.UUID `json:"id"`
ActiveBackups int `json:"activeBackups"`
}
type BackupSubmitMessage struct {
NodeID string `json:"nodeId"`
BackupID string `json:"backupId"`
IsCallNotifier bool `json:"isCallNotifier"`
}
type BackupCompletionMessage struct {
NodeID string `json:"nodeId"`
BackupID string `json:"backupId"`
}
import "github.com/google/uuid"
type BackupToNodeRelation struct {
NodeID uuid.UUID `json:"nodeId"`

View File

@@ -1,448 +0,0 @@
package backuping
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"time"
cache_utils "databasus-backend/internal/util/cache"
"github.com/google/uuid"
"github.com/valkey-io/valkey-go"
)
const (
nodeInfoKeyPrefix = "node:"
nodeInfoKeySuffix = ":info"
nodeActiveBackupsPrefix = "node:"
nodeActiveBackupsSuffix = ":active_backups"
backupSubmitChannel = "backup:submit"
backupCompletionChannel = "backup:completion"
)
// BackupNodesRegistry helps to sync backups scheduler and backup nodes.
// Features:
// - Track node availability and load level
// - Assign from scheduler to node backups needed to be processed
// - Notify scheduler from node about backup completion
type BackupNodesRegistry struct {
client valkey.Client
logger *slog.Logger
timeout time.Duration
pubsubBackups *cache_utils.PubSubManager
pubsubCompletions *cache_utils.PubSubManager
}
func (r *BackupNodesRegistry) GetAvailableNodes() ([]BackupNode, error) {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
var allKeys []string
cursor := uint64(0)
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
for {
result := r.client.Do(
ctx,
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
)
if result.Error() != nil {
return nil, fmt.Errorf("failed to scan node keys: %w", result.Error())
}
scanResult, err := result.AsScanEntry()
if err != nil {
return nil, fmt.Errorf("failed to parse scan result: %w", err)
}
allKeys = append(allKeys, scanResult.Elements...)
cursor = scanResult.Cursor
if cursor == 0 {
break
}
}
if len(allKeys) == 0 {
return []BackupNode{}, nil
}
keyDataMap, err := r.pipelineGetKeys(allKeys)
if err != nil {
return nil, fmt.Errorf("failed to pipeline get node keys: %w", err)
}
var nodes []BackupNode
for key, data := range keyDataMap {
var node BackupNode
if err := json.Unmarshal(data, &node); err != nil {
r.logger.Warn("Failed to unmarshal node data", "key", key, "error", err)
continue
}
nodes = append(nodes, node)
}
return nodes, nil
}
func (r *BackupNodesRegistry) GetBackupNodesStats() ([]BackupNodeStats, error) {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
var allKeys []string
cursor := uint64(0)
pattern := nodeActiveBackupsPrefix + "*" + nodeActiveBackupsSuffix
for {
result := r.client.Do(
ctx,
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(100).Build(),
)
if result.Error() != nil {
return nil, fmt.Errorf("failed to scan active backups keys: %w", result.Error())
}
scanResult, err := result.AsScanEntry()
if err != nil {
return nil, fmt.Errorf("failed to parse scan result: %w", err)
}
allKeys = append(allKeys, scanResult.Elements...)
cursor = scanResult.Cursor
if cursor == 0 {
break
}
}
if len(allKeys) == 0 {
return []BackupNodeStats{}, nil
}
keyDataMap, err := r.pipelineGetKeys(allKeys)
if err != nil {
return nil, fmt.Errorf("failed to pipeline get active backups keys: %w", err)
}
var stats []BackupNodeStats
for key, data := range keyDataMap {
nodeID := r.extractNodeIDFromKey(key, nodeActiveBackupsPrefix, nodeActiveBackupsSuffix)
count, err := r.parseIntFromBytes(data)
if err != nil {
r.logger.Warn("Failed to parse active backups count", "key", key, "error", err)
continue
}
stat := BackupNodeStats{
ID: nodeID,
ActiveBackups: int(count),
}
stats = append(stats, stat)
}
return stats, nil
}
func (r *BackupNodesRegistry) IncrementBackupsInProgress(nodeID string) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID, nodeActiveBackupsSuffix)
result := r.client.Do(ctx, r.client.B().Incr().Key(key).Build())
if result.Error() != nil {
return fmt.Errorf(
"failed to increment backups in progress for node %s: %w",
nodeID,
result.Error(),
)
}
return nil
}
func (r *BackupNodesRegistry) DecrementBackupsInProgress(nodeID string) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID, nodeActiveBackupsSuffix)
result := r.client.Do(ctx, r.client.B().Decr().Key(key).Build())
if result.Error() != nil {
return fmt.Errorf(
"failed to decrement backups in progress for node %s: %w",
nodeID,
result.Error(),
)
}
newValue, err := result.AsInt64()
if err != nil {
return fmt.Errorf("failed to parse decremented value for node %s: %w", nodeID, err)
}
if newValue < 0 {
setCtx, setCancel := context.WithTimeout(context.Background(), r.timeout)
r.client.Do(setCtx, r.client.B().Set().Key(key).Value("0").Build())
setCancel()
r.logger.Warn("Active backups counter went below 0, reset to 0", "nodeID", nodeID)
}
return nil
}
func (r *BackupNodesRegistry) HearthbeatNodeInRegistry(now time.Time, backupNode BackupNode) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
backupNode.LastHeartbeat = now
data, err := json.Marshal(backupNode)
if err != nil {
return fmt.Errorf("failed to marshal backup node: %w", err)
}
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, backupNode.ID.String(), nodeInfoKeySuffix)
result := r.client.Do(
ctx,
r.client.B().Set().Key(key).Value(string(data)).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to register node %s: %w", backupNode.ID, result.Error())
}
return nil
}
func (r *BackupNodesRegistry) UnregisterNodeFromRegistry(backupNode BackupNode) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, backupNode.ID.String(), nodeInfoKeySuffix)
counterKey := fmt.Sprintf(
"%s%s%s",
nodeActiveBackupsPrefix,
backupNode.ID.String(),
nodeActiveBackupsSuffix,
)
result := r.client.Do(
ctx,
r.client.B().Del().Key(infoKey, counterKey).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to unregister node %s: %w", backupNode.ID, result.Error())
}
r.logger.Info("Unregistered node from registry", "nodeID", backupNode.ID)
return nil
}
func (r *BackupNodesRegistry) AssignBackupToNode(
targetNodeID string,
backupID uuid.UUID,
isCallNotifier bool,
) error {
ctx := context.Background()
message := BackupSubmitMessage{
NodeID: targetNodeID,
BackupID: backupID.String(),
IsCallNotifier: isCallNotifier,
}
messageJSON, err := json.Marshal(message)
if err != nil {
return fmt.Errorf("failed to marshal backup submit message: %w", err)
}
err = r.pubsubBackups.Publish(ctx, backupSubmitChannel, string(messageJSON))
if err != nil {
return fmt.Errorf("failed to publish backup submit message: %w", err)
}
return nil
}
func (r *BackupNodesRegistry) SubscribeNodeForBackupsAssignment(
nodeID string,
handler func(backupID uuid.UUID, isCallNotifier bool),
) error {
ctx := context.Background()
wrappedHandler := func(message string) {
var msg BackupSubmitMessage
if err := json.Unmarshal([]byte(message), &msg); err != nil {
r.logger.Warn("Failed to unmarshal backup submit message", "error", err)
return
}
if msg.NodeID != nodeID {
return
}
backupID, err := uuid.Parse(msg.BackupID)
if err != nil {
r.logger.Warn(
"Failed to parse backup ID from message",
"backupId",
msg.BackupID,
"error",
err,
)
return
}
handler(backupID, msg.IsCallNotifier)
}
err := r.pubsubBackups.Subscribe(ctx, backupSubmitChannel, wrappedHandler)
if err != nil {
return fmt.Errorf("failed to subscribe to backup submit channel: %w", err)
}
r.logger.Info("Subscribed to backup submit channel", "nodeID", nodeID)
return nil
}
func (r *BackupNodesRegistry) UnsubscribeNodeForBackupsAssignments() error {
err := r.pubsubBackups.Close()
if err != nil {
return fmt.Errorf("failed to unsubscribe from backup submit channel: %w", err)
}
r.logger.Info("Unsubscribed from backup submit channel")
return nil
}
func (r *BackupNodesRegistry) PublishBackupCompletion(nodeID string, backupID uuid.UUID) error {
ctx := context.Background()
message := BackupCompletionMessage{
NodeID: nodeID,
BackupID: backupID.String(),
}
messageJSON, err := json.Marshal(message)
if err != nil {
return fmt.Errorf("failed to marshal backup completion message: %w", err)
}
err = r.pubsubCompletions.Publish(ctx, backupCompletionChannel, string(messageJSON))
if err != nil {
return fmt.Errorf("failed to publish backup completion message: %w", err)
}
return nil
}
func (r *BackupNodesRegistry) SubscribeForBackupsCompletions(
handler func(nodeID string, backupID uuid.UUID),
) error {
ctx := context.Background()
wrappedHandler := func(message string) {
var msg BackupCompletionMessage
if err := json.Unmarshal([]byte(message), &msg); err != nil {
r.logger.Warn("Failed to unmarshal backup completion message", "error", err)
return
}
backupID, err := uuid.Parse(msg.BackupID)
if err != nil {
r.logger.Warn(
"Failed to parse backup ID from completion message",
"backupId",
msg.BackupID,
"error",
err,
)
return
}
handler(msg.NodeID, backupID)
}
err := r.pubsubCompletions.Subscribe(ctx, backupCompletionChannel, wrappedHandler)
if err != nil {
return fmt.Errorf("failed to subscribe to backup completion channel: %w", err)
}
r.logger.Info("Subscribed to backup completion channel")
return nil
}
func (r *BackupNodesRegistry) UnsubscribeForBackupsCompletions() error {
err := r.pubsubCompletions.Close()
if err != nil {
return fmt.Errorf("failed to unsubscribe from backup completion channel: %w", err)
}
r.logger.Info("Unsubscribed from backup completion channel")
return nil
}
func (r *BackupNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uuid.UUID {
nodeIDStr := strings.TrimPrefix(key, prefix)
nodeIDStr = strings.TrimSuffix(nodeIDStr, suffix)
nodeID, err := uuid.Parse(nodeIDStr)
if err != nil {
r.logger.Warn("Failed to parse node ID from key", "key", key, "error", err)
return uuid.Nil
}
return nodeID
}
func (r *BackupNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, error) {
if len(keys) == 0 {
return make(map[string][]byte), nil
}
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
commands := make([]valkey.Completed, 0, len(keys))
for _, key := range keys {
commands = append(commands, r.client.B().Get().Key(key).Build())
}
results := r.client.DoMulti(ctx, commands...)
keyDataMap := make(map[string][]byte, len(keys))
for i, result := range results {
if result.Error() != nil {
r.logger.Warn("Failed to get key in pipeline", "key", keys[i], "error", result.Error())
continue
}
data, err := result.AsBytes()
if err != nil {
r.logger.Warn("Failed to parse key data in pipeline", "key", keys[i], "error", err)
continue
}
keyDataMap[keys[i]] = data
}
return keyDataMap, nil
}
func (r *BackupNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
str := string(data)
var count int64
_, err := fmt.Sscanf(str, "%d", &count)
if err != nil {
return 0, fmt.Errorf("failed to parse integer from bytes: %w", err)
}
return count, nil
}

View File

@@ -1,904 +0,0 @@
package backuping
import (
"context"
"testing"
"time"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/logger"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_HearthbeatNodeInRegistry_RegistersNodeWithTTL(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer cleanupTestNode(registry, node)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 1)
assert.Equal(t, node.ID, nodes[0].ID)
assert.Equal(t, node.ThroughputMBs, nodes[0].ThroughputMBs)
}
func Test_UnregisterNodeFromRegistry_RemovesNodeAndCounter(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
err = registry.UnregisterNodeFromRegistry(node)
assert.NoError(t, err)
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Empty(t, nodes)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Empty(t, stats)
}
func Test_GetAvailableNodes_ReturnsAllRegisteredNodes(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
node3 := createTestBackupNode()
defer cleanupTestNode(registry, node1)
defer cleanupTestNode(registry, node2)
defer cleanupTestNode(registry, node3)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
assert.NoError(t, err)
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 3)
nodeIDs := make(map[uuid.UUID]bool)
for _, node := range nodes {
nodeIDs[node.ID] = true
}
assert.True(t, nodeIDs[node1.ID])
assert.True(t, nodeIDs[node2.ID])
assert.True(t, nodeIDs[node3.ID])
}
func Test_GetAvailableNodes_WhenNoNodesExist_ReturnsEmptySlice(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.NotNil(t, nodes)
assert.Empty(t, nodes)
}
func Test_IncrementBackupsInProgress_IncrementsCounter(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer cleanupTestNode(registry, node)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Len(t, stats, 1)
assert.Equal(t, node.ID, stats[0].ID)
assert.Equal(t, 1, stats[0].ActiveBackups)
err = registry.IncrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
stats, err = registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Len(t, stats, 1)
assert.Equal(t, 2, stats[0].ActiveBackups)
}
func Test_DecrementBackupsInProgress_DecrementsCounter(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer cleanupTestNode(registry, node)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Equal(t, 3, stats[0].ActiveBackups)
err = registry.DecrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
stats, err = registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Equal(t, 2, stats[0].ActiveBackups)
err = registry.DecrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
stats, err = registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Equal(t, 1, stats[0].ActiveBackups)
}
func Test_DecrementBackupsInProgress_WhenNegative_ResetsToZero(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer cleanupTestNode(registry, node)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
err = registry.DecrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Len(t, stats, 1)
assert.Equal(t, 0, stats[0].ActiveBackups)
}
func Test_GetBackupNodesStats_ReturnsStatsForAllNodes(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
node3 := createTestBackupNode()
defer cleanupTestNode(registry, node1)
defer cleanupTestNode(registry, node2)
defer cleanupTestNode(registry, node3)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node1.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node2.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node2.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node3.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node3.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node3.ID.String())
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Len(t, stats, 3)
statsMap := make(map[uuid.UUID]int)
for _, stat := range stats {
statsMap[stat.ID] = stat.ActiveBackups
}
assert.Equal(t, 1, statsMap[node1.ID])
assert.Equal(t, 2, statsMap[node2.ID])
assert.Equal(t, 3, statsMap[node3.ID])
}
func Test_GetBackupNodesStats_WhenNoStats_ReturnsEmptySlice(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.NotNil(t, stats)
assert.Empty(t, stats)
}
func Test_MultipleNodes_RegisteredAndQueriedCorrectly(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node1 := createTestBackupNode()
node1.ThroughputMBs = 50
node2 := createTestBackupNode()
node2.ThroughputMBs = 100
node3 := createTestBackupNode()
node3.ThroughputMBs = 150
defer cleanupTestNode(registry, node1)
defer cleanupTestNode(registry, node2)
defer cleanupTestNode(registry, node3)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
assert.NoError(t, err)
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 3)
nodeMap := make(map[uuid.UUID]BackupNode)
for _, node := range nodes {
nodeMap[node.ID] = node
}
assert.Equal(t, 50, nodeMap[node1.ID].ThroughputMBs)
assert.Equal(t, 100, nodeMap[node2.ID].ThroughputMBs)
assert.Equal(t, 150, nodeMap[node3.ID].ThroughputMBs)
}
func Test_BackupCounters_TrackedSeparatelyPerNode(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
defer cleanupTestNode(registry, node1)
defer cleanupTestNode(registry, node2)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node1.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node1.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node2.ID.String())
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Len(t, stats, 2)
statsMap := make(map[uuid.UUID]int)
for _, stat := range stats {
statsMap[stat.ID] = stat.ActiveBackups
}
assert.Equal(t, 2, statsMap[node1.ID])
assert.Equal(t, 1, statsMap[node2.ID])
err = registry.DecrementBackupsInProgress(node1.ID.String())
assert.NoError(t, err)
stats, err = registry.GetBackupNodesStats()
assert.NoError(t, err)
statsMap = make(map[uuid.UUID]int)
for _, stat := range stats {
statsMap[stat.ID] = stat.ActiveBackups
}
assert.Equal(t, 1, statsMap[node1.ID])
assert.Equal(t, 1, statsMap[node2.ID])
}
func Test_GetAvailableNodes_SkipsInvalidJsonData(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer cleanupTestNode(registry, node)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
defer cancel()
invalidKey := nodeInfoKeyPrefix + uuid.New().String() + nodeInfoKeySuffix
registry.client.Do(
ctx,
registry.client.B().Set().Key(invalidKey).Value("invalid json data").Build(),
)
defer func() {
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), registry.timeout)
defer cleanupCancel()
registry.client.Do(cleanupCtx, registry.client.B().Del().Key(invalidKey).Build())
}()
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 1)
assert.Equal(t, node.ID, nodes[0].ID)
}
func Test_PipelineGetKeys_HandlesEmptyKeysList(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
keyDataMap, err := registry.pipelineGetKeys([]string{})
assert.NoError(t, err)
assert.NotNil(t, keyDataMap)
assert.Empty(t, keyDataMap)
}
func Test_HearthbeatNodeInRegistry_UpdatesLastHeartbeat(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
originalHeartbeat := node.LastHeartbeat
defer cleanupTestNode(registry, node)
time.Sleep(10 * time.Millisecond)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 1)
assert.True(t, nodes[0].LastHeartbeat.After(originalHeartbeat))
}
func createTestRegistry() *BackupNodesRegistry {
return &BackupNodesRegistry{
cache_utils.GetValkeyClient(),
logger.GetLogger(),
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
}
}
func createTestBackupNode() BackupNode {
return BackupNode{
ID: uuid.New(),
ThroughputMBs: 100,
LastHeartbeat: time.Now().UTC(),
}
}
func cleanupTestNode(registry *BackupNodesRegistry, node BackupNode) {
registry.UnregisterNodeFromRegistry(node)
}
func Test_AssignBackupTonode_PublishesJsonMessageToChannel(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID := uuid.New()
err := registry.AssignBackupToNode(node.ID.String(), backupID, true)
assert.NoError(t, err)
}
func Test_SubscribeNodeForBackupsAssignment_ReceivesSubmittedBackupsForMatchingNode(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID := uuid.New()
defer registry.UnsubscribeNodeForBackupsAssignments()
receivedBackupID := make(chan uuid.UUID, 1)
handler := func(id uuid.UUID, isCallNotifier bool) {
receivedBackupID <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node.ID.String(), backupID, true)
assert.NoError(t, err)
select {
case received := <-receivedBackupID:
assert.Equal(t, backupID, received)
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for backup message")
}
}
func Test_SubscribeNodeForBackupsAssignment_FiltersOutBackupsForDifferentNode(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
backupID := uuid.New()
defer registry.UnsubscribeNodeForBackupsAssignments()
receivedBackupID := make(chan uuid.UUID, 1)
handler := func(id uuid.UUID, isCallNotifier bool) {
receivedBackupID <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node1.ID.String(), handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node2.ID.String(), backupID, false)
assert.NoError(t, err)
select {
case <-receivedBackupID:
t.Fatal("Should not receive backup for different node")
case <-time.After(500 * time.Millisecond):
}
}
func Test_SubscribeNodeForBackupsAssignment_ParsesJsonAndBackupIdCorrectly(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID1 := uuid.New()
backupID2 := uuid.New()
defer registry.UnsubscribeNodeForBackupsAssignments()
receivedBackups := make(chan uuid.UUID, 2)
handler := func(id uuid.UUID, isCallNotifier bool) {
receivedBackups <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node.ID.String(), backupID1, true)
assert.NoError(t, err)
err = registry.AssignBackupToNode(node.ID.String(), backupID2, false)
assert.NoError(t, err)
received1 := <-receivedBackups
received2 := <-receivedBackups
receivedIDs := []uuid.UUID{received1, received2}
assert.Contains(t, receivedIDs, backupID1)
assert.Contains(t, receivedIDs, backupID2)
}
func Test_SubscribeNodeForBackupsAssignment_HandlesInvalidJson(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer registry.UnsubscribeNodeForBackupsAssignments()
receivedBackupID := make(chan uuid.UUID, 1)
handler := func(id uuid.UUID, isCallNotifier bool) {
receivedBackupID <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
ctx := context.Background()
err = registry.pubsubBackups.Publish(ctx, "backup:submit", "invalid json")
assert.NoError(t, err)
select {
case <-receivedBackupID:
t.Fatal("Should not receive backup for invalid JSON")
case <-time.After(500 * time.Millisecond):
}
}
func Test_UnsubscribeNodeForBackupsAssignments_StopsReceivingMessages(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID1 := uuid.New()
backupID2 := uuid.New()
receivedBackupID := make(chan uuid.UUID, 2)
handler := func(id uuid.UUID, isCallNotifier bool) {
receivedBackupID <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node.ID.String(), backupID1, true)
assert.NoError(t, err)
received := <-receivedBackupID
assert.Equal(t, backupID1, received)
err = registry.UnsubscribeNodeForBackupsAssignments()
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node.ID.String(), backupID2, false)
assert.NoError(t, err)
select {
case <-receivedBackupID:
t.Fatal("Should not receive backup after unsubscribe")
case <-time.After(500 * time.Millisecond):
}
}
func Test_SubscribeNodeForBackupsAssignment_WhenAlreadySubscribed_ReturnsError(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer registry.UnsubscribeNodeForBackupsAssignments()
handler := func(id uuid.UUID, isCallNotifier bool) {}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
assert.NoError(t, err)
err = registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
assert.Error(t, err)
assert.Contains(t, err.Error(), "already subscribed")
}
func Test_MultipleNodes_EachReceivesOnlyTheirBackups(t *testing.T) {
cache_utils.ClearAllCache()
registry1 := createTestRegistry()
registry2 := createTestRegistry()
registry3 := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
node3 := createTestBackupNode()
backupID1 := uuid.New()
backupID2 := uuid.New()
backupID3 := uuid.New()
defer registry1.UnsubscribeNodeForBackupsAssignments()
defer registry2.UnsubscribeNodeForBackupsAssignments()
defer registry3.UnsubscribeNodeForBackupsAssignments()
receivedBackups1 := make(chan uuid.UUID, 3)
receivedBackups2 := make(chan uuid.UUID, 3)
receivedBackups3 := make(chan uuid.UUID, 3)
handler1 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups1 <- id }
handler2 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups2 <- id }
handler3 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups3 <- id }
err := registry1.SubscribeNodeForBackupsAssignment(node1.ID.String(), handler1)
assert.NoError(t, err)
err = registry2.SubscribeNodeForBackupsAssignment(node2.ID.String(), handler2)
assert.NoError(t, err)
err = registry3.SubscribeNodeForBackupsAssignment(node3.ID.String(), handler3)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
submitRegistry := createTestRegistry()
err = submitRegistry.AssignBackupToNode(node1.ID.String(), backupID1, true)
assert.NoError(t, err)
err = submitRegistry.AssignBackupToNode(node2.ID.String(), backupID2, false)
assert.NoError(t, err)
err = submitRegistry.AssignBackupToNode(node3.ID.String(), backupID3, true)
assert.NoError(t, err)
select {
case received := <-receivedBackups1:
assert.Equal(t, backupID1, received)
case <-time.After(2 * time.Second):
t.Fatal("Node 1 timeout waiting for backup message")
}
select {
case received := <-receivedBackups2:
assert.Equal(t, backupID2, received)
case <-time.After(2 * time.Second):
t.Fatal("Node 2 timeout waiting for backup message")
}
select {
case received := <-receivedBackups3:
assert.Equal(t, backupID3, received)
case <-time.After(2 * time.Second):
t.Fatal("Node 3 timeout waiting for backup message")
}
select {
case <-receivedBackups1:
t.Fatal("Node 1 should not receive additional backups")
case <-time.After(300 * time.Millisecond):
}
select {
case <-receivedBackups2:
t.Fatal("Node 2 should not receive additional backups")
case <-time.After(300 * time.Millisecond):
}
select {
case <-receivedBackups3:
t.Fatal("Node 3 should not receive additional backups")
case <-time.After(300 * time.Millisecond):
}
}
func Test_PublishBackupCompletion_PublishesMessageToChannel(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID := uuid.New()
err := registry.PublishBackupCompletion(node.ID.String(), backupID)
assert.NoError(t, err)
}
func Test_SubscribeForBackupsCompletions_ReceivesCompletedBackups(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID := uuid.New()
defer registry.UnsubscribeForBackupsCompletions()
receivedBackupID := make(chan uuid.UUID, 1)
receivedNodeID := make(chan string, 1)
handler := func(nodeID string, backupID uuid.UUID) {
receivedNodeID <- nodeID
receivedBackupID <- backupID
}
err := registry.SubscribeForBackupsCompletions(handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.PublishBackupCompletion(node.ID.String(), backupID)
assert.NoError(t, err)
select {
case receivedNode := <-receivedNodeID:
assert.Equal(t, node.ID.String(), receivedNode)
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for node ID")
}
select {
case received := <-receivedBackupID:
assert.Equal(t, backupID, received)
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for backup completion message")
}
}
func Test_SubscribeForBackupsCompletions_ParsesJsonCorrectly(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID1 := uuid.New()
backupID2 := uuid.New()
defer registry.UnsubscribeForBackupsCompletions()
receivedBackups := make(chan uuid.UUID, 2)
handler := func(nodeID string, backupID uuid.UUID) {
receivedBackups <- backupID
}
err := registry.SubscribeForBackupsCompletions(handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.PublishBackupCompletion(node.ID.String(), backupID1)
assert.NoError(t, err)
err = registry.PublishBackupCompletion(node.ID.String(), backupID2)
assert.NoError(t, err)
received1 := <-receivedBackups
received2 := <-receivedBackups
receivedIDs := []uuid.UUID{received1, received2}
assert.Contains(t, receivedIDs, backupID1)
assert.Contains(t, receivedIDs, backupID2)
}
func Test_SubscribeForBackupsCompletions_HandlesInvalidJson(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
defer registry.UnsubscribeForBackupsCompletions()
receivedBackupID := make(chan uuid.UUID, 1)
handler := func(nodeID string, backupID uuid.UUID) {
receivedBackupID <- backupID
}
err := registry.SubscribeForBackupsCompletions(handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
ctx := context.Background()
err = registry.pubsubCompletions.Publish(ctx, "backup:completion", "invalid json")
assert.NoError(t, err)
select {
case <-receivedBackupID:
t.Fatal("Should not receive backup for invalid JSON")
case <-time.After(500 * time.Millisecond):
}
}
func Test_UnsubscribeForBackupsCompletions_StopsReceivingMessages(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID1 := uuid.New()
backupID2 := uuid.New()
receivedBackupID := make(chan uuid.UUID, 2)
handler := func(nodeID string, backupID uuid.UUID) {
receivedBackupID <- backupID
}
err := registry.SubscribeForBackupsCompletions(handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.PublishBackupCompletion(node.ID.String(), backupID1)
assert.NoError(t, err)
received := <-receivedBackupID
assert.Equal(t, backupID1, received)
err = registry.UnsubscribeForBackupsCompletions()
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.PublishBackupCompletion(node.ID.String(), backupID2)
assert.NoError(t, err)
select {
case <-receivedBackupID:
t.Fatal("Should not receive backup after unsubscribe")
case <-time.After(500 * time.Millisecond):
}
}
func Test_SubscribeForBackupsCompletions_WhenAlreadySubscribed_ReturnsError(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
defer registry.UnsubscribeForBackupsCompletions()
handler := func(nodeID string, backupID uuid.UUID) {}
err := registry.SubscribeForBackupsCompletions(handler)
assert.NoError(t, err)
err = registry.SubscribeForBackupsCompletions(handler)
assert.Error(t, err)
assert.Contains(t, err.Error(), "already subscribed")
}
func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) {
cache_utils.ClearAllCache()
registry1 := createTestRegistry()
registry2 := createTestRegistry()
registry3 := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
node3 := createTestBackupNode()
backupID1 := uuid.New()
backupID2 := uuid.New()
backupID3 := uuid.New()
defer registry1.UnsubscribeForBackupsCompletions()
defer registry2.UnsubscribeForBackupsCompletions()
defer registry3.UnsubscribeForBackupsCompletions()
receivedBackups1 := make(chan uuid.UUID, 3)
receivedBackups2 := make(chan uuid.UUID, 3)
receivedBackups3 := make(chan uuid.UUID, 3)
handler1 := func(nodeID string, backupID uuid.UUID) { receivedBackups1 <- backupID }
handler2 := func(nodeID string, backupID uuid.UUID) { receivedBackups2 <- backupID }
handler3 := func(nodeID string, backupID uuid.UUID) { receivedBackups3 <- backupID }
err := registry1.SubscribeForBackupsCompletions(handler1)
assert.NoError(t, err)
err = registry2.SubscribeForBackupsCompletions(handler2)
assert.NoError(t, err)
err = registry3.SubscribeForBackupsCompletions(handler3)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
publishRegistry := createTestRegistry()
err = publishRegistry.PublishBackupCompletion(node1.ID.String(), backupID1)
assert.NoError(t, err)
err = publishRegistry.PublishBackupCompletion(node2.ID.String(), backupID2)
assert.NoError(t, err)
err = publishRegistry.PublishBackupCompletion(node3.ID.String(), backupID3)
assert.NoError(t, err)
receivedAll1 := []uuid.UUID{}
receivedAll2 := []uuid.UUID{}
receivedAll3 := []uuid.UUID{}
for i := 0; i < 3; i++ {
select {
case received := <-receivedBackups1:
receivedAll1 = append(receivedAll1, received)
case <-time.After(2 * time.Second):
t.Fatal("Subscriber 1 timeout waiting for completion message")
}
}
for i := 0; i < 3; i++ {
select {
case received := <-receivedBackups2:
receivedAll2 = append(receivedAll2, received)
case <-time.After(2 * time.Second):
t.Fatal("Subscriber 2 timeout waiting for completion message")
}
}
for i := 0; i < 3; i++ {
select {
case received := <-receivedBackups3:
receivedAll3 = append(receivedAll3, received)
case <-time.After(2 * time.Second):
t.Fatal("Subscriber 3 timeout waiting for completion message")
}
}
assert.Contains(t, receivedAll1, backupID1)
assert.Contains(t, receivedAll1, backupID2)
assert.Contains(t, receivedAll1, backupID3)
assert.Contains(t, receivedAll2, backupID1)
assert.Contains(t, receivedAll2, backupID2)
assert.Contains(t, receivedAll2, backupID3)
assert.Contains(t, receivedAll3, backupID1)
assert.Contains(t, receivedAll3, backupID2)
assert.Contains(t, receivedAll3, backupID3)
}

View File

@@ -3,10 +3,11 @@ package backuping
import (
"context"
"databasus-backend/internal/config"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/storages"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
task_registry "databasus-backend/internal/features/tasks/registry"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/period"
"fmt"
@@ -16,12 +17,18 @@ import (
"github.com/google/uuid"
)
const (
schedulerStartupDelay = 1 * time.Minute
schedulerTickerInterval = 1 * time.Minute
schedulerHealthcheckThreshold = 5 * time.Minute
)
type BackupsScheduler struct {
backupRepository *backups_core.BackupRepository
backupConfigService *backups_config.BackupConfigService
storageService *storages.StorageService
backupCancelManager *backups_cancellation.BackupCancelManager
nodesRegistry *BackupNodesRegistry
taskCancelManager *task_cancellation.TaskCancelManager
tasksRegistry *task_registry.TaskNodesRegistry
lastBackupTime time.Time
logger *slog.Logger
@@ -35,7 +42,7 @@ func (s *BackupsScheduler) Run(ctx context.Context) {
if config.GetEnv().IsManyNodesMode {
// wait other nodes to start
time.Sleep(1 * time.Minute)
time.Sleep(schedulerStartupDelay)
}
if err := s.failBackupsInProgress(); err != nil {
@@ -43,12 +50,12 @@ func (s *BackupsScheduler) Run(ctx context.Context) {
panic(err)
}
if err := s.nodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted); err != nil {
if err := s.tasksRegistry.SubscribeForTasksCompletions(s.onBackupCompleted); err != nil {
s.logger.Error("Failed to subscribe to backup completions", "error", err)
panic(err)
}
defer func() {
if err := s.nodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
if err := s.tasksRegistry.UnsubscribeForTasksCompletions(); err != nil {
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
}
}()
@@ -57,7 +64,7 @@ func (s *BackupsScheduler) Run(ctx context.Context) {
return
}
ticker := time.NewTicker(1 * time.Minute)
ticker := time.NewTicker(schedulerTickerInterval)
defer ticker.Stop()
for {
@@ -84,7 +91,7 @@ func (s *BackupsScheduler) Run(ctx context.Context) {
func (s *BackupsScheduler) IsSchedulerRunning() bool {
// if last backup time is more than 5 minutes ago, return false
return s.lastBackupTime.After(time.Now().UTC().Add(-5 * time.Minute))
return s.lastBackupTime.After(time.Now().UTC().Add(-schedulerHealthcheckThreshold))
}
func (s *BackupsScheduler) failBackupsInProgress() error {
@@ -93,12 +100,10 @@ func (s *BackupsScheduler) failBackupsInProgress() error {
return err
}
fmt.Println("Backups in progress", len(backupsInProgress))
for _, backup := range backupsInProgress {
if err := s.backupCancelManager.CancelBackup(backup.ID); err != nil {
if err := s.taskCancelManager.CancelTask(backup.ID); err != nil {
s.logger.Error(
"Failed to cancel backup via context manager",
"Failed to cancel backup via task cancel manager",
"backupId",
backup.ID,
"error",
@@ -175,7 +180,7 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
return
}
if err := s.nodesRegistry.IncrementBackupsInProgress(leastBusyNodeID.String()); err != nil {
if err := s.tasksRegistry.IncrementTasksInProgress(leastBusyNodeID.String()); err != nil {
s.logger.Error(
"Failed to increment backups in progress",
"nodeId",
@@ -188,7 +193,7 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
return
}
if err := s.nodesRegistry.AssignBackupToNode(leastBusyNodeID.String(), backup.ID, isCallNotifier); err != nil {
if err := s.tasksRegistry.AssignTaskToNode(leastBusyNodeID.String(), backup.ID, isCallNotifier); err != nil {
s.logger.Error(
"Failed to submit backup",
"nodeId",
@@ -198,7 +203,7 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
"error",
err,
)
if decrementErr := s.nodesRegistry.DecrementBackupsInProgress(leastBusyNodeID.String()); decrementErr != nil {
if decrementErr := s.tasksRegistry.DecrementTasksInProgress(leastBusyNodeID.String()); decrementErr != nil {
s.logger.Error(
"Failed to decrement backups in progress after submit failure",
"nodeId",
@@ -393,7 +398,7 @@ func (s *BackupsScheduler) runPendingBackups() error {
}
func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
nodes, err := s.nodesRegistry.GetAvailableNodes()
nodes, err := s.tasksRegistry.GetAvailableNodes()
if err != nil {
return nil, fmt.Errorf("failed to get available nodes: %w", err)
}
@@ -402,27 +407,22 @@ func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
return nil, fmt.Errorf("no nodes available")
}
stats, err := s.nodesRegistry.GetBackupNodesStats()
stats, err := s.tasksRegistry.GetNodesStats()
if err != nil {
return nil, fmt.Errorf("failed to get backup nodes stats: %w", err)
}
statsMap := make(map[uuid.UUID]int)
for _, stat := range stats {
statsMap[stat.ID] = stat.ActiveBackups
statsMap[stat.ID] = stat.ActiveTasks
}
var bestNode *BackupNode
var bestNode *task_registry.TaskNode
var bestScore float64 = -1
now := time.Now().UTC()
for i := range nodes {
node := &nodes[i]
if now.Sub(node.LastHeartbeat) > 2*time.Minute {
continue
}
activeBackups := statsMap[node.ID]
var score float64
@@ -458,6 +458,13 @@ func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUI
return
}
// Verify this task is actually a backup (registry contains multiple task types)
_, err = s.backupRepository.FindByID(backupID)
if err != nil {
// Not a backup task, ignore it
return
}
relation, exists := s.backupToNodeRelations[nodeID]
if !exists {
s.logger.Warn(
@@ -498,7 +505,7 @@ func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUI
s.backupToNodeRelations[nodeID] = relation
}
if err := s.nodesRegistry.DecrementBackupsInProgress(nodeIDStr); err != nil {
if err := s.tasksRegistry.DecrementTasksInProgress(nodeIDStr); err != nil {
s.logger.Error(
"Failed to decrement backups in progress",
"nodeId",
@@ -512,18 +519,14 @@ func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUI
}
func (s *BackupsScheduler) checkDeadNodesAndFailBackups() error {
nodes, err := s.nodesRegistry.GetAvailableNodes()
nodes, err := s.tasksRegistry.GetAvailableNodes()
if err != nil {
return fmt.Errorf("failed to get available nodes: %w", err)
}
aliveNodeIDs := make(map[uuid.UUID]bool)
now := time.Now().UTC()
for _, node := range nodes {
if now.Sub(node.LastHeartbeat) <= 2*time.Minute {
aliveNodeIDs[node.ID] = true
}
aliveNodeIDs[node.ID] = true
}
for nodeID, relation := range s.backupToNodeRelations {
@@ -572,7 +575,7 @@ func (s *BackupsScheduler) checkDeadNodesAndFailBackups() error {
continue
}
if err := s.nodesRegistry.DecrementBackupsInProgress(nodeID.String()); err != nil {
if err := s.tasksRegistry.DecrementTasksInProgress(nodeID.String()); err != nil {
s.logger.Error(
"Failed to decrement backups in progress for dead node",
"nodeId",

View File

@@ -7,6 +7,7 @@ import (
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
task_registry "databasus-backend/internal/features/tasks/registry"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
@@ -42,8 +43,8 @@ func Test_RunPendingBackups_WhenLastBackupWasYesterday_CreatesNewBackup(t *testi
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
@@ -111,8 +112,8 @@ func Test_RunPendingBackups_WhenLastBackupWasRecentlyCompleted_SkipsBackup(t *te
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
@@ -179,8 +180,8 @@ func Test_RunPendingBackups_WhenLastBackupFailedAndRetriesDisabled_SkipsBackup(t
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
@@ -251,8 +252,8 @@ func Test_RunPendingBackups_WhenLastBackupFailedAndRetriesEnabled_CreatesNewBack
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
@@ -324,8 +325,8 @@ func Test_RunPendingBackups_WhenFailedBackupsExceedMaxRetries_SkipsBackup(t *tes
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
@@ -396,8 +397,8 @@ func Test_RunPendingBackups_WhenBackupsDisabled_SkipsBackup(t *testing.T) {
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
@@ -449,6 +450,8 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
var mockNodeID uuid.UUID
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
@@ -457,9 +460,15 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
// Clean up mock node
if mockNodeID != uuid.Nil {
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: mockNodeID})
}
cache_utils.ClearAllCache()
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
@@ -479,7 +488,7 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
assert.NoError(t, err)
// Register mock node without subscribing to backups (simulates node crash after registration)
mockNodeID := uuid.New()
mockNodeID = uuid.New()
err = CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
assert.NoError(t, err)
@@ -493,12 +502,12 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status)
// Verify Valkey counter was incremented when backup was assigned
stats, err := nodesRegistry.GetBackupNodesStats()
stats, err := nodesRegistry.GetNodesStats()
assert.NoError(t, err)
foundStat := false
for _, stat := range stats {
if stat.ID == mockNodeID {
assert.Equal(t, 1, stat.ActiveBackups)
assert.Equal(t, 1, stat.ActiveTasks)
foundStat = true
break
}
@@ -523,19 +532,117 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
assert.Contains(t, *backups[0].FailMessage, "node unavailability")
// Verify Valkey counter was decremented after backup failed
stats, err = nodesRegistry.GetBackupNodesStats()
stats, err = nodesRegistry.GetNodesStats()
assert.NoError(t, err)
for _, stat := range stats {
if stat.ID == mockNodeID {
assert.Equal(t, 0, stat.ActiveBackups)
assert.Equal(t, 0, stat.ActiveTasks)
}
}
// Node info should still exist in registry (not removed by checkDeadNodesAndFailBackups)
node, err := GetNodeFromRegistry(mockNodeID)
time.Sleep(200 * time.Millisecond)
}
func Test_OnBackupCompleted_WhenTaskIsNotBackup_SkipsProcessing(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
var mockNodeID uuid.UUID
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
// Clean up mock node
if mockNodeID != uuid.Nil {
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: mockNodeID})
}
cache_utils.ClearAllCache()
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
assert.NotNil(t, node)
assert.Equal(t, mockNodeID, node.ID)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Register mock node
mockNodeID = uuid.New()
err = CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
assert.NoError(t, err)
// Start a backup and assign it to the node
GetBackupsScheduler().StartBackup(database.ID, false)
time.Sleep(100 * time.Millisecond)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status)
// Get initial state of the registry
initialStats, err := nodesRegistry.GetNodesStats()
assert.NoError(t, err)
var initialActiveTasks int
for _, stat := range initialStats {
if stat.ID == mockNodeID {
initialActiveTasks = stat.ActiveTasks
break
}
}
assert.Equal(t, 1, initialActiveTasks, "Should have 1 active task")
// Call onBackupCompleted with a random UUID (not a backup ID)
nonBackupTaskID := uuid.New()
GetBackupsScheduler().onBackupCompleted(mockNodeID.String(), nonBackupTaskID)
time.Sleep(100 * time.Millisecond)
// Verify: Active tasks counter should remain the same (not decremented)
stats, err := nodesRegistry.GetNodesStats()
assert.NoError(t, err)
for _, stat := range stats {
if stat.ID == mockNodeID {
assert.Equal(t, initialActiveTasks, stat.ActiveTasks,
"Active tasks should not change for non-backup task")
}
}
// Verify: backup should still be in progress (not modified)
backups, err = backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status,
"Backup status should not change for non-backup task completion")
// Verify: backupToNodeRelations should still contain the node
scheduler := GetBackupsScheduler()
_, exists := scheduler.backupToNodeRelations[mockNodeID]
assert.True(t, exists, "Node should still be in backupToNodeRelations")
time.Sleep(200 * time.Millisecond)
}
@@ -549,6 +656,14 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
node3ID := uuid.New()
now := time.Now().UTC()
defer func() {
// Clean up all mock nodes
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node1ID})
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node2ID})
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node3ID})
cache_utils.ClearAllCache()
}()
err := CreateMockNodeInRegistry(node1ID, 100, now)
assert.NoError(t, err)
err = CreateMockNodeInRegistry(node2ID, 100, now)
@@ -557,17 +672,17 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
assert.NoError(t, err)
for range 5 {
err = nodesRegistry.IncrementBackupsInProgress(node1ID.String())
err = nodesRegistry.IncrementTasksInProgress(node1ID.String())
assert.NoError(t, err)
}
for range 2 {
err = nodesRegistry.IncrementBackupsInProgress(node2ID.String())
err = nodesRegistry.IncrementTasksInProgress(node2ID.String())
assert.NoError(t, err)
}
for range 8 {
err = nodesRegistry.IncrementBackupsInProgress(node3ID.String())
err = nodesRegistry.IncrementTasksInProgress(node3ID.String())
assert.NoError(t, err)
}
@@ -584,17 +699,24 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
node50MBsID := uuid.New()
now := time.Now().UTC()
defer func() {
// Clean up all mock nodes
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node100MBsID})
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node50MBsID})
cache_utils.ClearAllCache()
}()
err := CreateMockNodeInRegistry(node100MBsID, 100, now)
assert.NoError(t, err)
err = CreateMockNodeInRegistry(node50MBsID, 50, now)
assert.NoError(t, err)
for range 10 {
err = nodesRegistry.IncrementBackupsInProgress(node100MBsID.String())
err = nodesRegistry.IncrementTasksInProgress(node100MBsID.String())
assert.NoError(t, err)
}
err = nodesRegistry.IncrementBackupsInProgress(node50MBsID.String())
err = nodesRegistry.IncrementTasksInProgress(node50MBsID.String())
assert.NoError(t, err)
leastBusyNodeID, err := GetBackupsScheduler().calculateLeastBusyNode()
@@ -622,9 +744,11 @@ func Test_FailBackupsInProgress_WhenSchedulerStarts_CancelsBackupsAndUpdatesStat
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
cache_utils.ClearAllCache()
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
@@ -707,3 +831,204 @@ func Test_FailBackupsInProgress_WhenSchedulerStarts_CancelsBackupsAndUpdatesStat
time.Sleep(200 * time.Millisecond)
}
func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T) {
cache_utils.ClearAllCache()
// Start scheduler so it can handle task completions
schedulerCancel := StartSchedulerForTest(t)
defer schedulerCancel()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Get initial active task count
stats, err := nodesRegistry.GetNodesStats()
assert.NoError(t, err)
var initialActiveTasks int
for _, stat := range stats {
if stat.ID == backuperNode.nodeID {
initialActiveTasks = stat.ActiveTasks
break
}
}
t.Logf("Initial active tasks: %d", initialActiveTasks)
// Start backup
GetBackupsScheduler().StartBackup(database.ID, false)
// Wait for backup to complete
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
// Verify backup was created and completed
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
assert.Equal(t, backups_core.BackupStatusCompleted, backups[0].Status)
// Wait for active task count to decrease
decreased := WaitForActiveTasksDecrease(
t,
backuperNode.nodeID,
initialActiveTasks+1,
10*time.Second,
)
assert.True(t, decreased, "Active task count should have decreased after backup completion")
// Verify final active task count equals initial count
finalStats, err := nodesRegistry.GetNodesStats()
assert.NoError(t, err)
for _, stat := range finalStats {
if stat.ID == backuperNode.nodeID {
t.Logf("Final active tasks: %d", stat.ActiveTasks)
assert.Equal(t, initialActiveTasks, stat.ActiveTasks,
"Active task count should return to initial value after backup completion")
break
}
}
time.Sleep(200 * time.Millisecond)
}
func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
cache_utils.ClearAllCache()
// Start scheduler so it can handle task completions
schedulerCancel := StartSchedulerForTest(t)
defer schedulerCancel()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Set wrong password to cause backup failure
// We need to bypass service layer validation which would fail on connection test
database.Postgresql.Password = "intentionally_wrong_password"
dbRepo := &databases.DatabaseRepository{}
_, err := dbRepo.Save(database)
assert.NoError(t, err)
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Get initial active task count
stats, err := nodesRegistry.GetNodesStats()
assert.NoError(t, err)
var initialActiveTasks int
for _, stat := range stats {
if stat.ID == backuperNode.nodeID {
initialActiveTasks = stat.ActiveTasks
break
}
}
t.Logf("Initial active tasks: %d", initialActiveTasks)
// Start backup
GetBackupsScheduler().StartBackup(database.ID, false)
// Wait for backup to fail
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
// Verify backup was created and failed
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
assert.Equal(t, backups_core.BackupStatusFailed, backups[0].Status)
assert.NotNil(t, backups[0].FailMessage)
if backups[0].FailMessage != nil {
t.Logf("Backup failed with message: %s", *backups[0].FailMessage)
}
// Wait for active task count to decrease
decreased := WaitForActiveTasksDecrease(
t,
backuperNode.nodeID,
initialActiveTasks+1,
10*time.Second,
)
assert.True(t, decreased, "Active task count should have decreased after backup failure")
// Verify final active task count equals initial count
finalStats, err := nodesRegistry.GetNodesStats()
assert.NoError(t, err)
for _, stat := range finalStats {
if stat.ID == backuperNode.nodeID {
t.Logf("Final active tasks: %d", stat.ActiveTasks)
assert.Equal(t, initialActiveTasks, stat.ActiveTasks,
"Active task count should return to initial value after backup failure")
break
}
}
time.Sleep(200 * time.Millisecond)
}

View File

@@ -12,6 +12,7 @@ import (
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
task_registry "databasus-backend/internal/features/tasks/registry"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_services "databasus-backend/internal/features/workspaces/services"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
@@ -42,7 +43,7 @@ func CreateTestBackuperNode() *BackuperNode {
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
backupCancelManager,
taskCancelManager,
nodesRegistry,
logger.GetLogger(),
usecases.GetCreateBackupUsecase(),
@@ -138,6 +139,34 @@ func StartBackuperNodeForTest(t *testing.T, backuperNode *BackuperNode) context.
return nil
}
// StartSchedulerForTest starts the BackupsScheduler in a goroutine for testing.
// The scheduler subscribes to task completions and manages backup lifecycle.
// Returns a context cancel function that should be deferred to stop the scheduler.
func StartSchedulerForTest(t *testing.T) context.CancelFunc {
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
GetBackupsScheduler().Run(ctx)
close(done)
}()
// Give scheduler time to subscribe to completions
time.Sleep(100 * time.Millisecond)
t.Log("BackupsScheduler started")
return func() {
cancel()
select {
case <-done:
t.Log("BackupsScheduler stopped gracefully")
case <-time.After(2 * time.Second):
t.Log("BackupsScheduler stop timeout")
}
}
}
// StopBackuperNodeForTest stops the BackuperNode by canceling its context.
// It waits for the node to unregister from the registry.
func StopBackuperNodeForTest(t *testing.T, cancel context.CancelFunc, backuperNode *BackuperNode) {
@@ -167,7 +196,7 @@ func StopBackuperNodeForTest(t *testing.T, cancel context.CancelFunc, backuperNo
}
func CreateMockNodeInRegistry(nodeID uuid.UUID, throughputMBs int, lastHeartbeat time.Time) error {
backupNode := BackupNode{
backupNode := task_registry.TaskNode{
ID: nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: lastHeartbeat,
@@ -181,7 +210,7 @@ func UpdateNodeHeartbeatDirectly(
throughputMBs int,
lastHeartbeat time.Time,
) error {
backupNode := BackupNode{
backupNode := task_registry.TaskNode{
ID: nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: lastHeartbeat,
@@ -190,7 +219,7 @@ func UpdateNodeHeartbeatDirectly(
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
}
func GetNodeFromRegistry(nodeID uuid.UUID) (*BackupNode, error) {
func GetNodeFromRegistry(nodeID uuid.UUID) (*task_registry.TaskNode, error) {
nodes, err := nodesRegistry.GetAvailableNodes()
if err != nil {
return nil, err
@@ -204,3 +233,48 @@ func GetNodeFromRegistry(nodeID uuid.UUID) (*BackupNode, error) {
return nil, fmt.Errorf("node not found")
}
// WaitForActiveTasksDecrease waits for the active task count to decrease below the initial count.
// It polls the registry every 500ms until the count decreases or the timeout is reached.
// Returns true if the count decreased, false if timeout was reached.
func WaitForActiveTasksDecrease(
t *testing.T,
nodeID uuid.UUID,
initialCount int,
timeout time.Duration,
) bool {
deadline := time.Now().UTC().Add(timeout)
for time.Now().UTC().Before(deadline) {
stats, err := nodesRegistry.GetNodesStats()
if err != nil {
t.Logf("WaitForActiveTasksDecrease: error getting node stats: %v", err)
time.Sleep(500 * time.Millisecond)
continue
}
for _, stat := range stats {
if stat.ID == nodeID {
t.Logf(
"WaitForActiveTasksDecrease: current active tasks = %d (initial = %d)",
stat.ActiveTasks,
initialCount,
)
if stat.ActiveTasks < initialCount {
t.Logf(
"WaitForActiveTasksDecrease: active tasks decreased from %d to %d",
initialCount,
stat.ActiveTasks,
)
return true
}
break
}
}
time.Sleep(500 * time.Millisecond)
}
t.Logf("WaitForActiveTasksDecrease: timeout waiting for active tasks to decrease")
return false
}

View File

@@ -1,75 +0,0 @@
package backups_cancellation
import (
"context"
cache_utils "databasus-backend/internal/util/cache"
"log/slog"
"sync"
"github.com/google/uuid"
)
const backupCancelChannel = "backup:cancel"
type BackupCancelManager struct {
mu sync.RWMutex
cancelFuncs map[uuid.UUID]context.CancelFunc
pubsub *cache_utils.PubSubManager
logger *slog.Logger
}
func (m *BackupCancelManager) StartSubscription() {
ctx := context.Background()
handler := func(message string) {
backupID, err := uuid.Parse(message)
if err != nil {
m.logger.Error("Invalid backup ID in cancel message", "message", message, "error", err)
return
}
m.mu.Lock()
defer m.mu.Unlock()
cancelFunc, exists := m.cancelFuncs[backupID]
if exists {
cancelFunc()
delete(m.cancelFuncs, backupID)
m.logger.Info("Cancelled backup via Pub/Sub", "backupID", backupID)
}
}
err := m.pubsub.Subscribe(ctx, backupCancelChannel, handler)
if err != nil {
m.logger.Error("Failed to subscribe to backup cancel channel", "error", err)
} else {
m.logger.Info("Successfully subscribed to backup cancel channel")
}
}
func (m *BackupCancelManager) RegisterBackup(backupID uuid.UUID, cancelFunc context.CancelFunc) {
m.mu.Lock()
defer m.mu.Unlock()
m.cancelFuncs[backupID] = cancelFunc
m.logger.Debug("Registered backup", "backupID", backupID)
}
func (m *BackupCancelManager) CancelBackup(backupID uuid.UUID) error {
ctx := context.Background()
err := m.pubsub.Publish(ctx, backupCancelChannel, backupID.String())
if err != nil {
m.logger.Error("Failed to publish cancel message", "backupID", backupID, "error", err)
return err
}
m.logger.Info("Published backup cancel message", "backupID", backupID)
return nil
}
func (m *BackupCancelManager) UnregisterBackup(backupID uuid.UUID) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.cancelFuncs, backupID)
m.logger.Debug("Unregistered backup", "backupID", backupID)
}

View File

@@ -913,7 +913,7 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
assert.NoError(t, err)
// Register a cancellable context for the backup
GetBackupService().backupCancelManager.RegisterBackup(backup.ID, func() {})
GetBackupService().taskCancelManager.RegisterTask(backup.ID, func() {})
resp := test_utils.MakePostRequest(
t,

View File

@@ -3,7 +3,6 @@ package backups
import (
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups/backuping"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
"databasus-backend/internal/features/backups/backups/usecases"
@@ -12,6 +11,7 @@ import (
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
@@ -19,7 +19,7 @@ import (
var backupRepository = &backups_core.BackupRepository{}
var backupCancelManager = backups_cancellation.GetBackupCancelManager()
var taskCancelManager = task_cancellation.GetTaskCancelManager()
var backupService = &BackupService{
databaseService: databases.GetDatabaseService(),
@@ -35,7 +35,7 @@ var backupService = &BackupService{
backupRemoveListeners: []backups_core.BackupRemoveListener{},
workspaceService: workspaces_services.GetWorkspaceService(),
auditLogService: audit_logs.GetAuditLogService(),
backupCancelManager: backupCancelManager,
taskCancelManager: taskCancelManager,
downloadTokenService: backups_download.GetDownloadTokenService(),
backupSchedulerService: backuping.GetBackupsScheduler(),
}

View File

@@ -9,7 +9,6 @@ import (
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups/backuping"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
"databasus-backend/internal/features/backups/backups/encryption"
@@ -18,6 +17,7 @@ import (
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
users_models "databasus-backend/internal/features/users/models"
workspaces_services "databasus-backend/internal/features/workspaces/services"
util_encryption "databasus-backend/internal/util/encryption"
@@ -43,7 +43,7 @@ type BackupService struct {
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
backupCancelManager *backups_cancellation.BackupCancelManager
taskCancelManager *task_cancellation.TaskCancelManager
downloadTokenService *backups_download.DownloadTokenService
backupSchedulerService *backuping.BackupsScheduler
}
@@ -226,7 +226,7 @@ func (s *BackupService) CancelBackup(
return errors.New("backup is not in progress")
}
if err := s.backupCancelManager.CancelBackup(backupID); err != nil {
if err := s.taskCancelManager.CancelTask(backupID); err != nil {
return err
}

View File

@@ -85,6 +85,27 @@ func (p *PostgresqlDatabase) Validate() error {
return errors.New("cpu count must be greater than 0")
}
// Prevent Databasus from backing up itself
// Databasus runs an internal PostgreSQL instance that should not be backed up through the UI
// because it would expose internal metadata to non-system administrators.
// To properly backup Databasus, see: https://databasus.com/faq#backup-databasus
if p.Database != nil && *p.Database != "" {
localhostHosts := []string{"localhost", "127.0.0.1", "172.17.0.1", "host.docker.internal"}
isLocalhost := false
for _, host := range localhostHosts {
if strings.EqualFold(p.Host, host) {
isLocalhost = true
break
}
}
if isLocalhost && strings.EqualFold(*p.Database, "databasus") {
return errors.New(
"backing up Databasus internal database is not allowed. To backup Databasus itself, see https://databasus.com/faq#backup-databasus",
)
}
}
return nil
}

View File

@@ -705,6 +705,233 @@ func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
assert.Contains(t, err.Error(), "permission denied")
}
func Test_Validate_WhenLocalhostAndDatabasus_ReturnsError(t *testing.T) {
testCases := []struct {
name string
host string
username string
database string
}{
{
name: "localhost with databasus db",
host: "localhost",
username: "postgres",
database: "databasus",
},
{
name: "127.0.0.1 with databasus db",
host: "127.0.0.1",
username: "postgres",
database: "databasus",
},
{
name: "172.17.0.1 with databasus db",
host: "172.17.0.1",
username: "postgres",
database: "databasus",
},
{
name: "host.docker.internal with databasus db",
host: "host.docker.internal",
username: "postgres",
database: "databasus",
},
{
name: "LOCALHOST (uppercase) with DATABASUS db",
host: "LOCALHOST",
username: "POSTGRES",
database: "DATABASUS",
},
{
name: "LocalHost (mixed case) with DataBasus db",
host: "LocalHost",
username: "anyuser",
database: "DataBasus",
},
{
name: "localhost with databasus and any username",
host: "localhost",
username: "myuser",
database: "databasus",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
pgModel := &PostgresqlDatabase{
Host: tc.host,
Port: 5437,
Username: tc.username,
Password: "somepassword",
Database: &tc.database,
CpuCount: 1,
}
err := pgModel.Validate()
assert.Error(t, err)
assert.Contains(t, err.Error(), "backing up Databasus internal database is not allowed")
assert.Contains(t, err.Error(), "https://databasus.com/faq#backup-databasus")
})
}
}
func Test_Validate_WhenNotLocalhostOrNotDatabasus_ValidatesSuccessfully(t *testing.T) {
testCases := []struct {
name string
host string
username string
database string
}{
{
name: "different host (remote server) with databasus db",
host: "192.168.1.100",
username: "postgres",
database: "databasus",
},
{
name: "different database name on localhost",
host: "localhost",
username: "postgres",
database: "myapp",
},
{
name: "all different",
host: "db.example.com",
username: "appuser",
database: "production",
},
{
name: "localhost with postgres database",
host: "localhost",
username: "postgres",
database: "postgres",
},
{
name: "remote host with databasus db name (allowed)",
host: "db.example.com",
username: "postgres",
database: "databasus",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
pgModel := &PostgresqlDatabase{
Host: tc.host,
Port: 5432,
Username: tc.username,
Password: "somepassword",
Database: &tc.database,
CpuCount: 1,
}
err := pgModel.Validate()
assert.NoError(t, err)
})
}
}
func Test_Validate_WhenDatabaseIsNil_ValidatesSuccessfully(t *testing.T) {
pgModel := &PostgresqlDatabase{
Host: "localhost",
Port: 5437,
Username: "postgres",
Password: "somepassword",
Database: nil,
CpuCount: 1,
}
err := pgModel.Validate()
assert.NoError(t, err)
}
func Test_Validate_WhenDatabaseIsEmpty_ValidatesSuccessfully(t *testing.T) {
emptyDb := ""
pgModel := &PostgresqlDatabase{
Host: "localhost",
Port: 5437,
Username: "postgres",
Password: "somepassword",
Database: &emptyDb,
CpuCount: 1,
}
err := pgModel.Validate()
assert.NoError(t, err)
}
func Test_Validate_WhenRequiredFieldsMissing_ReturnsError(t *testing.T) {
testCases := []struct {
name string
model *PostgresqlDatabase
expectedError string
}{
{
name: "missing host",
model: &PostgresqlDatabase{
Host: "",
Port: 5432,
Username: "user",
Password: "pass",
CpuCount: 1,
},
expectedError: "host is required",
},
{
name: "missing port",
model: &PostgresqlDatabase{
Host: "localhost",
Port: 0,
Username: "user",
Password: "pass",
CpuCount: 1,
},
expectedError: "port is required",
},
{
name: "missing username",
model: &PostgresqlDatabase{
Host: "localhost",
Port: 5432,
Username: "",
Password: "pass",
CpuCount: 1,
},
expectedError: "username is required",
},
{
name: "missing password",
model: &PostgresqlDatabase{
Host: "localhost",
Port: 5432,
Username: "user",
Password: "",
CpuCount: 1,
},
expectedError: "password is required",
},
{
name: "invalid cpu count",
model: &PostgresqlDatabase{
Host: "localhost",
Port: 5432,
Username: "user",
Password: "pass",
CpuCount: 0,
},
expectedError: "cpu count must be greater than 0",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.model.Validate()
assert.Error(t, err)
assert.Contains(t, err.Error(), tc.expectedError)
})
}
}
type PostgresContainer struct {
Host string
Port int

View File

@@ -0,0 +1,75 @@
package task_cancellation
import (
"context"
cache_utils "databasus-backend/internal/util/cache"
"log/slog"
"sync"
"github.com/google/uuid"
)
const taskCancelChannel = "task:cancel"
type TaskCancelManager struct {
mu sync.RWMutex
cancelFuncs map[uuid.UUID]context.CancelFunc
pubsub *cache_utils.PubSubManager
logger *slog.Logger
}
func (m *TaskCancelManager) StartSubscription() {
ctx := context.Background()
handler := func(message string) {
taskID, err := uuid.Parse(message)
if err != nil {
m.logger.Error("Invalid task ID in cancel message", "message", message, "error", err)
return
}
m.mu.Lock()
defer m.mu.Unlock()
cancelFunc, exists := m.cancelFuncs[taskID]
if exists {
cancelFunc()
delete(m.cancelFuncs, taskID)
m.logger.Info("Cancelled task via Pub/Sub", "taskID", taskID)
}
}
err := m.pubsub.Subscribe(ctx, taskCancelChannel, handler)
if err != nil {
m.logger.Error("Failed to subscribe to task cancel channel", "error", err)
} else {
m.logger.Info("Successfully subscribed to task cancel channel")
}
}
func (m *TaskCancelManager) RegisterTask(task uuid.UUID, cancelFunc context.CancelFunc) {
m.mu.Lock()
defer m.mu.Unlock()
m.cancelFuncs[task] = cancelFunc
m.logger.Debug("Registered task", "taskID", task)
}
func (m *TaskCancelManager) CancelTask(taskID uuid.UUID) error {
ctx := context.Background()
err := m.pubsub.Publish(ctx, taskCancelChannel, taskID.String())
if err != nil {
m.logger.Error("Failed to publish cancel message", "taskID", taskID, "error", err)
return err
}
m.logger.Info("Published task cancel message", "taskID", taskID)
return nil
}
func (m *TaskCancelManager) UnregisterTask(taskID uuid.UUID) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.cancelFuncs, taskID)
m.logger.Debug("Unregistered task", "taskID", taskID)
}

View File

@@ -1,4 +1,4 @@
package backups_cancellation
package task_cancellation
import (
"context"
@@ -10,41 +10,41 @@ import (
"github.com/stretchr/testify/assert"
)
func Test_RegisterBackup_BackupRegisteredSuccessfully(t *testing.T) {
manager := backupCancelManager
func Test_RegisterTask_TaskRegisteredSuccessfully(t *testing.T) {
manager := taskCancelManager
backupID := uuid.New()
taskID := uuid.New()
_, cancel := context.WithCancel(context.Background())
defer cancel()
manager.RegisterBackup(backupID, cancel)
manager.RegisterTask(taskID, cancel)
manager.mu.RLock()
_, exists := manager.cancelFuncs[backupID]
_, exists := manager.cancelFuncs[taskID]
manager.mu.RUnlock()
assert.True(t, exists, "Backup should be registered")
assert.True(t, exists, "Task should be registered")
}
func Test_UnregisterBackup_BackupUnregisteredSuccessfully(t *testing.T) {
manager := backupCancelManager
func Test_UnregisterTask_TaskUnregisteredSuccessfully(t *testing.T) {
manager := taskCancelManager
backupID := uuid.New()
taskID := uuid.New()
_, cancel := context.WithCancel(context.Background())
defer cancel()
manager.RegisterBackup(backupID, cancel)
manager.UnregisterBackup(backupID)
manager.RegisterTask(taskID, cancel)
manager.UnregisterTask(taskID)
manager.mu.RLock()
_, exists := manager.cancelFuncs[backupID]
_, exists := manager.cancelFuncs[taskID]
manager.mu.RUnlock()
assert.False(t, exists, "Backup should be unregistered")
assert.False(t, exists, "Task should be unregistered")
}
func Test_CancelBackup_OnSameInstance_BackupCancelledViaPubSub(t *testing.T) {
manager := backupCancelManager
func Test_CancelTask_OnSameInstance_TaskCancelledViaPubSub(t *testing.T) {
manager := taskCancelManager
backupID := uuid.New()
taskID := uuid.New()
ctx, cancel := context.WithCancel(context.Background())
cancelled := false
@@ -57,11 +57,11 @@ func Test_CancelBackup_OnSameInstance_BackupCancelledViaPubSub(t *testing.T) {
cancel()
}
manager.RegisterBackup(backupID, wrappedCancel)
manager.RegisterTask(taskID, wrappedCancel)
manager.StartSubscription()
time.Sleep(100 * time.Millisecond)
err := manager.CancelBackup(backupID)
err := manager.CancelTask(taskID)
assert.NoError(t, err, "Cancel should not return error")
time.Sleep(500 * time.Millisecond)
@@ -74,11 +74,11 @@ func Test_CancelBackup_OnSameInstance_BackupCancelledViaPubSub(t *testing.T) {
assert.Error(t, ctx.Err(), "Context should be cancelled")
}
func Test_CancelBackup_FromDifferentInstance_BackupCancelledOnRunningInstance(t *testing.T) {
manager1 := backupCancelManager
manager2 := backupCancelManager
func Test_CancelTask_FromDifferentInstance_TaskCancelledOnRunningInstance(t *testing.T) {
manager1 := taskCancelManager
manager2 := taskCancelManager
backupID := uuid.New()
taskID := uuid.New()
ctx, cancel := context.WithCancel(context.Background())
cancelled := false
@@ -91,13 +91,13 @@ func Test_CancelBackup_FromDifferentInstance_BackupCancelledOnRunningInstance(t
cancel()
}
manager1.RegisterBackup(backupID, wrappedCancel)
manager1.RegisterTask(taskID, wrappedCancel)
manager1.StartSubscription()
manager2.StartSubscription()
time.Sleep(100 * time.Millisecond)
err := manager2.CancelBackup(backupID)
err := manager2.CancelTask(taskID)
assert.NoError(t, err, "Cancel should not return error")
time.Sleep(500 * time.Millisecond)
@@ -110,29 +110,29 @@ func Test_CancelBackup_FromDifferentInstance_BackupCancelledOnRunningInstance(t
assert.Error(t, ctx.Err(), "Context should be cancelled")
}
func Test_CancelBackup_WhenBackupDoesNotExist_NoErrorReturned(t *testing.T) {
manager := backupCancelManager
func Test_CancelTask_WhenTaskDoesNotExist_NoErrorReturned(t *testing.T) {
manager := taskCancelManager
manager.StartSubscription()
time.Sleep(100 * time.Millisecond)
nonExistentID := uuid.New()
err := manager.CancelBackup(nonExistentID)
assert.NoError(t, err, "Cancelling non-existent backup should not error")
err := manager.CancelTask(nonExistentID)
assert.NoError(t, err, "Cancelling non-existent task should not error")
}
func Test_CancelBackup_WithMultipleBackups_AllBackupsCancelled(t *testing.T) {
manager := backupCancelManager
func Test_CancelTask_WithMultipleTasks_AllTasksCancelled(t *testing.T) {
manager := taskCancelManager
numBackups := 5
backupIDs := make([]uuid.UUID, numBackups)
contexts := make([]context.Context, numBackups)
cancels := make([]context.CancelFunc, numBackups)
cancelledFlags := make([]bool, numBackups)
numTasks := 5
taskIDs := make([]uuid.UUID, numTasks)
contexts := make([]context.Context, numTasks)
cancels := make([]context.CancelFunc, numTasks)
cancelledFlags := make([]bool, numTasks)
var mu sync.Mutex
for i := 0; i < numBackups; i++ {
backupIDs[i] = uuid.New()
for i := 0; i < numTasks; i++ {
taskIDs[i] = uuid.New()
contexts[i], cancels[i] = context.WithCancel(context.Background())
idx := i
@@ -143,31 +143,31 @@ func Test_CancelBackup_WithMultipleBackups_AllBackupsCancelled(t *testing.T) {
cancels[idx]()
}
manager.RegisterBackup(backupIDs[i], wrappedCancel)
manager.RegisterTask(taskIDs[i], wrappedCancel)
}
manager.StartSubscription()
time.Sleep(100 * time.Millisecond)
for i := 0; i < numBackups; i++ {
err := manager.CancelBackup(backupIDs[i])
for i := 0; i < numTasks; i++ {
err := manager.CancelTask(taskIDs[i])
assert.NoError(t, err, "Cancel should not return error")
}
time.Sleep(1 * time.Second)
mu.Lock()
for i := 0; i < numBackups; i++ {
assert.True(t, cancelledFlags[i], "Backup %d should be cancelled", i)
for i := 0; i < numTasks; i++ {
assert.True(t, cancelledFlags[i], "Task %d should be cancelled", i)
assert.Error(t, contexts[i].Err(), "Context %d should be cancelled", i)
}
mu.Unlock()
}
func Test_CancelBackup_AfterUnregister_BackupNotCancelled(t *testing.T) {
manager := backupCancelManager
func Test_CancelTask_AfterUnregister_TaskNotCancelled(t *testing.T) {
manager := taskCancelManager
backupID := uuid.New()
taskID := uuid.New()
_, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -181,13 +181,13 @@ func Test_CancelBackup_AfterUnregister_BackupNotCancelled(t *testing.T) {
cancel()
}
manager.RegisterBackup(backupID, wrappedCancel)
manager.RegisterTask(taskID, wrappedCancel)
manager.StartSubscription()
time.Sleep(100 * time.Millisecond)
manager.UnregisterBackup(backupID)
manager.UnregisterTask(taskID)
err := manager.CancelBackup(backupID)
err := manager.CancelTask(taskID)
assert.NoError(t, err, "Cancel should not return error")
time.Sleep(500 * time.Millisecond)

View File

@@ -1,4 +1,4 @@
package backups_cancellation
package task_cancellation
import (
"context"
@@ -9,17 +9,17 @@ import (
"github.com/google/uuid"
)
var backupCancelManager = &BackupCancelManager{
var taskCancelManager = &TaskCancelManager{
sync.RWMutex{},
make(map[uuid.UUID]context.CancelFunc),
cache_utils.NewPubSubManager(),
logger.GetLogger(),
}
func GetBackupCancelManager() *BackupCancelManager {
return backupCancelManager
func GetTaskCancelManager() *TaskCancelManager {
return taskCancelManager
}
func SetupDependencies() {
backupCancelManager.StartSubscription()
taskCancelManager.StartSubscription()
}

View File

@@ -0,0 +1,18 @@
package task_registry
import (
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/logger"
)
var taskNodesRegistry = &TaskNodesRegistry{
cache_utils.GetValkeyClient(),
logger.GetLogger(),
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
}
func GetTaskNodesRegistry() *TaskNodesRegistry {
return taskNodesRegistry
}

View File

@@ -0,0 +1,29 @@
package task_registry
import (
"time"
"github.com/google/uuid"
)
type TaskNode struct {
ID uuid.UUID `json:"id"`
ThroughputMBs int `json:"throughputMBs"`
LastHeartbeat time.Time `json:"lastHeartbeat"`
}
type TaskNodeStats struct {
ID uuid.UUID `json:"id"`
ActiveTasks int `json:"activeTasks"`
}
type TaskSubmitMessage struct {
NodeID string `json:"nodeId"`
TaskID string `json:"taskId"`
IsCallNotifier bool `json:"isCallNotifier"`
}
type TaskCompletionMessage struct {
NodeID string `json:"nodeId"`
TaskID string `json:"taskId"`
}

View File

@@ -0,0 +1,641 @@
package task_registry
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"time"
cache_utils "databasus-backend/internal/util/cache"
"github.com/google/uuid"
"github.com/valkey-io/valkey-go"
)
const (
nodeInfoKeyPrefix = "node:"
nodeInfoKeySuffix = ":info"
nodeActiveTasksPrefix = "node:"
nodeActiveTasksSuffix = ":active_tasks"
taskSubmitChannel = "task:submit"
taskCompletionChannel = "task:completion"
deadNodeThreshold = 2 * time.Minute
cleanupTickerInterval = 1 * time.Second
)
// TaskNodesRegistry helps to sync tasks scheduler (backuping or restoring)
// and task nodes which are used for network-intensive tasks processing
//
// Features:
// - Track node availability and load level
// - Assign from scheduler to node tasks needed to be processed
// - Notify scheduler from node about task completion
//
// Important things to remember:
// - Node can contain different tasks types so when task is assigned
// or node's tasks cleaned - should be performed DB check in DB
// that task with this ID exists for this task type at all
// - Nodes without heathbeat for more than 2 minutes are not included
// in available nodes list and stats
//
// Cleanup dead nodes performed on 2 levels:
// - List and stats functions do not return dead nodes
// - Periodically dead nodes are cleaned up in cache (to not
// accumulate too many dead nodes in cache)
type TaskNodesRegistry struct {
client valkey.Client
logger *slog.Logger
timeout time.Duration
pubsubTasks *cache_utils.PubSubManager
pubsubCompletions *cache_utils.PubSubManager
}
func (r *TaskNodesRegistry) Run(ctx context.Context) {
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
}
ticker := time.NewTicker(cleanupTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes", "error", err)
}
}
}
}
func (r *TaskNodesRegistry) GetAvailableNodes() ([]TaskNode, error) {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
var allKeys []string
cursor := uint64(0)
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
for {
result := r.client.Do(
ctx,
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
)
if result.Error() != nil {
return nil, fmt.Errorf("failed to scan node keys: %w", result.Error())
}
scanResult, err := result.AsScanEntry()
if err != nil {
return nil, fmt.Errorf("failed to parse scan result: %w", err)
}
allKeys = append(allKeys, scanResult.Elements...)
cursor = scanResult.Cursor
if cursor == 0 {
break
}
}
if len(allKeys) == 0 {
return []TaskNode{}, nil
}
keyDataMap, err := r.pipelineGetKeys(allKeys)
if err != nil {
return nil, fmt.Errorf("failed to pipeline get node keys: %w", err)
}
threshold := time.Now().UTC().Add(-deadNodeThreshold)
var nodes []TaskNode
for key, data := range keyDataMap {
// Skip if the key doesn't exist (data is empty)
if len(data) == 0 {
continue
}
var node TaskNode
if err := json.Unmarshal(data, &node); err != nil {
r.logger.Warn("Failed to unmarshal node data", "key", key, "error", err)
continue
}
// Skip nodes with zero/uninitialized heartbeat
if node.LastHeartbeat.IsZero() {
continue
}
if node.LastHeartbeat.Before(threshold) {
continue
}
nodes = append(nodes, node)
}
return nodes, nil
}
func (r *TaskNodesRegistry) GetNodesStats() ([]TaskNodeStats, error) {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
var allKeys []string
cursor := uint64(0)
pattern := nodeActiveTasksPrefix + "*" + nodeActiveTasksSuffix
for {
result := r.client.Do(
ctx,
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(100).Build(),
)
if result.Error() != nil {
return nil, fmt.Errorf("failed to scan active tasks keys: %w", result.Error())
}
scanResult, err := result.AsScanEntry()
if err != nil {
return nil, fmt.Errorf("failed to parse scan result: %w", err)
}
allKeys = append(allKeys, scanResult.Elements...)
cursor = scanResult.Cursor
if cursor == 0 {
break
}
}
if len(allKeys) == 0 {
return []TaskNodeStats{}, nil
}
keyDataMap, err := r.pipelineGetKeys(allKeys)
if err != nil {
return nil, fmt.Errorf("failed to pipeline get active tasks keys: %w", err)
}
var nodeInfoKeys []string
nodeIDToStatsKey := make(map[string]string)
for key := range keyDataMap {
nodeID := r.extractNodeIDFromKey(key, nodeActiveTasksPrefix, nodeActiveTasksSuffix)
nodeIDStr := nodeID.String()
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeIDStr, nodeInfoKeySuffix)
nodeInfoKeys = append(nodeInfoKeys, infoKey)
nodeIDToStatsKey[infoKey] = key
}
nodeInfoMap, err := r.pipelineGetKeys(nodeInfoKeys)
if err != nil {
return nil, fmt.Errorf("failed to pipeline get node info keys: %w", err)
}
threshold := time.Now().UTC().Add(-deadNodeThreshold)
var stats []TaskNodeStats
for infoKey, nodeData := range nodeInfoMap {
// Skip if the info key doesn't exist (nodeData is empty)
if len(nodeData) == 0 {
continue
}
var node TaskNode
if err := json.Unmarshal(nodeData, &node); err != nil {
r.logger.Warn("Failed to unmarshal node data", "key", infoKey, "error", err)
continue
}
// Skip nodes with zero/uninitialized heartbeat
if node.LastHeartbeat.IsZero() {
continue
}
if node.LastHeartbeat.Before(threshold) {
continue
}
statsKey := nodeIDToStatsKey[infoKey]
tasksData := keyDataMap[statsKey]
count, err := r.parseIntFromBytes(tasksData)
if err != nil {
r.logger.Warn("Failed to parse active tasks count", "key", statsKey, "error", err)
continue
}
stat := TaskNodeStats{
ID: node.ID,
ActiveTasks: int(count),
}
stats = append(stats, stat)
}
return stats, nil
}
func (r *TaskNodesRegistry) IncrementTasksInProgress(nodeID string) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeActiveTasksPrefix, nodeID, nodeActiveTasksSuffix)
result := r.client.Do(ctx, r.client.B().Incr().Key(key).Build())
if result.Error() != nil {
return fmt.Errorf(
"failed to increment tasks in progress for node %s: %w",
nodeID,
result.Error(),
)
}
return nil
}
func (r *TaskNodesRegistry) DecrementTasksInProgress(nodeID string) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeActiveTasksPrefix, nodeID, nodeActiveTasksSuffix)
result := r.client.Do(ctx, r.client.B().Decr().Key(key).Build())
if result.Error() != nil {
return fmt.Errorf(
"failed to decrement tasks in progress for node %s: %w",
nodeID,
result.Error(),
)
}
newValue, err := result.AsInt64()
if err != nil {
return fmt.Errorf("failed to parse decremented value for node %s: %w", nodeID, err)
}
if newValue < 0 {
setCtx, setCancel := context.WithTimeout(context.Background(), r.timeout)
r.client.Do(setCtx, r.client.B().Set().Key(key).Value("0").Build())
setCancel()
r.logger.Warn("Active tasks counter went below 0, reset to 0", "nodeID", nodeID)
}
return nil
}
func (r *TaskNodesRegistry) HearthbeatNodeInRegistry(now time.Time, node TaskNode) error {
if now.IsZero() {
return fmt.Errorf("cannot register node with zero heartbeat timestamp")
}
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
node.LastHeartbeat = now
data, err := json.Marshal(node)
if err != nil {
return fmt.Errorf("failed to marshal node: %w", err)
}
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node.ID.String(), nodeInfoKeySuffix)
result := r.client.Do(
ctx,
r.client.B().Set().Key(key).Value(string(data)).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to register node %s: %w", node.ID, result.Error())
}
return nil
}
func (r *TaskNodesRegistry) UnregisterNodeFromRegistry(node TaskNode) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node.ID.String(), nodeInfoKeySuffix)
counterKey := fmt.Sprintf(
"%s%s%s",
nodeActiveTasksPrefix,
node.ID.String(),
nodeActiveTasksSuffix,
)
result := r.client.Do(
ctx,
r.client.B().Del().Key(infoKey, counterKey).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to unregister node %s: %w", node.ID, result.Error())
}
r.logger.Info("Unregistered node from registry", "nodeID", node.ID)
return nil
}
func (r *TaskNodesRegistry) AssignTaskToNode(
targetNodeID string,
taskID uuid.UUID,
isCallNotifier bool,
) error {
ctx := context.Background()
message := TaskSubmitMessage{
NodeID: targetNodeID,
TaskID: taskID.String(),
IsCallNotifier: isCallNotifier,
}
messageJSON, err := json.Marshal(message)
if err != nil {
return fmt.Errorf("failed to marshal task submit message: %w", err)
}
err = r.pubsubTasks.Publish(ctx, taskSubmitChannel, string(messageJSON))
if err != nil {
return fmt.Errorf("failed to publish task submit message: %w", err)
}
return nil
}
func (r *TaskNodesRegistry) SubscribeNodeForTasksAssignment(
nodeID string,
handler func(taskID uuid.UUID, isCallNotifier bool),
) error {
ctx := context.Background()
wrappedHandler := func(message string) {
var msg TaskSubmitMessage
if err := json.Unmarshal([]byte(message), &msg); err != nil {
r.logger.Warn("Failed to unmarshal task submit message", "error", err)
return
}
if msg.NodeID != nodeID {
return
}
taskID, err := uuid.Parse(msg.TaskID)
if err != nil {
r.logger.Warn(
"Failed to parse task ID from message",
"taskId",
msg.TaskID,
"error",
err,
)
return
}
handler(taskID, msg.IsCallNotifier)
}
err := r.pubsubTasks.Subscribe(ctx, taskSubmitChannel, wrappedHandler)
if err != nil {
return fmt.Errorf("failed to subscribe to task submit channel: %w", err)
}
r.logger.Info("Subscribed to task submit channel", "nodeID", nodeID)
return nil
}
func (r *TaskNodesRegistry) UnsubscribeNodeForTasksAssignments() error {
err := r.pubsubTasks.Close()
if err != nil {
return fmt.Errorf("failed to unsubscribe from task submit channel: %w", err)
}
r.logger.Info("Unsubscribed from task submit channel")
return nil
}
func (r *TaskNodesRegistry) PublishTaskCompletion(nodeID string, taskID uuid.UUID) error {
ctx := context.Background()
message := TaskCompletionMessage{
NodeID: nodeID,
TaskID: taskID.String(),
}
messageJSON, err := json.Marshal(message)
if err != nil {
return fmt.Errorf("failed to marshal task completion message: %w", err)
}
err = r.pubsubCompletions.Publish(ctx, taskCompletionChannel, string(messageJSON))
if err != nil {
return fmt.Errorf("failed to publish task completion message: %w", err)
}
return nil
}
func (r *TaskNodesRegistry) SubscribeForTasksCompletions(
handler func(nodeID string, taskID uuid.UUID),
) error {
ctx := context.Background()
wrappedHandler := func(message string) {
var msg TaskCompletionMessage
if err := json.Unmarshal([]byte(message), &msg); err != nil {
r.logger.Warn("Failed to unmarshal task completion message", "error", err)
return
}
taskID, err := uuid.Parse(msg.TaskID)
if err != nil {
r.logger.Warn(
"Failed to parse task ID from completion message",
"taskId",
msg.TaskID,
"error",
err,
)
return
}
handler(msg.NodeID, taskID)
}
err := r.pubsubCompletions.Subscribe(ctx, taskCompletionChannel, wrappedHandler)
if err != nil {
return fmt.Errorf("failed to subscribe to task completion channel: %w", err)
}
r.logger.Info("Subscribed to task completion channel")
return nil
}
func (r *TaskNodesRegistry) UnsubscribeForTasksCompletions() error {
err := r.pubsubCompletions.Close()
if err != nil {
return fmt.Errorf("failed to unsubscribe from task completion channel: %w", err)
}
r.logger.Info("Unsubscribed from task completion channel")
return nil
}
func (r *TaskNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uuid.UUID {
nodeIDStr := strings.TrimPrefix(key, prefix)
nodeIDStr = strings.TrimSuffix(nodeIDStr, suffix)
nodeID, err := uuid.Parse(nodeIDStr)
if err != nil {
r.logger.Warn("Failed to parse node ID from key", "key", key, "error", err)
return uuid.Nil
}
return nodeID
}
func (r *TaskNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, error) {
if len(keys) == 0 {
return make(map[string][]byte), nil
}
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
commands := make([]valkey.Completed, 0, len(keys))
for _, key := range keys {
commands = append(commands, r.client.B().Get().Key(key).Build())
}
results := r.client.DoMulti(ctx, commands...)
keyDataMap := make(map[string][]byte, len(keys))
for i, result := range results {
if result.Error() != nil {
r.logger.Warn("Failed to get key in pipeline", "key", keys[i], "error", result.Error())
continue
}
data, err := result.AsBytes()
if err != nil {
r.logger.Warn("Failed to parse key data in pipeline", "key", keys[i], "error", err)
continue
}
keyDataMap[keys[i]] = data
}
return keyDataMap, nil
}
func (r *TaskNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
str := string(data)
var count int64
_, err := fmt.Sscanf(str, "%d", &count)
if err != nil {
return 0, fmt.Errorf("failed to parse integer from bytes: %w", err)
}
return count, nil
}
func (r *TaskNodesRegistry) cleanupDeadNodes() error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
var allKeys []string
cursor := uint64(0)
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
for {
result := r.client.Do(
ctx,
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to scan node keys: %w", result.Error())
}
scanResult, err := result.AsScanEntry()
if err != nil {
return fmt.Errorf("failed to parse scan result: %w", err)
}
allKeys = append(allKeys, scanResult.Elements...)
cursor = scanResult.Cursor
if cursor == 0 {
break
}
}
if len(allKeys) == 0 {
return nil
}
keyDataMap, err := r.pipelineGetKeys(allKeys)
if err != nil {
return fmt.Errorf("failed to pipeline get node keys: %w", err)
}
threshold := time.Now().UTC().Add(-deadNodeThreshold)
var deadNodeKeys []string
for key, data := range keyDataMap {
// Skip if the key doesn't exist (data is empty)
if len(data) == 0 {
continue
}
var node TaskNode
if err := json.Unmarshal(data, &node); err != nil {
r.logger.Warn("Failed to unmarshal node data during cleanup", "key", key, "error", err)
continue
}
// Skip nodes with zero/uninitialized heartbeat
if node.LastHeartbeat.IsZero() {
continue
}
if node.LastHeartbeat.Before(threshold) {
nodeID := node.ID.String()
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeID, nodeInfoKeySuffix)
statsKey := fmt.Sprintf("%s%s%s", nodeActiveTasksPrefix, nodeID, nodeActiveTasksSuffix)
deadNodeKeys = append(deadNodeKeys, infoKey, statsKey)
r.logger.Info(
"Marking node for cleanup",
"nodeID", nodeID,
"lastHeartbeat", node.LastHeartbeat,
"threshold", threshold,
)
}
}
if len(deadNodeKeys) == 0 {
return nil
}
delCtx, delCancel := context.WithTimeout(context.Background(), r.timeout)
defer delCancel()
result := r.client.Do(
delCtx,
r.client.B().Del().Key(deadNodeKeys...).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to delete dead node keys: %w", result.Error())
}
deletedCount, err := result.AsInt64()
if err != nil {
return fmt.Errorf("failed to parse deleted count: %w", err)
}
r.logger.Info("Cleaned up dead nodes", "deletedKeysCount", deletedCount)
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,51 @@
package cache_utils
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_ClearAllCache_AfterClear_CacheIsEmpty(t *testing.T) {
client := getCache()
// Arrange: Set up multiple cache entries with different prefixes
testKeys := []struct {
prefix string
key string
value string
}{
{"test:user:", "user1", "John Doe"},
{"test:user:", "user2", "Jane Smith"},
{"test:session:", "session1", "abc123"},
{"test:session:", "session2", "def456"},
{"test:data:", "item1", "value1"},
}
// Set all test keys
for _, tk := range testKeys {
cacheUtil := NewCacheUtil[string](client, tk.prefix)
cacheUtil.Set(tk.key, &tk.value)
}
// Verify keys were set correctly before clearing
for _, tk := range testKeys {
cacheUtil := NewCacheUtil[string](client, tk.prefix)
retrieved := cacheUtil.Get(tk.key)
assert.NotNil(t, retrieved, "Key %s should exist before clearing", tk.prefix+tk.key)
assert.Equal(t, tk.value, *retrieved, "Retrieved value should match set value")
}
// Act: Clear all cache
err := ClearAllCache()
// Assert: No error returned
assert.NoError(t, err, "ClearAllCache should not return an error")
// Assert: All keys should be deleted
for _, tk := range testKeys {
cacheUtil := NewCacheUtil[string](client, tk.prefix)
retrieved := cacheUtil.Get(tk.key)
assert.Nil(t, retrieved, "Key %s should be deleted after clearing", tk.prefix+tk.key)
}
}

View File

@@ -87,17 +87,20 @@ func VerifyMariadbInstallation(
logger *slog.Logger,
envMode env_utils.EnvMode,
mariadbInstallDir string,
isShowLogs bool,
) {
clientVersions := []MariadbClientVersion{MariadbClientLegacy, MariadbClientModern}
for _, clientVersion := range clientVersions {
binDir := getMariadbBasePath(clientVersion, envMode, mariadbInstallDir)
logger.Info(
"Verifying MariaDB installation",
"clientVersion", clientVersion,
"path", binDir,
)
if isShowLogs {
logger.Info(
"Verifying MariaDB installation",
"clientVersion", clientVersion,
"path", binDir,
)
}
if _, err := os.Stat(binDir); os.IsNotExist(err) {
if envMode == env_utils.EnvModeDevelopment {
@@ -133,12 +136,14 @@ func VerifyMariadbInstallation(
}
cmdPath := GetMariadbExecutable(cmd, dummyServerVersion, envMode, mariadbInstallDir)
logger.Info(
"Checking for MariaDB command",
"clientVersion", clientVersion,
"command", cmd,
"path", cmdPath,
)
if isShowLogs {
logger.Info(
"Checking for MariaDB command",
"clientVersion", clientVersion,
"command", cmd,
"path", cmdPath,
)
}
if _, err := os.Stat(cmdPath); os.IsNotExist(err) {
if envMode == env_utils.EnvModeDevelopment {
@@ -162,11 +167,15 @@ func VerifyMariadbInstallation(
continue
}
logger.Info("MariaDB command found", "clientVersion", clientVersion, "command", cmd)
if isShowLogs {
logger.Info("MariaDB command found", "clientVersion", clientVersion, "command", cmd)
}
}
}
logger.Info("MariaDB client tools verification completed!")
if isShowLogs {
logger.Info("MariaDB client tools verification completed!")
}
}
// IsMariadbBackupVersionHigherThanRestoreVersion checks if backup was made with

View File

@@ -53,13 +53,16 @@ func VerifyMongodbInstallation(
logger *slog.Logger,
envMode env_utils.EnvMode,
mongodbInstallDir string,
isShowLogs bool,
) {
binDir := getMongodbBasePath(envMode, mongodbInstallDir)
logger.Info(
"Verifying MongoDB Database Tools installation",
"path", binDir,
)
if isShowLogs {
logger.Info(
"Verifying MongoDB Database Tools installation",
"path", binDir,
)
}
if _, err := os.Stat(binDir); os.IsNotExist(err) {
if envMode == env_utils.EnvModeDevelopment {
@@ -85,11 +88,13 @@ func VerifyMongodbInstallation(
for _, cmd := range requiredCommands {
cmdPath := GetMongodbExecutable(cmd, envMode, mongodbInstallDir)
logger.Info(
"Checking for MongoDB command",
"command", cmd,
"path", cmdPath,
)
if isShowLogs {
logger.Info(
"Checking for MongoDB command",
"command", cmd,
"path", cmdPath,
)
}
if _, err := os.Stat(cmdPath); os.IsNotExist(err) {
if envMode == env_utils.EnvModeDevelopment {
@@ -110,10 +115,14 @@ func VerifyMongodbInstallation(
continue
}
logger.Info("MongoDB command found", "command", cmd)
if isShowLogs {
logger.Info("MongoDB command found", "command", cmd)
}
}
logger.Info("MongoDB Database Tools verification completed!")
if isShowLogs {
logger.Info("MongoDB Database Tools verification completed!")
}
}
// IsMongodbBackupVersionHigherThanRestoreVersion checks if backup was made with

View File

@@ -55,6 +55,7 @@ func VerifyMysqlInstallation(
logger *slog.Logger,
envMode env_utils.EnvMode,
mysqlInstallDir string,
isShowLogs bool,
) {
versions := []MysqlVersion{
MysqlVersion57,
@@ -71,13 +72,15 @@ func VerifyMysqlInstallation(
for _, version := range versions {
binDir := getMysqlBasePath(version, envMode, mysqlInstallDir)
logger.Info(
"Verifying MySQL installation",
"version",
string(version),
"path",
binDir,
)
if isShowLogs {
logger.Info(
"Verifying MySQL installation",
"version",
string(version),
"path",
binDir,
)
}
if _, err := os.Stat(binDir); os.IsNotExist(err) {
if envMode == env_utils.EnvModeDevelopment {
@@ -108,15 +111,17 @@ func VerifyMysqlInstallation(
mysqlInstallDir,
)
logger.Info(
"Checking for MySQL command",
"command",
cmd,
"version",
string(version),
"path",
cmdPath,
)
if isShowLogs {
logger.Info(
"Checking for MySQL command",
"command",
cmd,
"version",
string(version),
"path",
cmdPath,
)
}
if _, err := os.Stat(cmdPath); os.IsNotExist(err) {
if envMode == env_utils.EnvModeDevelopment {
@@ -143,25 +148,31 @@ func VerifyMysqlInstallation(
continue
}
logger.Info(
"MySQL command found",
"command",
cmd,
"version",
string(version),
)
if isShowLogs {
logger.Info(
"MySQL command found",
"command",
cmd,
"version",
string(version),
)
}
}
logger.Info(
"Installation of MySQL verified",
"version",
string(version),
"path",
binDir,
)
if isShowLogs {
logger.Info(
"Installation of MySQL verified",
"version",
string(version),
"path",
binDir,
)
}
}
logger.Info("MySQL version-specific client tools verification completed!")
if isShowLogs {
logger.Info("MySQL version-specific client tools verification completed!")
}
}
// IsMysqlBackupVersionHigherThanRestoreVersion checks if backup was made with

View File

@@ -40,6 +40,7 @@ func VerifyPostgresesInstallation(
logger *slog.Logger,
envMode env_utils.EnvMode,
postgresesInstallDir string,
isShowLogs bool,
) {
versions := []PostgresqlVersion{
PostgresqlVersion12,
@@ -59,13 +60,15 @@ func VerifyPostgresesInstallation(
for _, version := range versions {
binDir := getPostgresqlBasePath(version, envMode, postgresesInstallDir)
logger.Info(
"Verifying PostgreSQL installation",
"version",
string(version),
"path",
binDir,
)
if isShowLogs {
logger.Info(
"Verifying PostgreSQL installation",
"version",
string(version),
"path",
binDir,
)
}
if _, err := os.Stat(binDir); os.IsNotExist(err) {
if envMode == env_utils.EnvModeDevelopment {
@@ -96,15 +99,17 @@ func VerifyPostgresesInstallation(
postgresesInstallDir,
)
logger.Info(
"Checking for PostgreSQL command",
"command",
cmd,
"version",
string(version),
"path",
cmdPath,
)
if isShowLogs {
logger.Info(
"Checking for PostgreSQL command",
"command",
cmd,
"version",
string(version),
"path",
cmdPath,
)
}
if _, err := os.Stat(cmdPath); os.IsNotExist(err) {
if envMode == env_utils.EnvModeDevelopment {
@@ -131,25 +136,33 @@ func VerifyPostgresesInstallation(
os.Exit(1)
}
logger.Info(
"PostgreSQL command found",
"command",
cmd,
"version",
string(version),
)
if isShowLogs {
logger.Info(
"PostgreSQL command found",
"command",
cmd,
"version",
string(version),
)
}
}
logger.Info(
"Installation of PostgreSQL verified",
"version",
string(version),
"path",
binDir,
)
if isShowLogs {
logger.Info(
"Installation of PostgreSQL verified",
"version",
string(version),
"path",
binDir,
)
}
}
logger.Info("All PostgreSQL version-specific client tools verification completed successfully!")
if isShowLogs {
logger.Info(
"All PostgreSQL version-specific client tools verification completed successfully!",
)
}
}
// EscapePgpassField escapes special characters in a field value for .pgpass file format.

View File

@@ -1,29 +0,0 @@
Write ReactComponent with the following structure:
interface Props {
someValue: SomeValue;
}
const someHelperFunction = () => {
...
}
export const ReactComponent = ({ someValue }: Props): JSX.Element => {
// first put states
const [someState, setSomeState] = useState<...>(...)
// then place functions
const loadSomeData = async () => {
...
}
// then hooks
useEffect(() => {
loadSomeData();
});
// then calculated values
const calculatedValue = someValue.calculate();
return <div> ... </div>
}