Compare commits

...

96 Commits

Author SHA1 Message Date
Rostislav Dugin
1f1d80245f Merge pull request #368 from databasus/develop
FIX (restores): Increase restore timeout to 23 hours instead of 1 hour
2026-02-17 14:56:58 +03:00
Rostislav Dugin
16a29cf458 FIX (restores): Increase restore timeout to 23 hours instead of 1 hour 2026-02-17 14:56:25 +03:00
Rostislav Dugin
43e04500ac Merge pull request #367 from databasus/develop
FEATURE (backups): Add meaningful names for backups
2026-02-17 14:50:21 +03:00
Rostislav Dugin
cee3022f85 FEATURE (backups): Add meaningful names for backups 2026-02-17 14:49:33 +03:00
Rostislav Dugin
f46d92c480 Merge pull request #365 from databasus/develop
FIX (audit logs): Get rid of IDs in audit logs and improve naming log…
2026-02-15 01:10:54 +03:00
Rostislav Dugin
10677238d7 FIX (audit logs): Get rid of IDs in audit logs and improve naming logging 2026-02-15 01:06:39 +03:00
Rostislav Dugin
2553203fcf Merge pull request #363 from databasus/develop
FIX (sign up): Return authorization token on sign up to avoid 2-step …
2026-02-15 00:09:00 +03:00
Rostislav Dugin
7b05bd8000 FIX (sign up): Return authorization token on sign up to avoid 2-step sign up 2026-02-15 00:08:01 +03:00
Rostislav Dugin
8d45728f73 Merge pull request #362 from databasus/develop
FEATURE (auth): Add optional CloudFlare Turnstile for sign in \ sign …
2026-02-14 23:19:12 +03:00
Rostislav Dugin
c70ad82c95 FEATURE (auth): Add optional CloudFlare Turnstile for sign in \ sign up \ password reset 2026-02-14 23:11:36 +03:00
Rostislav Dugin
e4bc34d319 Merge pull request #361 from databasus/develop
Develop
2026-02-13 16:57:25 +03:00
Rostislav Dugin
257ae85da7 FIX (postgres): Fix read-only issue when user cannot access tables and partitions created after user creation 2026-02-13 16:56:56 +03:00
Rostislav Dugin
b42c820bb2 FIX (mariadb): Fix events exclusion 2026-02-13 16:21:48 +03:00
Rostislav Dugin
da5c13fb11 Merge pull request #356 from databasus/develop
FIX (mysql & mariadb): Fix creation of backups with exremely large SQ…
2026-02-10 22:40:06 +03:00
Rostislav Dugin
35180360e5 FIX (mysql & mariadb): Fix creation of backups with exremely large SQL statements to avoid OOM 2026-02-10 22:38:18 +03:00
Rostislav Dugin
e4f6cd7a5d Merge pull request #349 from databasus/develop
Develop
2026-02-09 16:42:00 +03:00
Rostislav Dugin
d7b8e6d56a Merge branch 'develop' of https://github.com/databasus/databasus into develop 2026-02-09 16:40:46 +03:00
Rostislav Dugin
6016f23fb2 FEATURE (svr): Add SVR support 2026-02-09 16:39:51 +03:00
Rostislav Dugin
e7c4ee8f6f Merge pull request #345 from databasus/develop
Develop
2026-02-08 23:38:42 +03:00
Rostislav Dugin
a75702a01b Merge pull request #342 from wuast94/patch-1
Add image source label to dockerfiles
2026-02-08 23:38:18 +03:00
Rostislav Dugin
81a21eb907 FEATURE (google drive): Change OAuth authorization flow to local address instead of databasus.com 2026-02-08 23:32:13 +03:00
Marc
33d6bf0147 Add image source label to dockerfiles
To get changelogs shown with Renovate a docker container has to add the source label described in the OCI Image Format Specification.

For reference: https://github.com/renovatebot/renovate/blob/main/lib/modules/datasource/docker/readme.md
2026-02-05 23:30:37 +01:00
Rostislav Dugin
6eb53bb07b Merge pull request #341 from databasus/develop
Develop
2026-02-06 00:25:30 +03:00
Rostislav Dugin
6ac04270b9 FEATURE (healthcheck): Add checking whether backup nodes available for primary node 2026-02-06 00:24:34 +03:00
Rostislav Dugin
b0510d7c21 FIX (logging): Add login to VictoriaLogs logger 2026-02-06 00:18:09 +03:00
Rostislav Dugin
dc5f271882 Merge pull request #339 from databasus/develop
FIX (storages): Do not remove system storage on any workspace deletion
2026-02-05 01:32:46 +03:00
Rostislav Dugin
8f718771c9 FIX (storages): Do not remove system storage on any workspace deletion 2026-02-05 01:32:21 +03:00
Rostislav Dugin
d8eea05dca Merge pull request #332 from databasus/develop
FIX (script): Fix script creation in playground head x2
2026-02-02 20:46:35 +03:00
Rostislav Dugin
b2a94274d7 FIX (script): Fix script creation in playground head x2 2026-02-02 20:44:52 +03:00
Rostislav Dugin
77c2712ebb Merge pull request #331 from databasus/develop
FIX (script): Fix script creation in playground head
2026-02-02 19:47:44 +03:00
Rostislav Dugin
a9dc29f82c FIX (script): Fix script creation in playground head 2026-02-02 19:47:15 +03:00
Rostislav Dugin
c934a45dca Merge pull request #330 from databasus/develop
FIX (storages): Fix storage edit in playground
2026-02-02 18:51:47 +03:00
Rostislav Dugin
d4acdf2826 FIX (storages): Fix storage edit in playground 2026-02-02 18:48:19 +03:00
Rostislav Dugin
49753c4fc0 Merge pull request #329 from databasus/develop
FIX (s3): Fix S3 prefill in playground on form edit
2026-02-02 18:14:07 +03:00
Rostislav Dugin
c6aed6b36d FIX (s3): Fix S3 prefill in playground on form edit 2026-02-02 18:12:44 +03:00
Rostislav Dugin
3060b4266a Merge pull request #328 from databasus/develop
Develop
2026-02-02 17:53:05 +03:00
Rostislav Dugin
ebeb597f17 FEATURE (playground): Add support of Rybbit script for playground 2026-02-02 17:50:31 +03:00
Rostislav Dugin
4783784325 FIX (playground): Do not show whitelist message in playground 2026-02-02 16:53:01 +03:00
Rostislav Dugin
bd41433bdb Merge branch 'develop' of https://github.com/databasus/databasus into develop 2026-02-02 16:50:18 +03:00
Rostislav Dugin
a9073787d2 FIX (audit logs): In dark mode show white text in audit logs 2026-02-02 16:44:49 +03:00
Rostislav Dugin
0890bf8f09 Merge pull request #327 from artemkalugin01/access-management-href-fix
Fix href in settings for access-management#global-settings
2026-02-02 16:12:25 +03:00
artem.kalugin
f8c11e8802 Fix href typo in settings for access-management#global-settings 2026-02-02 12:59:56 +03:00
Rostislav Dugin
e798d82fc1 Merge pull request #325 from databasus/develop
FIX (storages): Fix default storage type prefill in playground
2026-02-01 20:12:12 +03:00
Rostislav Dugin
81a01585ee FIX (storages): Fix default storage type prefill in playground 2026-02-01 20:07:12 +03:00
Rostislav Dugin
a8465c1a10 Merge pull request #324 from databasus/develop
FIX (storages): Limit local storage usage in playground
2026-02-01 19:20:34 +03:00
Rostislav Dugin
a9e5db70f6 FIX (storages): Limit local storage usage in playground 2026-02-01 19:18:54 +03:00
Rostislav Dugin
7a47be6ca6 Merge pull request #323 from databasus/develop
Develop
2026-02-01 18:42:30 +03:00
Rostislav Dugin
16be3db0c6 FIX (playground): Pre-select system storage if exists in playground 2026-02-01 18:30:50 +03:00
Rostislav Dugin
744e51d1e1 REFACTOR (email): Refactor commit adding date headers to emails 2026-02-01 16:43:53 +03:00
Rostislav Dugin
b3af75d430 Merge branch 'develop' of https://github.com/databasus/databasus into develop 2026-02-01 16:41:52 +03:00
mcarbs
6f7320abeb FIX (email): Add email date header 2026-02-01 16:41:17 +03:00
Rostislav Dugin
a1655d35a6 FIX (healthcheck): Add cache accessibility to healthcheck 2026-01-30 16:33:39 +03:00
Rostislav Dugin
9b6e801184 Merge pull request #316 from databasus/develop
FEATURE (email): Add sending email about members invitation and passw…
2026-01-28 17:29:58 +03:00
Rostislav Dugin
105777ab6f FEATURE (email): Add sending email about members invitation and password reset 2026-01-28 17:28:36 +03:00
Rostislav Dugin
3a1a88d5cf Merge pull request #315 from databasus/develop
FIX (env): Fix env detection over startup
2026-01-28 11:33:06 +03:00
Rostislav Dugin
699ca16814 FIX (env): Fix env detection over startup 2026-01-28 11:32:19 +03:00
Rostislav Dugin
26f3cf233a Merge pull request #313 from databasus/develop
FIX (backups): Improve cascade deletion of backups on storage removal x3
2026-01-27 17:04:25 +03:00
Rostislav Dugin
3d8372e9f6 FIX (backups): Improve cascade deletion of backups on storage removal x3 2026-01-27 17:03:51 +03:00
Rostislav Dugin
b46f11804d Merge pull request #312 from databasus/develop
FIX (backups): Improve cascade deletion of backups on storage removal x2
2026-01-27 16:38:49 +03:00
Rostislav Dugin
4676361688 FIX (backups): Improve cascade deletion of backups on storage removal x2 2026-01-27 16:38:21 +03:00
Databasus
de3679cadf Merge pull request #310 from databasus/develop
FIX (backups): Improve cascade deletion of backups on storage removal
2026-01-27 16:29:13 +03:00
Rostislav Dugin
8f03a30af2 FIX (backups): Improve cascade deletion of backups on storage removal 2026-01-27 16:28:06 +03:00
Rostislav Dugin
356529c58a Merge pull request #309 from databasus/develop
FIX (tests): Fix database backups cleanup when DI does not allow to d…
2026-01-27 15:39:53 +03:00
Rostislav Dugin
e7eed056f7 FIX (tests): Fix database backups cleanup when DI does not allow to delete backups via listeners 2026-01-27 15:39:04 +03:00
Rostislav Dugin
6084cdc954 Merge pull request #308 from databasus/develop
FIX (tests): Increase cascade deletion timeouts in tests
2026-01-27 15:24:15 +03:00
Rostislav Dugin
c50bcc57b1 FIX (tests): Increase cascade deletion timeouts in tests 2026-01-27 15:23:13 +03:00
Rostislav Dugin
ea76300ed7 Merge pull request #307 from databasus/develop
Develop
2026-01-27 15:07:56 +03:00
Rostislav Dugin
9b413e4076 FIX (tests): Improve cleaning up of backups and workspaces 2026-01-27 15:07:20 +03:00
Rostislav Dugin
f91cb260f2 FEATURE (logs): Add Victora Logs 2026-01-27 15:07:20 +03:00
Rostislav Dugin
8f37a8082f FIX (db): Decrease connections count for DB 2026-01-27 15:07:20 +03:00
Rostislav Dugin
5cf7614772 FIX (playground): Make playground multiple nodes 2026-01-24 14:57:45 +03:00
Rostislav Dugin
ae27f74c2e Merge pull request #304 from databasus/develop
FIX (playground): Fix flacky test with impossible value
2026-01-23 12:38:06 +03:00
Rostislav Dugin
9457516bb9 FIX (playground): Fix flacky test with impossible value 2026-01-23 12:37:10 +03:00
Rostislav Dugin
a36fc5bf8c Merge pull request #303 from databasus/develop
Develop
2026-01-23 12:24:29 +03:00
Rostislav Dugin
03ada5806d FEATURE (pre-commit): Add building step to pre-commit 2026-01-23 12:22:31 +03:00
Rostislav Dugin
a6675390e5 FIX (cors): Allow CORS for healthcheck endpoint 2026-01-23 12:04:29 +03:00
Rostislav Dugin
af2f978876 FEATURE (playground): Add playground 2026-01-23 12:00:56 +03:00
Rostislav Dugin
04e7eba5c5 Merge pull request #300 from databasus/develop
FIX (ci \ cd): Add build step after lint step for frontend to catch b…
2026-01-20 08:40:14 +03:00
Rostislav Dugin
520165541d FIX (ci \ cd): Add build step after lint step for frontend to catch build issues 2026-01-20 08:39:28 +03:00
Rostislav Dugin
5b556bc161 Merge pull request #299 from databasus/develop
Develop
2026-01-20 08:26:57 +03:00
Rostislav Dugin
0952a15ec5 FEATURE (navbar): Update navbar style 2026-01-20 08:25:58 +03:00
Rostislav Dugin
1afb3aa3ff Merge pull request #298 from tim-sas-kramp/main
FIX (theme): Integrate theme support for GitHub button color scheme
2026-01-20 07:25:57 +03:00
tim-sas-kramp
19b92e5f74 FIX (theme): Integrate theme support for GitHub button color scheme 2026-01-19 21:17:24 +00:00
Rostislav Dugin
d4763f26b2 Merge pull request #296 from databasus/develop
Develop
2026-01-19 19:27:03 +03:00
Rostislav Dugin
0e389ba16b FIX (backups): Allow parallel backups for different DBs 2026-01-19 19:26:03 +03:00
Rostislav Dugin
594a3294c6 FEATURE (limits): Add max backup size limit and total backups size limit 2026-01-19 19:26:03 +03:00
Rostislav Dugin
4e4a323cf1 FEATURE (config): Suggest read-only user creation when DB config changed 2026-01-19 19:26:03 +03:00
Rostislav Dugin
7d9ecf697b FIX (backups): Do not allow 2 parallel backups for the same DB 2026-01-19 19:26:03 +03:00
Rostislav Dugin
755c420157 Merge pull request #294 from databasus/develop
FIX (mysql \ mariadb): Add escaping underscoped DB names over heath c…
2026-01-19 12:07:18 +03:00
Rostislav Dugin
ff73627287 FIX (mysql \ mariadb): Add escaping underscoped DB names over heath check 2026-01-19 11:34:37 +03:00
Rostislav Dugin
9c9ab00ace Merge pull request #292 from databasus/develop
FIX (postgresql): Do not throw an error over read-only user creation …
2026-01-18 23:08:55 +03:00
Rostislav Dugin
7366e21a1a FIX (postgresql): Do not throw an error over read-only user creation if there are no public schema in DB 2026-01-18 22:57:47 +03:00
Rostislav Dugin
a327d1aa57 Merge pull request #290 from databasus/develop
FIX (ftp): Add support of nested folders
2026-01-18 18:34:45 +03:00
Rostislav Dugin
f152b16ea3 FIX (ftp): Add support of nested folders 2026-01-18 18:34:13 +03:00
Databasus
85dbe80d3d Merge pull request #288 from databasus/develop
FIX (email): Add following RFC 2047 for emails
2026-01-18 17:59:17 +03:00
Rostislav Dugin
edf4028fd1 FIX (email): Add following RFC 2047 for emails 2026-01-18 17:58:31 +03:00
196 changed files with 11179 additions and 1602 deletions

View File

@@ -81,6 +81,11 @@ jobs:
cd frontend
npm run lint
- name: Build frontend
run: |
cd frontend
npm run build
test-frontend:
runs-on: ubuntu-latest
needs: [lint-frontend]

4
.gitignore vendored
View File

@@ -1,3 +1,4 @@
ansible/
postgresus_data/
postgresus-data/
databasus-data/
@@ -9,4 +10,5 @@ node_modules/
/articles
.DS_Store
/scripts
/scripts
.vscode/settings.json

View File

@@ -18,6 +18,13 @@ repos:
files: ^frontend/.*\.(ts|tsx|js|jsx)$
pass_filenames: false
- id: frontend-build
name: Frontend Build
entry: bash -c "cd frontend && npm run build"
language: system
files: ^frontend/.*\.(ts|tsx|js|jsx|json|css)$
pass_filenames: false
# Backend checks
- repo: local
hooks:

305
AGENTS.md
View File

@@ -1,35 +1,37 @@
# Agent Rules and Guidelines
This document contains all coding standards, conventions and best practices recommended for the Databasus project.
This document contains all coding standards, conventions and best practices recommended for the TgTaps project.
This is NOT a strict set of rules, but a set of recommendations to help you write better code.
---
## Table of Contents
- [Engineering Philosophy](#engineering-philosophy)
- [Backend Guidelines](#backend-guidelines)
- [Code Style](#code-style)
- [Engineering philosophy](#engineering-philosophy)
- [Backend guidelines](#backend-guidelines)
- [Code style](#code-style)
- [Boolean naming](#boolean-naming)
- [Add reasonable new lines between logical statements](#add-reasonable-new-lines-between-logical-statements)
- [Comments](#comments)
- [Controllers](#controllers)
- [Dependency Injection (DI)](#dependency-injection-di)
- [Dependency injection (DI)](#dependency-injection-di)
- [Migrations](#migrations)
- [Refactoring](#refactoring)
- [Testing](#testing)
- [Time Handling](#time-handling)
- [CRUD Examples](#crud-examples)
- [Frontend Guidelines](#frontend-guidelines)
- [React Component Structure](#react-component-structure)
- [Time handling](#time-handling)
- [CRUD examples](#crud-examples)
- [Frontend guidelines](#frontend-guidelines)
- [React component structure](#react-component-structure)
---
## Engineering Philosophy
## Engineering philosophy
**Think like a skeptical senior engineer and code reviewer. Don't just do what was asked—also think about what should have been asked.**
⚠️ **Balance vigilance with pragmatism:** Catch real issues, not theoretical ones. Don't let perfect be the enemy of good.
### Task Context Assessment:
### Task context assessment:
**First, assess the task scope:**
@@ -38,7 +40,7 @@ This is NOT a strict set of rules, but a set of recommendations to help you writ
- **Complex** (architecture, security, performance-critical): Full analysis required
- **Unclear** (ambiguous requirements): Always clarify assumptions first
### For Non-Trivial Tasks:
### For non-trivial tasks:
1. **Restate the objective and list assumptions** (explicit + implicit)
- If any assumption is shaky, call it out clearly
@@ -71,7 +73,7 @@ This is NOT a strict set of rules, but a set of recommendations to help you writ
- Patch the answer accordingly
- Verify edge cases are handled
### Application Guidelines:
### Application guidelines:
**Scale your response to the task:**
@@ -84,9 +86,9 @@ This is NOT a strict set of rules, but a set of recommendations to help you writ
---
## Backend Guidelines
## Backend guidelines
### Code Style
### Code style
**Always place private methods to the bottom of file**
@@ -94,7 +96,7 @@ This rule applies to ALL Go files including tests, services, controllers, reposi
In Go, exported (public) functions/methods start with uppercase letters, while unexported (private) ones start with lowercase letters.
#### Structure Order:
#### Structure order:
1. Type definitions and constants
2. Public methods/functions (uppercase)
@@ -227,7 +229,7 @@ func (c *ProjectController) extractProjectID(ctx *gin.Context) uuid.UUID {
}
```
#### Key Points:
#### Key points:
- **Exported/Public** = starts with uppercase letter (CreateUser, GetProject)
- **Unexported/Private** = starts with lowercase letter (validateUser, handleError)
@@ -237,13 +239,13 @@ func (c *ProjectController) extractProjectID(ctx *gin.Context) uuid.UUID {
---
### Boolean Naming
### Boolean naming
**Always prefix boolean variables with verbs like `is`, `has`, `was`, `should`, `can`, etc.**
This makes the code more readable and clearly indicates that the variable represents a true/false state.
#### Good Examples:
#### Good examples:
```go
type User struct {
@@ -265,7 +267,7 @@ wasCompleted := false
hasPermission := checkPermissions()
```
#### Bad Examples:
#### Bad examples:
```go
type User struct {
@@ -286,7 +288,7 @@ completed := false // Should be: wasCompleted
permission := true // Should be: hasPermission
```
#### Common Boolean Prefixes:
#### Common boolean prefixes:
- **is** - current state (IsActive, IsValid, IsEnabled)
- **has** - possession or presence (HasAccess, HasPermission, HasError)
@@ -297,6 +299,167 @@ permission := true // Should be: hasPermission
---
### Add reasonable new lines between logical statements
**Add blank lines between logical blocks to improve code readability.**
Separate different logical operations within a function with blank lines. This makes the code flow clearer and helps identify distinct steps in the logic.
#### Guidelines:
- Add blank line before final `return` statement
- Add blank line after variable declarations before using them
- Add blank line between error handling and subsequent logic
- Add blank line between different logical operations
#### Bad example (without spacing):
```go
func (t *Task) BeforeSave(tx *gorm.DB) error {
if len(t.Messages) > 0 {
messagesBytes, err := json.Marshal(t.Messages)
if err != nil {
return err
}
t.MessagesJSON = string(messagesBytes)
}
return nil
}
func (t *Task) AfterFind(tx *gorm.DB) error {
if t.MessagesJSON != "" {
var messages []onewin_dto.TaskCompletionMessage
if err := json.Unmarshal([]byte(t.MessagesJSON), &messages); err != nil {
return err
}
t.Messages = messages
}
return nil
}
```
#### Good example (with proper spacing):
```go
func (t *Task) BeforeSave(tx *gorm.DB) error {
if len(t.Messages) > 0 {
messagesBytes, err := json.Marshal(t.Messages)
if err != nil {
return err
}
t.MessagesJSON = string(messagesBytes)
}
return nil
}
func (t *Task) AfterFind(tx *gorm.DB) error {
if t.MessagesJSON != "" {
var messages []onewin_dto.TaskCompletionMessage
if err := json.Unmarshal([]byte(t.MessagesJSON), &messages); err != nil {
return err
}
t.Messages = messages
}
return nil
}
```
#### More examples:
**Service method with multiple operations:**
```go
func (s *UserService) CreateUser(request *CreateUserRequest) (*User, error) {
// Validate input
if err := s.validateUserRequest(request); err != nil {
return nil, err
}
// Create user entity
user := &User{
ID: uuid.New(),
Name: request.Name,
Email: request.Email,
}
// Save to database
if err := s.repository.Create(user); err != nil {
return nil, err
}
// Send notification
s.notificationService.SendWelcomeEmail(user.Email)
return user, nil
}
```
**Repository method with query building:**
```go
func (r *Repository) GetFiltered(filters *Filters) ([]*Entity, error) {
query := storage.GetDb().Model(&Entity{})
if filters.Status != "" {
query = query.Where("status = ?", filters.Status)
}
if filters.CreatedAfter != nil {
query = query.Where("created_at > ?", filters.CreatedAfter)
}
var entities []*Entity
if err := query.Find(&entities).Error; err != nil {
return nil, err
}
return entities, nil
}
```
**Repository method with error handling:**
Bad (without spacing):
```go
func (r *Repository) FindById(id uuid.UUID) (*models.Task, error) {
var task models.Task
result := storage.GetDb().Where("id = ?", id).First(&task)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, errors.New("task not found")
}
return nil, result.Error
}
return &task, nil
}
```
Good (with proper spacing):
```go
func (r *Repository) FindById(id uuid.UUID) (*models.Task, error) {
var task models.Task
result := storage.GetDb().Where("id = ?", id).First(&task)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, errors.New("task not found")
}
return nil, result.Error
}
return &task, nil
}
```
---
### Comments
#### Guidelines
@@ -305,13 +468,14 @@ permission := true // Should be: hasPermission
2. **Functions and variables should have meaningful names** - Code should be self-documenting
3. **Comments for unclear code only** - Only add comments when code logic isn't immediately clear
#### Key Principles:
#### Key principles:
- **Code should tell a story** - Use descriptive variable and function names
- **Comments explain WHY, not WHAT** - The code shows what happens, comments explain business logic or complex decisions
- **Prefer refactoring over commenting** - If code needs explaining, consider making it clearer instead
- **API documentation is required** - Swagger comments for all HTTP endpoints are mandatory
- **Complex algorithms deserve comments** - Mathematical formulas, business rules, or non-obvious optimizations
- **Do not write summary sections in .md files unless directly requested** - Avoid adding "Summary" or "Conclusion" sections at the end of documentation files unless the user explicitly asks for them
#### Example of useless comments:
@@ -343,7 +507,7 @@ func CreateValidLogItems(count int, uniqueID string) []logs_receiving.LogItemReq
### Controllers
#### Controller Guidelines:
#### Controller guidelines:
1. **When we write controller:**
- We combine all routes to single controller
@@ -475,7 +639,7 @@ func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
---
### Dependency Injection (DI)
### Dependency injection (DI)
For DI files use **implicit fields declaration styles** (especially for controllers, services, repositories, use cases, etc., not simple data structures).
@@ -503,7 +667,7 @@ var orderController = &OrderController{
**This is needed to avoid forgetting to update DI style when we add new dependency.**
#### Force Such Usage
#### Force such usage
Please force such usage if file look like this (see some services\controllers\repos definitions and getters):
@@ -549,13 +713,13 @@ func GetOrderRepository() *repositories.OrderRepository {
}
```
#### SetupDependencies() Pattern
#### SetupDependencies() pattern
**All `SetupDependencies()` functions must use sync.Once to ensure idempotent execution.**
This pattern allows `SetupDependencies()` to be safely called multiple times (especially in tests) while ensuring the actual setup logic executes only once.
**Implementation Pattern:**
**Implementation pattern:**
```go
package feature
@@ -573,22 +737,22 @@ var (
func SetupDependencies() {
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
// Initialize dependencies here
someService.SetDependency(otherService)
anotherService.AddListener(listener)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}
```
**Why This Pattern:**
**Why this pattern:**
- **Tests can call multiple times**: Test setup often calls `SetupDependencies()` multiple times without issues
- **Thread-safe**: Works correctly with concurrent calls (nanoseconds or seconds apart)
@@ -604,13 +768,13 @@ func SetupDependencies() {
---
### Background Services
### Background services
**All background service `Run()` methods must panic if called multiple times to prevent corrupted states.**
Background services run infinite loops and must never be started twice on the same instance. Multiple calls indicate a serious bug that would cause duplicate goroutines, resource leaks, and data corruption.
**Implementation Pattern:**
**Implementation pattern:**
```go
package feature
@@ -630,14 +794,14 @@ type BackgroundService struct {
func (s *BackgroundService) Run(ctx context.Context) {
wasAlreadyRun := s.hasRun.Load()
s.runOnce.Do(func() {
s.hasRun.Store(true)
// Existing infinite loop logic
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
@@ -647,21 +811,21 @@ func (s *BackgroundService) Run(ctx context.Context) {
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
}
```
**Why Panic Instead of Warning:**
**Why panic instead of warning:**
- **Prevents corruption**: Multiple `Run()` calls would create duplicate goroutines consuming resources
- **Fails fast**: Catches critical bugs immediately in tests and production
- **Clear indication**: Panic clearly indicates a serious programming error
- **Applies everywhere**: Same protection in tests and production
**When This Applies:**
**When this applies:**
- All background services with infinite loops
- Registry services (BackupNodesRegistry, RestoreNodesRegistry)
@@ -727,14 +891,14 @@ You can shortify, make more readable, improve code quality, etc. Common logic ca
**After writing tests, always launch them and verify that they pass.**
#### Test Naming Format
#### Test naming format
Use these naming patterns:
- `Test_WhatWeDo_WhatWeExpect`
- `Test_WhatWeDo_WhichConditions_WhatWeExpect`
#### Examples from Real Codebase:
#### Examples from real codebase:
- `Test_CreateApiKey_WhenUserIsProjectOwner_ApiKeyCreated`
- `Test_UpdateProject_WhenUserIsProjectAdmin_ProjectUpdated`
@@ -742,22 +906,22 @@ Use these naming patterns:
- `Test_GetProjectAuditLogs_WithDifferentUserRoles_EnforcesPermissionsCorrectly`
- `Test_ProjectLifecycleE2E_CompletesSuccessfully`
#### Testing Philosophy
#### Testing philosophy
**Prefer Controllers Over Unit Tests:**
**Prefer controllers over unit tests:**
- Test through HTTP endpoints via controllers whenever possible
- Avoid testing repositories, services in isolation - test via API instead
- Only use unit tests for complex model logic when no API exists
- Name test files `controller_test.go` or `service_test.go`, not `integration_test.go`
**Extract Common Logic to Testing Utilities:**
**Extract common logic to testing utilities:**
- Create `testing.go` or `testing/testing.go` files for shared test utilities
- Extract router creation, user setup, models creation helpers (in API, not just structs creation)
- Reuse common patterns across different test files
**Refactor Existing Tests:**
**Refactor existing tests:**
- When working with existing tests, always look for opportunities to refactor and improve
- Extract repetitive setup code to common utilities
@@ -766,7 +930,44 @@ Use these naming patterns:
- Consolidate similar test patterns across different test files
- Make tests more readable and maintainable for other developers
#### Testing Utilities Structure
**Clean up test data:**
- If the feature supports cleanup operations (DELETE endpoints, cleanup methods), use them in tests
- Clean up resources after test execution to avoid test data pollution
- Use `defer` statements or explicit cleanup calls at the end of tests
- Prioritize using API methods for cleanup (not direct database deletion)
- Examples:
- CRUD features: delete created records via DELETE endpoint
- File uploads: remove uploaded files
- Background jobs: stop schedulers or cancel running tasks
- Skip cleanup only when:
- Tests run in isolated transactions that auto-rollback
- Cleanup endpoint doesn't exist yet
- Test explicitly validates failure scenarios where cleanup isn't possible
**Example:**
```go
func Test_BackupLifecycle_CreateAndDelete(t *testing.T) {
router := createTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test", owner)
// Create backup config
config := createBackupConfig(t, router, workspace.ID, owner.Token)
// Cleanup at end of test
defer deleteBackupConfig(t, router, workspace.ID, config.ID, owner.Token)
// Test operations...
triggerBackup(t, router, workspace.ID, config.ID, owner.Token)
// Verify backup was created
backups := getBackups(t, router, workspace.ID, owner.Token)
assert.NotEmpty(t, backups)
}
```
#### Testing utilities structure
**Create `testing.go` or `testing/testing.go` files with common utilities:**
@@ -802,7 +1003,7 @@ func AddMemberToProject(project *projects_models.Project, member *users_dto.Sign
}
```
#### Controller Test Examples
#### Controller test examples
**Permission-based testing:**
@@ -869,7 +1070,7 @@ func Test_ProjectLifecycleE2E_CompletesSuccessfully(t *testing.T) {
---
### Time Handling
### Time handling
**Always use `time.Now().UTC()` instead of `time.Now()`**
@@ -877,7 +1078,7 @@ This ensures consistent timezone handling across the application.
---
### CRUD Examples
### CRUD examples
This is an example of complete CRUD implementation structure:
@@ -1541,9 +1742,9 @@ func createTimedLog(db *gorm.DB, userID *uuid.UUID, message string, createdAt ti
---
## Frontend Guidelines
## Frontend guidelines
### React Component Structure
### React component structure
Write React components with the following structure:
@@ -1577,7 +1778,7 @@ export const ReactComponent = ({ someValue }: Props): JSX.Element => {
}
```
#### Structure Order:
#### Structure order:
1. **Props interface** - Define component props
2. **Helper functions** (outside component) - Pure utility functions

View File

@@ -251,6 +251,37 @@ fi
# PostgreSQL 17 binary paths
PG_BIN="/usr/lib/postgresql/17/bin"
# Generate runtime configuration for frontend
echo "Generating runtime configuration..."
# Detect if email is configured (both SMTP_HOST and DATABASUS_URL must be set)
if [ -n "\${SMTP_HOST:-}" ] && [ -n "\${DATABASUS_URL:-}" ]; then
IS_EMAIL_CONFIGURED="true"
else
IS_EMAIL_CONFIGURED="false"
fi
cat > /app/ui/build/runtime-config.js <<JSEOF
// Runtime configuration injected at container startup
// This file is generated dynamically and should not be edited manually
window.__RUNTIME_CONFIG__ = {
IS_CLOUD: '\${IS_CLOUD:-false}',
GITHUB_CLIENT_ID: '\${GITHUB_CLIENT_ID:-}',
GOOGLE_CLIENT_ID: '\${GOOGLE_CLIENT_ID:-}',
IS_EMAIL_CONFIGURED: '\$IS_EMAIL_CONFIGURED',
CLOUDFLARE_TURNSTILE_SITE_KEY: '\${CLOUDFLARE_TURNSTILE_SITE_KEY:-}'
};
JSEOF
# Inject analytics script if provided (only if not already injected)
if [ -n "\${ANALYTICS_SCRIPT:-}" ]; then
if ! grep -q "rybbit.databasus.com" /app/ui/build/index.html 2>/dev/null; then
echo "Injecting analytics script..."
sed -i "s#</head># \${ANALYTICS_SCRIPT}\\
</head>#" /app/ui/build/index.html
fi
fi
# Ensure proper ownership of data directory
echo "Setting up data directory permissions..."
mkdir -p /databasus-data/pgdata
@@ -372,9 +403,37 @@ SQL
# Start the main application
echo "Starting Databasus application..."
# Check and warn about external database/Valkey usage
if [ -n "\${DANGEROUS_EXTERNAL_DATABASE_DSN:-}" ]; then
echo ""
echo "=========================================="
echo "WARNING: Using external database"
echo "=========================================="
echo "DANGEROUS_EXTERNAL_DATABASE_DSN is set."
echo "Application will connect to external PostgreSQL instead of internal instance."
echo "Internal PostgreSQL is still running in the background."
echo "=========================================="
echo ""
fi
if [ -n "\${DANGEROUS_VALKEY_HOST:-}" ]; then
echo ""
echo "=========================================="
echo "WARNING: Using external Valkey"
echo "=========================================="
echo "DANGEROUS_VALKEY_HOST is set."
echo "Application will connect to external Valkey instead of internal instance."
echo "Internal Valkey is still running in the background."
echo "=========================================="
echo ""
fi
exec ./main
EOF
LABEL org.opencontainers.image.source="https://github.com/databasus/databasus"
RUN chmod +x /app/start.sh
EXPOSE 4005
@@ -383,4 +442,4 @@ EXPOSE 4005
VOLUME ["/databasus-data"]
ENTRYPOINT ["/app/start.sh"]
CMD []
CMD []

View File

@@ -11,7 +11,7 @@
[![MongoDB](https://img.shields.io/badge/MongoDB-47A248?logo=mongodb&logoColor=white)](https://www.mongodb.com/)
<br />
[![Apache 2.0 License](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE)
[![Docker Pulls](https://img.shields.io/docker/pulls/rostislavdugin/postgresus?color=brightgreen)](https://hub.docker.com/r/rostislavdugin/postgresus)
[![Docker Pulls](https://img.shields.io/docker/pulls/databasus/databasus?color=brightgreen)](https://hub.docker.com/r/databasus/databasus)
[![Platform](https://img.shields.io/badge/platform-linux%20%7C%20macos%20%7C%20windows-lightgrey)](https://github.com/databasus/databasus)
[![Self Hosted](https://img.shields.io/badge/self--hosted-yes-brightgreen)](https://github.com/databasus/databasus)
[![Open Source](https://img.shields.io/badge/open%20source-❤️-red)](https://github.com/databasus/databasus)
@@ -31,8 +31,6 @@
<img src="assets/dashboard-dark.svg" alt="Databasus Dark Dashboard" width="800" style="margin-bottom: 10px;"/>
<img src="assets/dashboard.svg" alt="Databasus Dashboard" width="800"/>
</div>
---

View File

@@ -6,9 +6,14 @@ DEV_DB_PASSWORD=Q1234567
ENV_MODE=development
# logging
SHOW_DB_INSTALLATION_VERIFICATION_LOGS=true
VICTORIA_LOGS_URL=http://localhost:9428
VICTORIA_LOGS_PASSWORD=devpassword
# tests
TEST_LOCALHOST=localhost
IS_SKIP_EXTERNAL_RESOURCES_TESTS=false
# cloudflare turnstile
CLOUDFLARE_TURNSTILE_SITE_KEY=
CLOUDFLARE_TURNSTILE_SECRET_KEY=
# db
DATABASE_DSN=host=dev-db user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
DATABASE_URL=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable

3
backend/.gitignore vendored
View File

@@ -18,4 +18,5 @@ pgdata-for-restore/
temp/
cmd.exe
temp/
valkey-data/
valkey-data/
victoria-logs-data/

View File

@@ -185,6 +185,9 @@ func startServerWithGracefulShutdown(log *slog.Logger, app *gin.Engine) {
<-quit
log.Info("Shutdown signal received")
// Gracefully shutdown VictoriaLogs writer
logger.ShutdownVictoriaLogs(5 * time.Second)
// The context is used to inform the server it has 10 seconds to finish
// the request it is currently handling
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -272,6 +275,10 @@ func runBackgroundTasks(log *slog.Logger) {
backuping.GetBackupsScheduler().Run(ctx)
})
go runWithPanicLogging(log, "backup cleaner background service", func() {
backuping.GetBackupCleaner().Run(ctx)
})
go runWithPanicLogging(log, "restore background service", func() {
restoring.GetRestoresScheduler().Run(ctx)
})

View File

@@ -34,6 +34,20 @@ services:
retries: 5
start_period: 20s
# VictoriaLogs for external logging
victoria-logs:
image: victoriametrics/victoria-logs:latest
container_name: victoria-logs
ports:
- "9428:9428"
command:
- -storageDataPath=/victoria-logs-data
- -retentionPeriod=7d
- -httpAuth.password=devpassword
volumes:
- ./victoria-logs-data:/victoria-logs-data
restart: unless-stopped
# Test MinIO container
test-minio:
image: minio/minio:latest

View File

@@ -22,13 +22,22 @@ const (
type EnvVariables struct {
IsTesting bool
DatabaseDsn string `env:"DATABASE_DSN" required:"true"`
EnvMode env_utils.EnvMode `env:"ENV_MODE" required:"true"`
PostgresesInstallDir string `env:"POSTGRES_INSTALL_DIR"`
MysqlInstallDir string `env:"MYSQL_INSTALL_DIR"`
MariadbInstallDir string `env:"MARIADB_INSTALL_DIR"`
MongodbInstallDir string `env:"MONGODB_INSTALL_DIR"`
// Internal database
DatabaseDsn string `env:"DATABASE_DSN" required:"true"`
// Internal Valkey
ValkeyHost string `env:"VALKEY_HOST" required:"true"`
ValkeyPort string `env:"VALKEY_PORT" required:"true"`
ValkeyUsername string `env:"VALKEY_USERNAME" required:"true"`
ValkeyPassword string `env:"VALKEY_PASSWORD" required:"true"`
ValkeyIsSsl bool `env:"VALKEY_IS_SSL" required:"true"`
IsCloud bool `env:"IS_CLOUD"`
TestLocalhost string `env:"TEST_LOCALHOST"`
ShowDbInstallationVerificationLogs bool `env:"SHOW_DB_INSTALLATION_VERIFICATION_LOGS"`
@@ -89,19 +98,16 @@ type EnvVariables struct {
TestMongodb70Port string `env:"TEST_MONGODB_70_PORT"`
TestMongodb82Port string `env:"TEST_MONGODB_82_PORT"`
// Valkey
ValkeyHost string `env:"VALKEY_HOST" required:"true"`
ValkeyPort string `env:"VALKEY_PORT" required:"true"`
ValkeyUsername string `env:"VALKEY_USERNAME"`
ValkeyPassword string `env:"VALKEY_PASSWORD"`
ValkeyIsSsl bool `env:"VALKEY_IS_SSL" required:"true"`
// oauth
GitHubClientID string `env:"GITHUB_CLIENT_ID"`
GitHubClientSecret string `env:"GITHUB_CLIENT_SECRET"`
GoogleClientID string `env:"GOOGLE_CLIENT_ID"`
GoogleClientSecret string `env:"GOOGLE_CLIENT_SECRET"`
// Cloudflare Turnstile
CloudflareTurnstileSecretKey string `env:"CLOUDFLARE_TURNSTILE_SECRET_KEY"`
CloudflareTurnstileSiteKey string `env:"CLOUDFLARE_TURNSTILE_SITE_KEY"`
// testing Telegram
TestTelegramBotToken string `env:"TEST_TELEGRAM_BOT_TOKEN"`
TestTelegramChatID string `env:"TEST_TELEGRAM_CHAT_ID"`
@@ -112,6 +118,15 @@ type EnvVariables struct {
TestSupabaseUsername string `env:"TEST_SUPABASE_USERNAME"`
TestSupabasePassword string `env:"TEST_SUPABASE_PASSWORD"`
TestSupabaseDatabase string `env:"TEST_SUPABASE_DATABASE"`
// SMTP configuration (optional)
SMTPHost string `env:"SMTP_HOST"`
SMTPPort int `env:"SMTP_PORT"`
SMTPUser string `env:"SMTP_USER"`
SMTPPassword string `env:"SMTP_PASSWORD"`
// Application URL (optional) - used for email links
DatabasusURL string `env:"DATABASUS_URL"`
}
var (
@@ -182,6 +197,11 @@ func loadEnvVariables() {
env.IsSkipExternalResourcesTests = false
}
// Set default value for IsCloud if not defined
if os.Getenv("IS_CLOUD") == "" {
env.IsCloud = false
}
for _, arg := range os.Args {
if strings.Contains(arg, "test") {
env.IsTesting = true
@@ -189,6 +209,14 @@ func loadEnvVariables() {
}
}
// Check for external database override
if externalDsn := os.Getenv("DANGEROUS_EXTERNAL_DATABASE_DSN"); externalDsn != "" {
log.Warn(
"Using DANGEROUS_EXTERNAL_DATABASE_DSN - connecting to external database instead of internal PostgreSQL",
)
env.DatabaseDsn = externalDsn
}
if env.DatabaseDsn == "" {
log.Error("DATABASE_DSN is empty")
os.Exit(1)
@@ -259,6 +287,27 @@ func loadEnvVariables() {
os.Exit(1)
}
// Check for external Valkey override
if externalValkeyHost := os.Getenv("DANGEROUS_VALKEY_HOST"); externalValkeyHost != "" {
log.Warn(
"Using DANGEROUS_VALKEY_* variables - connecting to external Valkey instead of internal instance",
)
env.ValkeyHost = externalValkeyHost
if externalValkeyPort := os.Getenv("DANGEROUS_VALKEY_PORT"); externalValkeyPort != "" {
env.ValkeyPort = externalValkeyPort
}
if externalValkeyUsername := os.Getenv("DANGEROUS_VALKEY_USERNAME"); externalValkeyUsername != "" {
env.ValkeyUsername = externalValkeyUsername
}
if externalValkeyPassword := os.Getenv("DANGEROUS_VALKEY_PASSWORD"); externalValkeyPassword != "" {
env.ValkeyPassword = externalValkeyPassword
}
if externalValkeyIsSsl := os.Getenv("DANGEROUS_VALKEY_IS_SSL"); externalValkeyIsSsl != "" {
env.ValkeyIsSsl = externalValkeyIsSsl == "true"
}
}
// Store the data and temp folders one level below the root
// (projectRoot/databasus-data -> /databasus-data)
env.DataFolder = filepath.Join(filepath.Dir(backendRoot), "databasus-data", "backups")

View File

@@ -70,16 +70,18 @@ func (n *BackuperNode) Run(ctx context.Context) {
}
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
n.MakeBackup(backupID, isCallNotifier)
if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil {
n.logger.Error(
"Failed to publish backup completion",
"error",
err,
"backupID",
backupID,
)
}
go func() {
n.MakeBackup(backupID, isCallNotifier)
if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil {
n.logger.Error(
"Failed to publish backup completion",
"error",
err,
"backupID",
backupID,
)
}
}()
}
err := n.backupNodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID, backupHandler)
@@ -157,30 +159,73 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
start := time.Now().UTC()
ctx, cancel := context.WithCancel(context.Background())
n.backupCancelManager.RegisterTask(backup.ID, cancel)
defer n.backupCancelManager.UnregisterTask(backup.ID)
backupProgressListener := func(
completedMBs float64,
) {
backup.BackupSizeMb = completedMBs
backup.BackupDurationMs = time.Since(start).Milliseconds()
// Check size limit (0 = unlimited)
if backupConfig.MaxBackupSizeMB > 0 &&
completedMBs > float64(backupConfig.MaxBackupSizeMB) {
errMsg := fmt.Sprintf(
"backup size (%.2f MB) exceeded maximum allowed size (%d MB)",
completedMBs,
backupConfig.MaxBackupSizeMB,
)
backup.Status = backups_core.BackupStatusFailed
backup.IsSkipRetry = true
backup.FailMessage = &errMsg
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to save backup with size exceeded error", "error", err)
}
cancel() // Cancel the backup context
return
}
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to update backup progress", "error", err)
}
}
ctx, cancel := context.WithCancel(context.Background())
n.backupCancelManager.RegisterTask(backup.ID, cancel)
defer n.backupCancelManager.UnregisterTask(backup.ID)
backupMetadata, err := n.createBackupUseCase.Execute(
ctx,
backup.ID,
backup,
backupConfig,
database,
storage,
backupProgressListener,
)
if err != nil {
// Check if backup was already marked as failed by progress listener (e.g., size limit exceeded)
// If so, skip error handling to avoid overwriting the status
currentBackup, fetchErr := n.backupRepository.FindByID(backup.ID)
if fetchErr == nil && currentBackup.Status == backups_core.BackupStatusFailed {
n.logger.Warn(
"Backup already marked as failed by progress listener, skipping error handling",
"backupId",
backup.ID,
"failMessage",
*currentBackup.FailMessage,
)
// Still call notification for size limit failures
n.SendBackupNotification(
backupConfig,
currentBackup,
backups_config.NotificationBackupFailed,
currentBackup.FailMessage,
)
return
}
errMsg := err.Error()
// Log detailed error information for debugging
@@ -218,7 +263,7 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
// Delete partial backup from storage
storage, storageErr := n.storageService.GetStorageByID(backup.StorageID)
if storageErr == nil {
if deleteErr := storage.DeleteFile(n.fieldEncryptor, backup.ID); deleteErr != nil {
if deleteErr := storage.DeleteFile(n.fieldEncryptor, backup.ID.String()); deleteErr != nil {
n.logger.Error(
"Failed to delete partial backup file",
"backupId",

View File

@@ -1,13 +1,10 @@
package backuping
import (
"context"
"errors"
"strings"
"testing"
"time"
common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -18,7 +15,6 @@ import (
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
cache_utils "databasus-backend/internal/util/cache"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
@@ -158,35 +154,120 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
})
}
type CreateFailedBackupUsecase struct {
}
func Test_BackupSizeLimits(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
func (uc *CreateFailedBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
backupProgressListener(10)
return nil, errors.New("backup failed")
}
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
type CreateSuccessBackupUsecase struct{}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
func (uc *CreateSuccessBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
backupProgressListener(10)
return &common.BackupMetadata{
EncryptionSalt: nil,
EncryptionIV: nil,
Encryption: backups_config.BackupEncryptionNone,
}, nil
t.Run("UnlimitedSize_MaxBackupSizeMBIsZero_BackupCompletes", func(t *testing.T) {
// Enable backups with unlimited size (0)
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backupConfig.MaxBackupSizeMB = 0 // unlimited
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backuperNode := CreateTestBackuperNode()
backuperNode.createBackupUseCase = &CreateLargeBackupUsecase{}
// Create a backup record
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backuperNode.MakeBackup(backup.ID, false)
// Verify backup completed successfully even with large size
updatedBackup, err := backupRepository.FindByID(backup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
assert.Equal(t, float64(10000), updatedBackup.BackupSizeMb)
assert.Nil(t, updatedBackup.FailMessage)
})
t.Run("SizeExceeded_BackupFailedWithIsSkipRetry", func(t *testing.T) {
// Enable backups with 5 MB limit
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backupConfig.MaxBackupSizeMB = 5
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backuperNode := CreateTestBackuperNode()
backuperNode.createBackupUseCase = &CreateProgressiveBackupUsecase{}
// Create a backup record
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backuperNode.MakeBackup(backup.ID, false)
// Verify backup was marked as failed with IsSkipRetry=true
updatedBackup, err := backupRepository.FindByID(backup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusFailed, updatedBackup.Status)
assert.True(t, updatedBackup.IsSkipRetry)
assert.NotNil(t, updatedBackup.FailMessage)
assert.Contains(t, *updatedBackup.FailMessage, "exceeded maximum allowed size")
assert.Contains(t, *updatedBackup.FailMessage, "10.00 MB")
assert.Contains(t, *updatedBackup.FailMessage, "5 MB")
assert.Greater(t, updatedBackup.BackupSizeMb, float64(5))
})
t.Run("SizeWithinLimit_BackupCompletes", func(t *testing.T) {
// Enable backups with 100 MB limit
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backupConfig.MaxBackupSizeMB = 100
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backuperNode := CreateTestBackuperNode()
backuperNode.createBackupUseCase = &CreateMediumBackupUsecase{}
// Create a backup record
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backuperNode.MakeBackup(backup.ID, false)
// Verify backup completed successfully
updatedBackup, err := backupRepository.FindByID(backup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
assert.Equal(t, float64(50), updatedBackup.BackupSizeMb)
assert.Nil(t, updatedBackup.FailMessage)
})
}

View File

@@ -0,0 +1,242 @@
package backuping
import (
"context"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/storages"
util_encryption "databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/period"
)
const (
cleanerTickerInterval = 1 * time.Minute
)
type BackupCleaner struct {
backupRepository *backups_core.BackupRepository
storageService *storages.StorageService
backupConfigService *backups_config.BackupConfigService
fieldEncryptor util_encryption.FieldEncryptor
logger *slog.Logger
backupRemoveListeners []backups_core.BackupRemoveListener
runOnce sync.Once
hasRun atomic.Bool
}
func (c *BackupCleaner) Run(ctx context.Context) {
wasAlreadyRun := c.hasRun.Load()
c.runOnce.Do(func() {
c.hasRun.Store(true)
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(cleanerTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := c.cleanOldBackups(); err != nil {
c.logger.Error("Failed to clean old backups", "error", err)
}
if err := c.cleanExceededBackups(); err != nil {
c.logger.Error("Failed to clean exceeded backups", "error", err)
}
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", c))
}
}
func (c *BackupCleaner) DeleteBackup(backup *backups_core.Backup) error {
for _, listener := range c.backupRemoveListeners {
if err := listener.OnBeforeBackupRemove(backup); err != nil {
return err
}
}
storage, err := c.storageService.GetStorageByID(backup.StorageID)
if err != nil {
return err
}
err = storage.DeleteFile(c.fieldEncryptor, backup.ID.String())
if err != nil {
// we do not return error here, because sometimes clean up performed
// before unavailable storage removal or change - therefore we should
// proceed even in case of error. It's possible that some S3 or
// storage is not available yet, it should not block us
c.logger.Error("Failed to delete backup file", "error", err)
}
return c.backupRepository.DeleteByID(backup.ID)
}
func (c *BackupCleaner) AddBackupRemoveListener(listener backups_core.BackupRemoveListener) {
c.backupRemoveListeners = append(c.backupRemoveListeners, listener)
}
func (c *BackupCleaner) cleanOldBackups() error {
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
backupStorePeriod := backupConfig.StorePeriod
if backupStorePeriod == period.PeriodForever {
continue
}
storeDuration := backupStorePeriod.ToDuration()
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
oldBackups, err := c.backupRepository.FindBackupsBeforeDate(
backupConfig.DatabaseID,
dateBeforeBackupsShouldBeDeleted,
)
if err != nil {
c.logger.Error(
"Failed to find old backups for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
for _, backup := range oldBackups {
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
continue
}
c.logger.Info(
"Deleted old backup",
"backupId",
backup.ID,
"databaseId",
backupConfig.DatabaseID,
)
}
}
return nil
}
func (c *BackupCleaner) cleanExceededBackups() error {
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
if backupConfig.MaxBackupsTotalSizeMB <= 0 {
continue
}
if err := c.cleanExceededBackupsForDatabase(
backupConfig.DatabaseID,
backupConfig.MaxBackupsTotalSizeMB,
); err != nil {
c.logger.Error(
"Failed to clean exceeded backups for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
}
return nil
}
func (c *BackupCleaner) cleanExceededBackupsForDatabase(
databaseID uuid.UUID,
limitperDbMB int64,
) error {
for {
backupsTotalSizeMB, err := c.backupRepository.GetTotalSizeByDatabase(databaseID)
if err != nil {
return err
}
if backupsTotalSizeMB <= float64(limitperDbMB) {
break
}
oldestBackups, err := c.backupRepository.FindOldestByDatabaseExcludingInProgress(
databaseID,
1,
)
if err != nil {
return err
}
if len(oldestBackups) == 0 {
c.logger.Warn(
"No backups to delete but still over limit",
"databaseId",
databaseID,
"totalSizeMB",
backupsTotalSizeMB,
"limitMB",
limitperDbMB,
)
break
}
backup := oldestBackups[0]
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error(
"Failed to delete exceeded backup",
"backupId",
backup.ID,
"databaseId",
databaseID,
"error",
err,
)
return err
}
c.logger.Info(
"Deleted exceeded backup",
"backupId",
backup.ID,
"databaseId",
databaseID,
"backupSizeMB",
backup.BackupSizeMb,
"totalSizeMB",
backupsTotalSizeMB,
"limitMB",
limitperDbMB,
)
}
return nil
}

View File

@@ -0,0 +1,595 @@
package backuping
import (
"testing"
"time"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/storage"
"databasus-backend/internal/util/period"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_CleanOldBackups_DeletesBackupsOlderThanStorePeriod(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Create backup interval
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Create backups with different ages
now := time.Now().UTC()
oldBackup1 := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
CreatedAt: now.Add(-10 * 24 * time.Hour), // 10 days old
}
oldBackup2 := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
CreatedAt: now.Add(-8 * 24 * time.Hour), // 8 days old
}
recentBackup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
CreatedAt: now.Add(-3 * 24 * time.Hour), // 3 days old
}
err = backupRepository.Save(oldBackup1)
assert.NoError(t, err)
err = backupRepository.Save(oldBackup2)
assert.NoError(t, err)
err = backupRepository.Save(recentBackup)
assert.NoError(t, err)
// Run cleanup
cleaner := GetBackupCleaner()
err = cleaner.cleanOldBackups()
assert.NoError(t, err)
// Verify old backups deleted, recent backup remains
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Equal(t, 1, len(remainingBackups))
assert.Equal(t, recentBackup.ID, remainingBackups[0].ID)
}
func Test_CleanOldBackups_SkipsDatabaseWithForeverStorePeriod(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Create backup interval
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodForever,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Create very old backup
oldBackup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
CreatedAt: time.Now().UTC().Add(-365 * 24 * time.Hour), // 1 year old
}
err = backupRepository.Save(oldBackup)
assert.NoError(t, err)
// Run cleanup
cleaner := GetBackupCleaner()
err = cleaner.cleanOldBackups()
assert.NoError(t, err)
// Verify backup still exists
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Equal(t, 1, len(remainingBackups))
assert.Equal(t, oldBackup.ID, remainingBackups[0].ID)
}
func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Create backup interval
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 100, // 100 MB limit
BackupIntervalID: interval.ID,
BackupInterval: interval,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Create 3 backups totaling 50MB (under limit)
for i := 0; i < 3; i++ {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 16.67,
CreatedAt: time.Now().UTC().Add(-time.Duration(i) * time.Hour),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
}
// Run cleanup
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
assert.NoError(t, err)
// Verify all backups remain
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Equal(t, 3, len(remainingBackups))
}
func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Create backup interval
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 30, // 30 MB limit
BackupIntervalID: interval.ID,
BackupInterval: interval,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Create 5 backups of 10MB each (total 50MB, over 30MB limit)
now := time.Now().UTC()
var backupIDs []uuid.UUID
for i := 0; i < 5; i++ {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
CreatedAt: now.Add(-time.Duration(4-i) * time.Hour), // Oldest first
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backupIDs = append(backupIDs, backup.ID)
}
// Run cleanup
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
assert.NoError(t, err)
// Verify 2 oldest backups deleted, 3 newest remain
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Equal(t, 3, len(remainingBackups))
// Check that the newest 3 backups remain
remainingIDs := make(map[uuid.UUID]bool)
for _, backup := range remainingBackups {
remainingIDs[backup.ID] = true
}
assert.False(t, remainingIDs[backupIDs[0]]) // Oldest deleted
assert.False(t, remainingIDs[backupIDs[1]]) // 2nd oldest deleted
assert.True(t, remainingIDs[backupIDs[2]]) // 3rd remains
assert.True(t, remainingIDs[backupIDs[3]]) // 4th remains
assert.True(t, remainingIDs[backupIDs[4]]) // Newest remains
}
func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Create backup interval
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 50, // 50 MB limit
BackupIntervalID: interval.ID,
BackupInterval: interval,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
now := time.Now().UTC()
// Create 3 completed backups of 30MB each
completedBackups := make([]*backups_core.Backup, 3)
for i := 0; i < 3; i++ {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 30,
CreatedAt: now.Add(-time.Duration(3-i) * time.Hour),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
completedBackups[i] = backup
}
// Create 1 in-progress backup (should be excluded from size calculation and deletion)
inProgressBackup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
BackupSizeMb: 10,
CreatedAt: now,
}
err = backupRepository.Save(inProgressBackup)
assert.NoError(t, err)
// Run cleanup
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
assert.NoError(t, err)
// Verify: only completed backups deleted, in-progress remains
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
// Should have in-progress + 1 completed (total 40MB completed + 10MB in-progress)
assert.GreaterOrEqual(t, len(remainingBackups), 2)
// Verify in-progress backup still exists
var inProgressFound bool
for _, backup := range remainingBackups {
if backup.ID == inProgressBackup.ID {
inProgressFound = true
assert.Equal(t, backups_core.BackupStatusInProgress, backup.Status)
}
}
assert.True(t, inProgressFound, "In-progress backup should not be deleted")
}
func Test_CleanExceededBackups_WithZeroLimit_SkipsDatabase(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Create backup interval
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 0, // No size limit
BackupIntervalID: interval.ID,
BackupInterval: interval,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Create large backups
for i := 0; i < 10; i++ {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 100,
CreatedAt: time.Now().UTC().Add(-time.Duration(i) * time.Hour),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
}
// Run cleanup
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
assert.NoError(t, err)
// Verify all backups remain
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Equal(t, 10, len(remainingBackups))
}
func Test_GetTotalSizeByDatabase_CalculatesCorrectly(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Create completed backups
completedBackup1 := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10.5,
CreatedAt: time.Now().UTC(),
}
completedBackup2 := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 20.3,
CreatedAt: time.Now().UTC(),
}
// Create failed backup (should be included)
failedBackup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusFailed,
BackupSizeMb: 5.2,
CreatedAt: time.Now().UTC(),
}
// Create in-progress backup (should be excluded)
inProgressBackup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
BackupSizeMb: 100,
CreatedAt: time.Now().UTC(),
}
err := backupRepository.Save(completedBackup1)
assert.NoError(t, err)
err = backupRepository.Save(completedBackup2)
assert.NoError(t, err)
err = backupRepository.Save(failedBackup)
assert.NoError(t, err)
err = backupRepository.Save(inProgressBackup)
assert.NoError(t, err)
// Calculate total size
totalSize, err := backupRepository.GetTotalSizeByDatabase(database.ID)
assert.NoError(t, err)
// Should be 10.5 + 20.3 + 5.2 = 36.0 (excluding in-progress 100)
assert.InDelta(t, 36.0, totalSize, 0.1)
}
// Mock listener for testing
type mockBackupRemoveListener struct {
onBeforeBackupRemove func(*backups_core.Backup) error
}
func (m *mockBackupRemoveListener) OnBeforeBackupRemove(backup *backups_core.Backup) error {
if m.onBeforeBackupRemove != nil {
return m.onBeforeBackupRemove(backup)
}
return nil
}
// Test_DeleteBackup_WhenStorageDeleteFails_BackupStillRemovedFromDatabase verifies resilience
// when storage becomes unavailable. Even if storage.DeleteFile fails (e.g., storage is offline,
// credentials changed, or storage was deleted), the backup record should still be removed from
// the database. This prevents orphaned backup records when storage is no longer accessible.
func Test_DeleteBackup_WhenStorageDeleteFails_BackupStillRemovedFromDatabase(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
testStorage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, testStorage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(testStorage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: testStorage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
CreatedAt: time.Now().UTC(),
}
err := backupRepository.Save(backup)
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.DeleteBackup(backup)
assert.NoError(t, err, "DeleteBackup should succeed even when storage file doesn't exist")
deletedBackup, err := backupRepository.FindByID(backup.ID)
assert.Error(t, err, "Backup should not exist in database")
assert.Nil(t, deletedBackup)
}
func createTestInterval() *intervals.Interval {
timeOfDay := "04:00"
interval := &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
err := storage.GetDb().Create(interval).Error
if err != nil {
panic(err)
}
return interval
}

View File

@@ -24,14 +24,25 @@ var backupRepository = &backups_core.BackupRepository{}
var taskCancelManager = tasks_cancellation.GetTaskCancelManager()
var backupCleaner = &BackupCleaner{
backupRepository,
storages.GetStorageService(),
backups_config.GetBackupConfigService(),
encryption.GetFieldEncryptor(),
logger.GetLogger(),
[]backups_core.BackupRemoveListener{},
sync.Once{},
atomic.Bool{},
}
var backupNodesRegistry = &BackupNodesRegistry{
client: cache_utils.GetValkeyClient(),
logger: logger.GetLogger(),
timeout: cache_utils.DefaultCacheTimeout,
pubsubBackups: cache_utils.NewPubSubManager(),
pubsubCompletions: cache_utils.NewPubSubManager(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
cache_utils.GetValkeyClient(),
logger.GetLogger(),
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
sync.Once{},
atomic.Bool{},
}
func getNodeID() uuid.UUID {
@@ -39,35 +50,35 @@ func getNodeID() uuid.UUID {
}
var backuperNode = &BackuperNode{
databaseService: databases.GetDatabaseService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
workspaceService: workspaces_services.GetWorkspaceService(),
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
notificationSender: notifiers.GetNotifierService(),
backupCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
logger: logger.GetLogger(),
createBackupUseCase: usecases.GetCreateBackupUsecase(),
nodeID: getNodeID(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
databases.GetDatabaseService(),
encryption.GetFieldEncryptor(),
workspaces_services.GetWorkspaceService(),
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
taskCancelManager,
backupNodesRegistry,
logger.GetLogger(),
usecases.GetCreateBackupUsecase(),
getNodeID(),
time.Time{},
sync.Once{},
atomic.Bool{},
}
var backupsScheduler = &BackupsScheduler{
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
taskCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
lastBackupTime: time.Now().UTC(),
logger: logger.GetLogger(),
backupToNodeRelations: make(map[uuid.UUID]BackupToNodeRelation),
backuperNode: backuperNode,
runOnce: sync.Once{},
hasRun: atomic.Bool{},
backupRepository,
backups_config.GetBackupConfigService(),
taskCancelManager,
backupNodesRegistry,
databases.GetDatabaseService(),
time.Now().UTC(),
logger.GetLogger(),
make(map[uuid.UUID]BackupToNodeRelation),
backuperNode,
sync.Once{},
atomic.Bool{},
}
func GetBackupsScheduler() *BackupsScheduler {
@@ -81,3 +92,7 @@ func GetBackuperNode() *BackuperNode {
func GetBackupNodesRegistry() *BackupNodesRegistry {
return backupNodesRegistry
}
func GetBackupCleaner() *BackupCleaner {
return backupCleaner
}

View File

@@ -1,8 +1,19 @@
package backuping
import (
"databasus-backend/internal/features/notifiers"
"context"
"errors"
"sync/atomic"
"time"
common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
"github.com/google/uuid"
"github.com/stretchr/testify/mock"
)
@@ -17,3 +28,168 @@ func (m *MockNotificationSender) SendNotification(
) {
m.Called(notifier, title, message)
}
type CreateFailedBackupUsecase struct{}
func (uc *CreateFailedBackupUsecase) Execute(
ctx context.Context,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
backupProgressListener(10)
return nil, errors.New("backup failed")
}
type CreateSuccessBackupUsecase struct{}
func (uc *CreateSuccessBackupUsecase) Execute(
ctx context.Context,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
backupProgressListener(10)
return &common.BackupMetadata{
EncryptionSalt: nil,
EncryptionIV: nil,
Encryption: backups_config.BackupEncryptionNone,
}, nil
}
// CreateLargeBackupUsecase simulates a large backup (10000 MB)
type CreateLargeBackupUsecase struct{}
func (uc *CreateLargeBackupUsecase) Execute(
ctx context.Context,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
backupProgressListener(10000)
return &common.BackupMetadata{
EncryptionSalt: nil,
EncryptionIV: nil,
Encryption: backups_config.BackupEncryptionNone,
}, nil
}
// CreateProgressiveBackupUsecase simulates progressive size updates that exceed limit
type CreateProgressiveBackupUsecase struct{}
func (uc *CreateProgressiveBackupUsecase) Execute(
ctx context.Context,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
// Simulate progressive backup that grows beyond limit
backupProgressListener(1)
if ctx.Err() != nil {
return nil, ctx.Err()
}
backupProgressListener(3)
if ctx.Err() != nil {
return nil, ctx.Err()
}
backupProgressListener(5)
if ctx.Err() != nil {
return nil, ctx.Err()
}
backupProgressListener(10) // This exceeds the 5 MB limit
if ctx.Err() != nil {
return nil, ctx.Err()
}
// Should not reach here due to cancellation
return &common.BackupMetadata{
EncryptionSalt: nil,
EncryptionIV: nil,
Encryption: backups_config.BackupEncryptionNone,
}, nil
}
// CreateMediumBackupUsecase simulates a 50 MB backup
type CreateMediumBackupUsecase struct{}
func (uc *CreateMediumBackupUsecase) Execute(
ctx context.Context,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
backupProgressListener(50)
return &common.BackupMetadata{
EncryptionSalt: nil,
EncryptionIV: nil,
Encryption: backups_config.BackupEncryptionNone,
}, nil
}
// MockTrackingBackupUsecase tracks backup use case calls for testing parallel execution
type MockTrackingBackupUsecase struct {
callCount atomic.Int32
calledBackupIDs chan uuid.UUID
}
func NewMockTrackingBackupUsecase() *MockTrackingBackupUsecase {
return &MockTrackingBackupUsecase{
calledBackupIDs: make(chan uuid.UUID, 10),
}
}
func (m *MockTrackingBackupUsecase) Execute(
ctx context.Context,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
m.callCount.Add(1)
// Send backup ID to channel (non-blocking)
select {
case m.calledBackupIDs <- backup.ID:
default:
}
// Simulate backup work
time.Sleep(100 * time.Millisecond)
backupProgressListener(10)
return &common.BackupMetadata{
EncryptionSalt: nil,
EncryptionIV: nil,
Encryption: backups_config.BackupEncryptionNone,
}, nil
}
func (m *MockTrackingBackupUsecase) GetCallCount() int32 {
return m.callCount.Load()
}
func (m *MockTrackingBackupUsecase) GetCalledBackupIDs() []uuid.UUID {
ids := []uuid.UUID{}
for {
select {
case id := <-m.calledBackupIDs:
ids = append(ids, id)
default:
return ids
}
}
}

View File

@@ -13,10 +13,9 @@ import (
"databasus-backend/internal/config"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/storages"
"databasus-backend/internal/features/databases"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/period"
files_utils "databasus-backend/internal/util/files"
)
const (
@@ -28,9 +27,9 @@ const (
type BackupsScheduler struct {
backupRepository *backups_core.BackupRepository
backupConfigService *backups_config.BackupConfigService
storageService *storages.StorageService
taskCancelManager *task_cancellation.TaskCancelManager
backupNodesRegistry *BackupNodesRegistry
databaseService *databases.DatabaseService
lastBackupTime time.Time
logger *slog.Logger
@@ -84,10 +83,6 @@ func (s *BackupsScheduler) Run(ctx context.Context) {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.cleanOldBackups(); err != nil {
s.logger.Error("Failed to clean old backups", "error", err)
}
if err := s.checkDeadNodesAndFailBackups(); err != nil {
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
}
@@ -111,58 +106,52 @@ func (s *BackupsScheduler) IsSchedulerRunning() bool {
return s.lastBackupTime.After(time.Now().UTC().Add(-schedulerHealthcheckThreshold))
}
func (s *BackupsScheduler) failBackupsInProgress() error {
backupsInProgress, err := s.backupRepository.FindByStatus(backups_core.BackupStatusInProgress)
func (s *BackupsScheduler) IsBackupNodesAvailable() bool {
nodes, err := s.backupNodesRegistry.GetAvailableNodes()
if err != nil {
return err
s.logger.Error("Failed to get available nodes for health check", "error", err)
return false
}
for _, backup := range backupsInProgress {
if err := s.taskCancelManager.CancelTask(backup.ID); err != nil {
s.logger.Error(
"Failed to cancel backup via task cancel manager",
"backupId",
backup.ID,
"error",
err,
)
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(backup.DatabaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
continue
}
failMessage := "Backup failed due to application restart"
backup.FailMessage = &failMessage
backup.Status = backups_core.BackupStatusFailed
backup.BackupSizeMb = 0
s.backuperNode.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupFailed,
&failMessage,
)
if err := s.backupRepository.Save(backup); err != nil {
return err
}
}
return nil
return len(nodes) > 0
}
func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool) {
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(databaseID)
func (s *BackupsScheduler) StartBackup(database *databases.Database, isCallNotifier bool) {
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
return
}
if backupConfig.StorageID == nil {
s.logger.Error("Backup config storage ID is nil", "databaseId", databaseID)
s.logger.Error("Backup config storage ID is nil", "databaseId", database.ID)
return
}
// Check for existing in-progress backups
inProgressBackups, err := s.backupRepository.FindByDatabaseIdAndStatus(
database.ID,
backups_core.BackupStatusInProgress,
)
if err != nil {
s.logger.Error(
"Failed to check for in-progress backups",
"databaseId",
database.ID,
"error",
err,
)
return
}
if len(inProgressBackups) > 0 {
s.logger.Warn(
"Backup already in progress for database, skipping new backup",
"databaseId",
database.ID,
"existingBackupId",
inProgressBackups[0].ID,
)
return
}
@@ -178,12 +167,22 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
return
}
backupID := uuid.New()
timestamp := time.Now().UTC()
backup := &backups_core.Backup{
ID: backupID,
FileName: fmt.Sprintf(
"%s-%s-%s",
files_utils.SanitizeFilename(database.Name),
timestamp.Format("20060102-150405"),
backupID.String(),
),
DatabaseID: backupConfig.DatabaseID,
StorageID: *backupConfig.StorageID,
Status: backups_core.BackupStatusInProgress,
BackupSizeMb: 0,
CreatedAt: time.Now().UTC(),
CreatedAt: timestamp,
}
if err := s.backupRepository.Save(backup); err != nil {
@@ -237,8 +236,8 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
s.backupToNodeRelations[*leastBusyNodeID] = relation
} else {
s.backupToNodeRelations[*leastBusyNodeID] = BackupToNodeRelation{
NodeID: *leastBusyNodeID,
BackupsIDs: []uuid.UUID{backup.ID},
*leastBusyNodeID,
[]uuid.UUID{backup.ID},
}
}
@@ -266,6 +265,10 @@ func (s *BackupsScheduler) GetRemainedBackupTryCount(lastBackup *backups_core.Ba
return 0
}
if lastBackup.IsSkipRetry {
return 0
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(lastBackup.DatabaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
@@ -298,74 +301,6 @@ func (s *BackupsScheduler) GetRemainedBackupTryCount(lastBackup *backups_core.Ba
return maxFailedTriesCount - len(lastFailedBackups)
}
func (s *BackupsScheduler) cleanOldBackups() error {
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
backupStorePeriod := backupConfig.StorePeriod
if backupStorePeriod == period.PeriodForever {
continue
}
storeDuration := backupStorePeriod.ToDuration()
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
oldBackups, err := s.backupRepository.FindBackupsBeforeDate(
backupConfig.DatabaseID,
dateBeforeBackupsShouldBeDeleted,
)
if err != nil {
s.logger.Error(
"Failed to find old backups for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
for _, backup := range oldBackups {
storage, err := s.storageService.GetStorageByID(backup.StorageID)
if err != nil {
s.logger.Error(
"Failed to get storage by ID",
"storageId",
backup.StorageID,
"error",
err,
)
continue
}
encryptor := encryption.GetFieldEncryptor()
err = storage.DeleteFile(encryptor, backup.ID)
if err != nil {
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
}
if err := s.backupRepository.DeleteByID(backup.ID); err != nil {
s.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
continue
}
s.logger.Info(
"Deleted old backup",
"backupId",
backup.ID,
"databaseId",
backupConfig.DatabaseID,
)
}
}
return nil
}
func (s *BackupsScheduler) runPendingBackups() error {
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
@@ -406,7 +341,13 @@ func (s *BackupsScheduler) runPendingBackups() error {
backupConfig.BackupInterval.Interval,
)
s.StartBackup(backupConfig.DatabaseID, remainedBackupTryCount == 1)
database, err := s.databaseService.GetDatabaseByID(backupConfig.DatabaseID)
if err != nil {
s.logger.Error("Failed to get database by ID", "error", err)
continue
}
s.StartBackup(database, remainedBackupTryCount == 1)
continue
}
}
@@ -414,6 +355,49 @@ func (s *BackupsScheduler) runPendingBackups() error {
return nil
}
func (s *BackupsScheduler) failBackupsInProgress() error {
backupsInProgress, err := s.backupRepository.FindByStatus(backups_core.BackupStatusInProgress)
if err != nil {
return err
}
for _, backup := range backupsInProgress {
if err := s.taskCancelManager.CancelTask(backup.ID); err != nil {
s.logger.Error(
"Failed to cancel backup via task cancel manager",
"backupId",
backup.ID,
"error",
err,
)
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(backup.DatabaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
continue
}
failMessage := "Backup failed due to application restart"
backup.FailMessage = &failMessage
backup.Status = backups_core.BackupStatusFailed
backup.BackupSizeMb = 0
s.backuperNode.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupFailed,
&failMessage,
)
if err := s.backupRepository.Save(backup); err != nil {
return err
}
}
return nil
}
func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
nodes, err := s.backupNodesRegistry.GetAvailableNodes()
if err != nil {

View File

@@ -492,7 +492,7 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
assert.NoError(t, err)
// Scheduler assigns backup to mock node
GetBackupsScheduler().StartBackup(database.ID, false)
GetBackupsScheduler().StartBackup(database, false)
time.Sleep(100 * time.Millisecond)
backups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -595,7 +595,7 @@ func Test_OnBackupCompleted_WhenTaskIsNotBackup_SkipsProcessing(t *testing.T) {
assert.NoError(t, err)
// Start a backup and assign it to the node
GetBackupsScheduler().StartBackup(database.ID, false)
GetBackupsScheduler().StartBackup(database, false)
time.Sleep(100 * time.Millisecond)
backups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -892,7 +892,7 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T
t.Logf("Initial active tasks: %d", initialActiveTasks)
// Start backup
scheduler.StartBackup(database.ID, false)
scheduler.StartBackup(database, false)
// Wait for backup to complete
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
@@ -995,7 +995,7 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
t.Logf("Initial active tasks: %d", initialActiveTasks)
// Start backup
scheduler.StartBackup(database.ID, false)
scheduler.StartBackup(database, false)
// Wait for backup to fail
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
@@ -1033,3 +1033,289 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
time.Sleep(200 * time.Millisecond)
}
func Test_StartBackup_WhenBackupAlreadyInProgress_SkipsNewBackup(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Create an in-progress backup manually
inProgressBackup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
BackupSizeMb: 0,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(inProgressBackup)
assert.NoError(t, err)
// Try to start a new backup - should be skipped
GetBackupsScheduler().StartBackup(database, false)
time.Sleep(200 * time.Millisecond)
// Verify only 1 backup exists (the original in-progress one)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status)
assert.Equal(t, inProgressBackup.ID, backups[0].ID)
time.Sleep(200 * time.Millisecond)
}
func Test_RunPendingBackups_WhenLastBackupFailedWithIsSkipRetry_SkipsBackupEvenWithRetriesEnabled(
t *testing.T,
) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups with retries enabled and high retry count
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
backupConfig.IsRetryIfFailed = true
backupConfig.MaxFailedTriesCount = 5
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Create a failed backup with IsSkipRetry set to true
failMessage := "backup failed due to size limit exceeded"
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusFailed,
FailMessage: &failMessage,
IsSkipRetry: true,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
})
// Verify GetRemainedBackupTryCount returns 0 even though retries are enabled
lastBackup, err := backupRepository.FindLastByDatabaseID(database.ID)
assert.NoError(t, err)
assert.NotNil(t, lastBackup)
remainedTries := GetBackupsScheduler().GetRemainedBackupTryCount(lastBackup)
assert.Equal(t, 0, remainedTries, "Should return 0 tries when IsSkipRetry is true")
// Run the scheduler
GetBackupsScheduler().runPendingBackups()
time.Sleep(100 * time.Millisecond)
// Verify no new backup was created (still only 1 backup exists)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1, "No retry should be attempted when IsSkipRetry is true")
time.Sleep(200 * time.Millisecond)
}
func Test_StartBackup_When2BackupsStartedForDifferentDatabases_BothUseCasesAreCalled(t *testing.T) {
cache_utils.ClearAllCache()
// Create mock tracking use case
mockUseCase := NewMockTrackingBackupUsecase()
// Create BackuperNode with mock use case
backuperNode := CreateTestBackuperNodeWithUseCase(mockUseCase)
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
// Create scheduler
scheduler := CreateTestScheduler()
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
// Setup test data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
// Create 2 separate databases
database1 := databases.CreateTestDatabase(workspace.ID, storage, notifier)
database2 := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// Cleanup backups for database1
backups1, _ := backupRepository.FindByDatabaseID(database1.ID)
for _, backup := range backups1 {
backupRepository.DeleteByID(backup.ID)
}
// Cleanup backups for database2
backups2, _ := backupRepository.FindByDatabaseID(database2.ID)
for _, backup := range backups2 {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database1)
databases.RemoveTestDatabase(database2)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups for database1
backupConfig1, err := backups_config.GetBackupConfigService().
GetBackupConfigByDbId(database1.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig1.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig1.IsBackupsEnabled = true
backupConfig1.StorePeriod = period.PeriodWeek
backupConfig1.Storage = storage
backupConfig1.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig1)
assert.NoError(t, err)
// Enable backups for database2
backupConfig2, err := backups_config.GetBackupConfigService().
GetBackupConfigByDbId(database2.ID)
assert.NoError(t, err)
backupConfig2.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig2.IsBackupsEnabled = true
backupConfig2.StorePeriod = period.PeriodWeek
backupConfig2.Storage = storage
backupConfig2.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig2)
assert.NoError(t, err)
// Start 2 backups simultaneously
t.Log("Starting backup for database1")
scheduler.StartBackup(database1, false)
t.Log("Starting backup for database2")
scheduler.StartBackup(database2, false)
// Wait up to 10 seconds for both backups to complete
t.Log("Waiting for both backups to complete...")
success := assert.Eventually(t, func() bool {
callCount := mockUseCase.GetCallCount()
t.Logf("Current call count: %d/2", callCount)
return callCount == 2
}, 10*time.Second, 200*time.Millisecond, "Both use cases should be called within 10 seconds")
if !success {
t.Logf("Test failed: Only %d out of 2 use cases were called", mockUseCase.GetCallCount())
}
// Verify both backup IDs were received
calledBackupIDs := mockUseCase.GetCalledBackupIDs()
t.Logf("Called backup IDs: %v", calledBackupIDs)
assert.Len(t, calledBackupIDs, 2, "Both backup IDs should be tracked")
// Verify both backups exist in repository and are completed
backups1, err := backupRepository.FindByDatabaseID(database1.ID)
assert.NoError(t, err)
assert.Len(t, backups1, 1, "Database1 should have 1 backup")
if len(backups1) > 0 {
t.Logf("Database1 backup status: %s", backups1[0].Status)
}
backups2, err := backupRepository.FindByDatabaseID(database2.ID)
assert.NoError(t, err)
assert.Len(t, backups2, 1, "Database2 should have 1 backup")
if len(backups2) > 0 {
t.Logf("Database2 backup status: %s", backups2[0].Status)
}
// Verify both backups completed successfully
if len(backups1) > 0 {
assert.Equal(t, backups_core.BackupStatusCompleted, backups1[0].Status,
"Database1 backup should be completed")
}
if len(backups2) > 0 {
assert.Equal(t, backups_core.BackupStatusCompleted, backups2[0].Status,
"Database2 backup should be completed")
}
time.Sleep(200 * time.Millisecond)
}

View File

@@ -55,19 +55,38 @@ func CreateTestBackuperNode() *BackuperNode {
}
}
func CreateTestBackuperNodeWithUseCase(useCase backups_core.CreateBackupUsecase) *BackuperNode {
return &BackuperNode{
databaseService: databases.GetDatabaseService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
workspaceService: workspaces_services.GetWorkspaceService(),
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
notificationSender: notifiers.GetNotifierService(),
backupCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
logger: logger.GetLogger(),
createBackupUseCase: useCase,
nodeID: uuid.New(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}
func CreateTestScheduler() *BackupsScheduler {
return &BackupsScheduler{
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
taskCancelManager,
backupNodesRegistry,
time.Now().UTC(),
logger.GetLogger(),
make(map[uuid.UUID]BackupToNodeRelation),
CreateTestBackuperNode(),
sync.Once{},
atomic.Bool{},
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
taskCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
lastBackupTime: time.Now().UTC(),
logger: logger.GetLogger(),
backupToNodeRelations: make(map[uuid.UUID]BackupToNodeRelation),
backuperNode: CreateTestBackuperNode(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}

View File

@@ -6,6 +6,7 @@ import (
backups_download "databasus-backend/internal/features/backups/backups/download"
"databasus-backend/internal/features/databases"
users_middleware "databasus-backend/internal/features/users/middleware"
files_utils "databasus-backend/internal/util/files"
"fmt"
"io"
"net/http"
@@ -322,7 +323,7 @@ func (c *BackupController) generateBackupFilename(
timestamp := backup.CreatedAt.Format("2006-01-02_15-04-05")
// Sanitize database name for filename (replace spaces and special chars)
safeName := sanitizeFilename(database.Name)
safeName := files_utils.SanitizeFilename(database.Name)
// Determine extension based on database type
extension := c.getBackupExtension(database.Type)
@@ -346,33 +347,6 @@ func (c *BackupController) getBackupExtension(
}
}
func sanitizeFilename(name string) string {
// Replace characters that are invalid in filenames
replacer := map[rune]rune{
' ': '_',
'/': '-',
'\\': '-',
':': '-',
'*': '-',
'?': '-',
'"': '-',
'<': '-',
'>': '-',
'|': '-',
}
result := make([]rune, 0, len(name))
for _, char := range name {
if replacement, exists := replacer[char]; exists {
result = append(result, replacement)
} else {
result = append(result, char)
}
}
return string(result)
}
func (c *BackupController) startDownloadHeartbeat(ctx context.Context, userID uuid.UUID) {
ticker := time.NewTicker(backups_download.GetDownloadHeartbeatInterval())
defer ticker.Stop()

View File

@@ -32,6 +32,7 @@ import (
workspaces_models "databasus-backend/internal/features/workspaces/models"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/util/encryption"
files_utils "databasus-backend/internal/util/files"
test_utils "databasus-backend/internal/util/testing"
"databasus-backend/internal/util/tools"
)
@@ -80,7 +81,7 @@ func Test_GetBackups_PermissionsEnforced(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, _ := createTestDatabaseWithBackups(workspace, owner, router)
database, _, storage := createTestDatabaseWithBackups(workspace, owner, router)
var testUserToken string
if tt.isGlobalAdmin {
@@ -122,6 +123,12 @@ func Test_GetBackups_PermissionsEnforced(t *testing.T) {
} else {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
@@ -218,6 +225,10 @@ func Test_CreateBackup_PermissionsEnforced(t *testing.T) {
} else {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
// Cleanup
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
@@ -261,6 +272,10 @@ func Test_CreateBackup_AuditLogWritten(t *testing.T) {
}
}
assert.True(t, found, "Audit log for backup creation not found")
// Cleanup
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_DeleteBackup_PermissionsEnforced(t *testing.T) {
@@ -314,7 +329,7 @@ func Test_DeleteBackup_PermissionsEnforced(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
var testUserToken string
if tt.isGlobalAdmin {
@@ -358,6 +373,12 @@ func Test_DeleteBackup_PermissionsEnforced(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, 0, len(response.Backups))
}
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
@@ -367,7 +388,7 @@ func Test_DeleteBackup_AuditLogWritten(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
test_utils.MakeDeleteRequest(
t,
@@ -398,6 +419,12 @@ func Test_DeleteBackup_AuditLogWritten(t *testing.T) {
}
}
assert.True(t, found, "Audit log for backup deletion not found")
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_GenerateDownloadToken_PermissionsEnforced(t *testing.T) {
@@ -444,7 +471,7 @@ func Test_GenerateDownloadToken_PermissionsEnforced(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
var testUserToken string
if tt.isGlobalAdmin {
@@ -488,6 +515,12 @@ func Test_GenerateDownloadToken_PermissionsEnforced(t *testing.T) {
} else {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
@@ -497,7 +530,7 @@ func Test_DownloadBackup_WithValidToken_Success(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
// Generate download token
var tokenResponse backups_download.GenerateDownloadTokenResponse
@@ -524,6 +557,12 @@ func Test_DownloadBackup_WithValidToken_Success(t *testing.T) {
contentDisposition := testResp.Headers.Get("Content-Disposition")
assert.Contains(t, contentDisposition, "attachment")
assert.Contains(t, contentDisposition, tokenResponse.Filename)
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_DownloadBackup_WithoutToken_Unauthorized(t *testing.T) {
@@ -531,7 +570,7 @@ func Test_DownloadBackup_WithoutToken_Unauthorized(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
// Try to download without token
testResp := test_utils.MakeGetRequest(
@@ -543,6 +582,12 @@ func Test_DownloadBackup_WithoutToken_Unauthorized(t *testing.T) {
)
assert.Contains(t, string(testResp.Body), "download token is required")
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_DownloadBackup_WithInvalidToken_Unauthorized(t *testing.T) {
@@ -550,7 +595,7 @@ func Test_DownloadBackup_WithInvalidToken_Unauthorized(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
// Try to download with invalid token
testResp := test_utils.MakeGetRequest(
@@ -562,6 +607,12 @@ func Test_DownloadBackup_WithInvalidToken_Unauthorized(t *testing.T) {
)
assert.Contains(t, string(testResp.Body), "invalid or expired download token")
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_DownloadBackup_WithExpiredToken_Unauthorized(t *testing.T) {
@@ -569,7 +620,7 @@ func Test_DownloadBackup_WithExpiredToken_Unauthorized(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
// Get user for token generation
userService := users_services.GetUserService()
@@ -611,6 +662,12 @@ func Test_DownloadBackup_WithExpiredToken_Unauthorized(t *testing.T) {
}
}
assert.False(t, found, "Audit log should NOT be created for failed download with expired token")
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_DownloadBackup_TokenUsedOnce_CannotReuseToken(t *testing.T) {
@@ -618,7 +675,7 @@ func Test_DownloadBackup_TokenUsedOnce_CannotReuseToken(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
// Generate download token
var tokenResponse backups_download.GenerateDownloadTokenResponse
@@ -651,6 +708,12 @@ func Test_DownloadBackup_TokenUsedOnce_CannotReuseToken(t *testing.T) {
)
assert.Contains(t, string(testResp.Body), "invalid or expired download token")
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_DownloadBackup_WithDifferentBackupToken_Unauthorized(t *testing.T) {
@@ -705,6 +768,13 @@ func Test_DownloadBackup_WithDifferentBackupToken_Unauthorized(t *testing.T) {
)
assert.Contains(t, string(testResp.Body), "invalid or expired download token")
// Cleanup
databases.RemoveTestDatabase(database1)
databases.RemoveTestDatabase(database2)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
@@ -712,7 +782,7 @@ func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
// Generate download token
var tokenResponse backups_download.GenerateDownloadTokenResponse
@@ -756,6 +826,12 @@ func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
}
}
assert.True(t, found, "Audit log for backup download not found")
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_DownloadBackup_ProperFilenameForPostgreSQL(t *testing.T) {
@@ -856,6 +932,12 @@ func Test_DownloadBackup_ProperFilenameForPostgreSQL(t *testing.T) {
contentDisposition,
"Filename should contain timestamp",
)
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
@@ -875,7 +957,7 @@ func Test_SanitizeFilename(t *testing.T) {
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := sanitizeFilename(tt.input)
result := files_utils.SanitizeFilename(tt.input)
assert.Equal(t, tt.expected, result)
})
}
@@ -948,6 +1030,12 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
}
}
assert.True(t, foundCancelLog, "Cancel audit log should be created")
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_ConcurrentDownloadPrevention(t *testing.T) {
@@ -955,7 +1043,7 @@ func Test_ConcurrentDownloadPrevention(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
var token1Response backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
@@ -1003,6 +1091,12 @@ func Test_ConcurrentDownloadPrevention(t *testing.T) {
if !service.IsDownloadInProgress(owner.UserID) {
t.Log("Warning: First download completed before we could test concurrency")
<-downloadComplete
// Cleanup before early return
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
return
}
@@ -1049,6 +1143,12 @@ func Test_ConcurrentDownloadPrevention(t *testing.T) {
t.Log(
"Successfully prevented concurrent downloads and allowed subsequent downloads after completion",
)
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_GenerateDownloadToken_BlockedWhenDownloadInProgress(t *testing.T) {
@@ -1056,7 +1156,7 @@ func Test_GenerateDownloadToken_BlockedWhenDownloadInProgress(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
var token1Response backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
@@ -1092,6 +1192,12 @@ func Test_GenerateDownloadToken_BlockedWhenDownloadInProgress(t *testing.T) {
if !service.IsDownloadInProgress(owner.UserID) {
t.Log("Warning: First download completed before we could test token generation blocking")
<-downloadComplete
// Cleanup before early return
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
return
}
@@ -1131,6 +1237,12 @@ func Test_GenerateDownloadToken_BlockedWhenDownloadInProgress(t *testing.T) {
t.Log(
"Successfully blocked token generation during download and allowed generation after completion",
)
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func createTestRouter() *gin.Engine {
@@ -1222,7 +1334,7 @@ func createTestDatabaseWithBackups(
workspace *workspaces_models.Workspace,
owner *users_dto.SignInResponseDTO,
router *gin.Engine,
) (*databases.Database, *backups_core.Backup) {
) (*databases.Database, *backups_core.Backup, *storages.Storage) {
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
@@ -1242,7 +1354,7 @@ func createTestDatabaseWithBackups(
backup := createTestBackup(database, owner)
return database, backup
return database, backup, storage
}
func createTestBackup(
@@ -1255,11 +1367,24 @@ func createTestBackup(
panic(err)
}
storages, err := storages.GetStorageService().GetStorages(user, *database.WorkspaceID)
if err != nil || len(storages) == 0 {
loadedStorages, err := storages.GetStorageService().GetStorages(user, *database.WorkspaceID)
if err != nil || len(loadedStorages) == 0 {
panic("No storage found for workspace")
}
// Filter out system storages
var nonSystemStorages []*storages.Storage
for _, storage := range loadedStorages {
if !storage.IsSystem {
nonSystemStorages = append(nonSystemStorages, storage)
}
}
if len(nonSystemStorages) == 0 {
panic("No non-system storage found for workspace")
}
storages := nonSystemStorages
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
@@ -1283,7 +1408,7 @@ func createTestBackup(
context.Background(),
encryption.GetFieldEncryptor(),
logger,
backup.ID,
backup.ID.String(),
reader,
); err != nil {
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
@@ -1320,7 +1445,7 @@ func Test_BandwidthThrottling_SingleDownload_Uses75Percent(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
bandwidthManager := backups_download.GetBandwidthManager()
initialCount := bandwidthManager.GetActiveDownloadCount()
@@ -1370,6 +1495,12 @@ func Test_BandwidthThrottling_SingleDownload_Uses75Percent(t *testing.T) {
time.Sleep(50 * time.Millisecond)
finalCount := bandwidthManager.GetActiveDownloadCount()
assert.Equal(t, initialCount, finalCount, "Download should be unregistered after completion")
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_BandwidthThrottling_MultipleDownloads_ShareBandwidth(t *testing.T) {
@@ -1489,6 +1620,12 @@ func Test_BandwidthThrottling_MultipleDownloads_ShareBandwidth(t *testing.T) {
time.Sleep(100 * time.Millisecond)
finalCount := bandwidthManager.GetActiveDownloadCount()
assert.Equal(t, initialCount, finalCount, "All downloads should be unregistered")
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_BandwidthThrottling_DynamicAdjustment(t *testing.T) {
@@ -1577,4 +1714,10 @@ func Test_BandwidthThrottling_DynamicAdjustment(t *testing.T) {
time.Sleep(100 * time.Millisecond)
finalCount := bandwidthManager.GetActiveDownloadCount()
assert.Equal(t, initialCount, finalCount, "All downloads completed and unregistered")
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}

View File

@@ -8,8 +8,6 @@ import (
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
"github.com/google/uuid"
)
type NotificationSender interface {
@@ -23,7 +21,7 @@ type NotificationSender interface {
type CreateBackupUsecase interface {
Execute(
ctx context.Context,
backupID uuid.UUID,
backup *Backup,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,

View File

@@ -8,13 +8,15 @@ import (
)
type Backup struct {
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
FileName string `json:"fileName" gorm:"column:file_name;type:text;not null"`
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;not null"`
StorageID uuid.UUID `json:"storageId" gorm:"column:storage_id;type:uuid;not null"`
Status BackupStatus `json:"status" gorm:"column:status;not null"`
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`
IsSkipRetry bool `json:"isSkipRetry" gorm:"column:is_skip_retry;type:boolean;not null"`
BackupSizeMb float64 `json:"backupSizeMb" gorm:"column:backup_size_mb;default:0"`

View File

@@ -212,3 +212,36 @@ func (r *BackupRepository) CountByDatabaseID(databaseID uuid.UUID) (int64, error
return count, nil
}
func (r *BackupRepository) GetTotalSizeByDatabase(databaseID uuid.UUID) (float64, error) {
var totalSize float64
if err := storage.
GetDb().
Model(&Backup{}).
Select("COALESCE(SUM(backup_size_mb), 0)").
Where("database_id = ? AND status != ?", databaseID, BackupStatusInProgress).
Scan(&totalSize).Error; err != nil {
return 0, err
}
return totalSize, nil
}
func (r *BackupRepository) FindOldestByDatabaseExcludingInProgress(
databaseID uuid.UUID,
limit int,
) ([]*Backup, error) {
var backups []*Backup
if err := storage.
GetDb().
Where("database_id = ? AND status != ?", databaseID, BackupStatusInProgress).
Order("created_at ASC").
Limit(limit).
Find(&backups).Error; err != nil {
return nil, err
}
return backups, nil
}

View File

@@ -25,22 +25,23 @@ var backupRepository = &backups_core.BackupRepository{}
var taskCancelManager = task_cancellation.GetTaskCancelManager()
var backupService = &BackupService{
databaseService: databases.GetDatabaseService(),
storageService: storages.GetStorageService(),
backupRepository: backupRepository,
notifierService: notifiers.GetNotifierService(),
notificationSender: notifiers.GetNotifierService(),
backupConfigService: backups_config.GetBackupConfigService(),
secretKeyService: encryption_secrets.GetSecretKeyService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
createBackupUseCase: usecases.GetCreateBackupUsecase(),
logger: logger.GetLogger(),
backupRemoveListeners: []backups_core.BackupRemoveListener{},
workspaceService: workspaces_services.GetWorkspaceService(),
auditLogService: audit_logs.GetAuditLogService(),
taskCancelManager: taskCancelManager,
downloadTokenService: backups_download.GetDownloadTokenService(),
backupSchedulerService: backuping.GetBackupsScheduler(),
databases.GetDatabaseService(),
storages.GetStorageService(),
backupRepository,
notifiers.GetNotifierService(),
notifiers.GetNotifierService(),
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
usecases.GetCreateBackupUsecase(),
logger.GetLogger(),
[]backups_core.BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
taskCancelManager,
backups_download.GetDownloadTokenService(),
backuping.GetBackupsScheduler(),
backuping.GetBackupCleaner(),
}
var backupController = &BackupController{

View File

@@ -21,6 +21,7 @@ import (
users_models "databasus-backend/internal/features/users/models"
workspaces_services "databasus-backend/internal/features/workspaces/services"
util_encryption "databasus-backend/internal/util/encryption"
files_utils "databasus-backend/internal/util/files"
"github.com/google/uuid"
)
@@ -46,6 +47,7 @@ type BackupService struct {
taskCancelManager *task_cancellation.TaskCancelManager
downloadTokenService *backups_download.DownloadTokenService
backupSchedulerService *backuping.BackupsScheduler
backupCleaner *backuping.BackupCleaner
}
func (s *BackupService) AddBackupRemoveListener(listener backups_core.BackupRemoveListener) {
@@ -91,7 +93,7 @@ func (s *BackupService) MakeBackupWithAuth(
return errors.New("insufficient permissions to create backup for this database")
}
s.backupSchedulerService.StartBackup(databaseID, true)
s.backupSchedulerService.StartBackup(database, true)
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Backup manually initiated for database: %s", database.Name),
@@ -180,16 +182,12 @@ func (s *BackupService) DeleteBackup(
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Backup deleted for database: %s (ID: %s)",
database.Name,
backupID.String(),
),
fmt.Sprintf("Backup deleted for database: %s", database.Name),
&user.ID,
database.WorkspaceID,
)
return s.deleteBackup(backup)
return s.backupCleaner.DeleteBackup(backup)
}
func (s *BackupService) GetBackup(backupID uuid.UUID) (*backups_core.Backup, error) {
@@ -231,11 +229,7 @@ func (s *BackupService) CancelBackup(
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Backup cancelled for database: %s (ID: %s)",
database.Name,
backupID.String(),
),
fmt.Sprintf("Backup cancelled for database: %s", database.Name),
&user.ID,
database.WorkspaceID,
)
@@ -275,11 +269,7 @@ func (s *BackupService) GetBackupFile(
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Backup file downloaded for database: %s (ID: %s)",
database.Name,
backupID.String(),
),
fmt.Sprintf("Backup file downloaded for database: %s", database.Name),
&user.ID,
database.WorkspaceID,
)
@@ -292,29 +282,6 @@ func (s *BackupService) GetBackupFile(
return reader, backup, database, nil
}
func (s *BackupService) deleteBackup(backup *backups_core.Backup) error {
for _, listener := range s.backupRemoveListeners {
if err := listener.OnBeforeBackupRemove(backup); err != nil {
return err
}
}
storage, err := s.storageService.GetStorageByID(backup.StorageID)
if err != nil {
return err
}
err = storage.DeleteFile(s.fieldEncryptor, backup.ID)
if err != nil {
// we do not return error here, because sometimes clean up performed
// before unavailable storage removal or change - therefore we should
// proceed even in case of error
s.logger.Error("Failed to delete backup file", "error", err)
}
return s.backupRepository.DeleteByID(backup.ID)
}
func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
dbBackupsInProgress, err := s.backupRepository.FindByDatabaseIdAndStatus(
databaseID,
@@ -336,7 +303,7 @@ func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
}
for _, dbBackup := range dbBackups {
err := s.deleteBackup(dbBackup)
err := s.backupCleaner.DeleteBackup(dbBackup)
if err != nil {
return err
}
@@ -358,7 +325,7 @@ func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, erro
return nil, fmt.Errorf("failed to get storage: %w", err)
}
fileReader, err := storage.GetFile(s.fieldEncryptor, backup.ID)
fileReader, err := storage.GetFile(s.fieldEncryptor, backup.ID.String())
if err != nil {
return nil, fmt.Errorf("failed to get backup file: %w", err)
}
@@ -512,11 +479,7 @@ func (s *BackupService) WriteAuditLogForDownload(
database *databases.Database,
) {
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Backup file downloaded for database: %s (ID: %s)",
database.Name,
backup.ID.String(),
),
fmt.Sprintf("Backup file downloaded for database: %s", database.Name),
&userID,
database.WorkspaceID,
)
@@ -543,7 +506,7 @@ func (s *BackupService) generateBackupFilename(
database *databases.Database,
) string {
timestamp := backup.CreatedAt.Format("2006-01-02_15-04-05")
safeName := sanitizeFilename(database.Name)
safeName := files_utils.SanitizeFilename(database.Name)
extension := s.getBackupExtension(database.Type)
return fmt.Sprintf("%s_backup_%s%s", safeName, timestamp, extension)
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
usecases_mariadb "databasus-backend/internal/features/backups/backups/usecases/mariadb"
usecases_mongodb "databasus-backend/internal/features/backups/backups/usecases/mongodb"
usecases_mysql "databasus-backend/internal/features/backups/backups/usecases/mysql"
@@ -12,8 +13,6 @@ import (
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/storages"
"github.com/google/uuid"
)
type CreateBackupUsecase struct {
@@ -25,7 +24,7 @@ type CreateBackupUsecase struct {
func (uc *CreateBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
@@ -35,7 +34,7 @@ func (uc *CreateBackupUsecase) Execute(
case databases.DatabaseTypePostgres:
return uc.CreatePostgresqlBackupUsecase.Execute(
ctx,
backupID,
backup,
backupConfig,
database,
storage,
@@ -45,7 +44,7 @@ func (uc *CreateBackupUsecase) Execute(
case databases.DatabaseTypeMysql:
return uc.CreateMysqlBackupUsecase.Execute(
ctx,
backupID,
backup,
backupConfig,
database,
storage,
@@ -55,7 +54,7 @@ func (uc *CreateBackupUsecase) Execute(
case databases.DatabaseTypeMariadb:
return uc.CreateMariadbBackupUsecase.Execute(
ctx,
backupID,
backup,
backupConfig,
database,
storage,
@@ -65,7 +64,7 @@ func (uc *CreateBackupUsecase) Execute(
case databases.DatabaseTypeMongodb:
return uc.CreateMongodbBackupUsecase.Execute(
ctx,
backupID,
backup,
backupConfig,
database,
storage,

View File

@@ -19,6 +19,7 @@ import (
"databasus-backend/internal/config"
common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -52,7 +53,7 @@ type writeResult struct {
func (uc *CreateMariadbBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
db *databases.Database,
storage *storages.Storage,
@@ -82,7 +83,7 @@ func (uc *CreateMariadbBackupUsecase) Execute(
return uc.streamToStorage(
ctx,
backupID,
backup,
backupConfig,
tools.GetMariadbExecutable(
tools.MariadbExecutableMariadbDump,
@@ -108,13 +109,15 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
"--single-transaction",
"--routines",
"--quick",
"--skip-extended-insert",
"--verbose",
}
if mdb.HasPrivilege("TRIGGER") {
args = append(args, "--triggers")
}
if mdb.HasPrivilege("EVENT") {
if mdb.HasPrivilege("EVENT") && !mdb.IsExcludeEvents {
args = append(args, "--events")
}
@@ -134,7 +137,7 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
func (uc *CreateMariadbBackupUsecase) streamToStorage(
parentCtx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
mariadbBin string,
args []string,
@@ -185,7 +188,7 @@ func (uc *CreateMariadbBackupUsecase) streamToStorage(
storageReader, storageWriter := io.Pipe()
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
backupID,
backup.ID,
backupConfig,
storageWriter,
)
@@ -202,7 +205,13 @@ func (uc *CreateMariadbBackupUsecase) streamToStorage(
saveErrCh := make(chan error, 1)
go func() {
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
saveErr := storage.SaveFile(
ctx,
uc.fieldEncryptor,
uc.logger,
backup.FileName,
storageReader,
)
saveErrCh <- saveErr
}()

View File

@@ -16,6 +16,7 @@ import (
"databasus-backend/internal/config"
common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -46,7 +47,7 @@ type writeResult struct {
func (uc *CreateMongodbBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
db *databases.Database,
storage *storages.Storage,
@@ -76,7 +77,7 @@ func (uc *CreateMongodbBackupUsecase) Execute(
return uc.streamToStorage(
ctx,
backupID,
backup,
backupConfig,
tools.GetMongodbExecutable(
tools.MongodbExecutableMongodump,
@@ -114,7 +115,7 @@ func (uc *CreateMongodbBackupUsecase) buildMongodumpArgs(
func (uc *CreateMongodbBackupUsecase) streamToStorage(
parentCtx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
mongodumpBin string,
args []string,
@@ -163,7 +164,7 @@ func (uc *CreateMongodbBackupUsecase) streamToStorage(
storageReader, storageWriter := io.Pipe()
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
backupID,
backup.ID,
backupConfig,
storageWriter,
)
@@ -175,7 +176,13 @@ func (uc *CreateMongodbBackupUsecase) streamToStorage(
saveErrCh := make(chan error, 1)
go func() {
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
saveErr := storage.SaveFile(
ctx,
uc.fieldEncryptor,
uc.logger,
backup.FileName,
storageReader,
)
saveErrCh <- saveErr
}()

View File

@@ -19,6 +19,7 @@ import (
"databasus-backend/internal/config"
common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -52,7 +53,7 @@ type writeResult struct {
func (uc *CreateMysqlBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
db *databases.Database,
storage *storages.Storage,
@@ -82,7 +83,7 @@ func (uc *CreateMysqlBackupUsecase) Execute(
return uc.streamToStorage(
ctx,
backupID,
backup,
backupConfig,
tools.GetMysqlExecutable(
my.Version,
@@ -107,6 +108,7 @@ func (uc *CreateMysqlBackupUsecase) buildMysqldumpArgs(my *mysqltypes.MysqlDatab
"--routines",
"--set-gtid-purged=OFF",
"--quick",
"--skip-extended-insert",
"--verbose",
}
@@ -148,7 +150,7 @@ func (uc *CreateMysqlBackupUsecase) getNetworkCompressionArgs(version tools.Mysq
func (uc *CreateMysqlBackupUsecase) streamToStorage(
parentCtx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
mysqlBin string,
args []string,
@@ -199,7 +201,7 @@ func (uc *CreateMysqlBackupUsecase) streamToStorage(
storageReader, storageWriter := io.Pipe()
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
backupID,
backup.ID,
backupConfig,
storageWriter,
)
@@ -216,7 +218,13 @@ func (uc *CreateMysqlBackupUsecase) streamToStorage(
saveErrCh := make(chan error, 1)
go func() {
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
saveErr := storage.SaveFile(
ctx,
uc.fieldEncryptor,
uc.logger,
backup.FileName,
storageReader,
)
saveErrCh <- saveErr
}()

View File

@@ -16,6 +16,7 @@ import (
"databasus-backend/internal/config"
common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -53,7 +54,7 @@ type writeResult struct {
func (uc *CreatePostgresqlBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
db *databases.Database,
storage *storages.Storage,
@@ -88,7 +89,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
return uc.streamToStorage(
ctx,
backupID,
backup,
backupConfig,
tools.GetPostgresqlExecutable(
pg.Version,
@@ -107,7 +108,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
// streamToStorage streams pg_dump output directly to storage
func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
parentCtx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
pgBin string,
args []string,
@@ -166,7 +167,7 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
storageReader, storageWriter := io.Pipe()
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
backupID,
backup.ID,
backupConfig,
storageWriter,
)
@@ -181,7 +182,13 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
// Start streaming into storage in its own goroutine
saveErrCh := make(chan error, 1)
go func() {
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
saveErr := storage.SaveFile(
ctx,
uc.fieldEncryptor,
uc.logger,
backup.FileName,
storageReader,
)
saveErrCh <- saveErr
}()

View File

@@ -16,6 +16,7 @@ type BackupConfigController struct {
func (c *BackupConfigController) RegisterRoutes(router *gin.RouterGroup) {
router.POST("/backup-configs/save", c.SaveBackupConfig)
router.GET("/backup-configs/database/:id/plan", c.GetDatabasePlan)
router.GET("/backup-configs/database/:id", c.GetBackupConfigByDbID)
router.GET("/backup-configs/storage/:id/is-using", c.IsStorageUsing)
router.GET("/backup-configs/storage/:id/databases-count", c.CountDatabasesForStorage)
@@ -92,6 +93,39 @@ func (c *BackupConfigController) GetBackupConfigByDbID(ctx *gin.Context) {
ctx.JSON(http.StatusOK, backupConfig)
}
// GetDatabasePlan
// @Summary Get database plan by database ID
// @Description Get the plan limits for a specific database (max backup size, max total size, max storage period)
// @Tags backup-configs
// @Produce json
// @Param id path string true "Database ID"
// @Success 200 {object} plans.DatabasePlan
// @Failure 400 {object} map[string]string "Invalid database ID"
// @Failure 401 {object} map[string]string "User not authenticated"
// @Failure 404 {object} map[string]string "Database not found or access denied"
// @Router /backup-configs/database/{id}/plan [get]
func (c *BackupConfigController) GetDatabasePlan(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
return
}
plan, err := c.backupConfigService.GetDatabasePlan(user, id)
if err != nil {
ctx.JSON(http.StatusNotFound, gin.H{"error": "database plan not found"})
return
}
ctx.JSON(http.StatusOK, plan)
}
// IsStorageUsing
// @Summary Check if storage is being used
// @Description Check if a storage is currently being used by any backup configuration

View File

@@ -6,6 +6,7 @@ import (
"net/http"
"strconv"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -16,11 +17,14 @@ import (
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/features/storages"
local_storage "databasus-backend/internal/features/storages/models/local"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/storage"
"databasus-backend/internal/util/period"
test_utils "databasus-backend/internal/util/testing"
"databasus-backend/internal/util/tools"
@@ -89,6 +93,11 @@ func Test_SaveBackupConfig_PermissionsEnforced(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
@@ -152,6 +161,11 @@ func Test_SaveBackupConfig_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *test
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
timeOfDay := "04:00"
@@ -242,6 +256,11 @@ func Test_GetBackupConfigByDbID_PermissionsEnforced(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
@@ -290,6 +309,11 @@ func Test_GetBackupConfigByDbID_ReturnsDefaultConfigForNewDatabase(t *testing.T)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
var response BackupConfig
test_utils.MakeGetRequestAndUnmarshal(
t,
@@ -300,14 +324,214 @@ func Test_GetBackupConfigByDbID_ReturnsDefaultConfigForNewDatabase(t *testing.T)
&response,
)
var plan plans.DatabasePlan
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
"Bearer "+owner.Token,
http.StatusOK,
&plan,
)
assert.Equal(t, database.ID, response.DatabaseID)
assert.False(t, response.IsBackupsEnabled)
assert.Equal(t, period.PeriodWeek, response.StorePeriod)
assert.Equal(t, plan.MaxStoragePeriod, response.StorePeriod)
assert.Equal(t, plan.MaxBackupSizeMB, response.MaxBackupSizeMB)
assert.Equal(t, plan.MaxBackupsTotalSizeMB, response.MaxBackupsTotalSizeMB)
assert.True(t, response.IsRetryIfFailed)
assert.Equal(t, 3, response.MaxFailedTriesCount)
assert.NotNil(t, response.BackupInterval)
}
func Test_GetDatabasePlan_ForNewDatabase_PlanAlwaysReturned(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
var response plans.DatabasePlan
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
"Bearer "+owner.Token,
http.StatusOK,
&response,
)
assert.Equal(t, database.ID, response.DatabaseID)
assert.NotNil(t, response.MaxBackupSizeMB)
assert.NotNil(t, response.MaxBackupsTotalSizeMB)
assert.NotEmpty(t, response.MaxStoragePeriod)
}
func Test_SaveBackupConfig_WhenPlanLimitsAreAdjusted_ValidationEnforced(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Get plan via API (triggers auto-creation)
var plan plans.DatabasePlan
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
"Bearer "+owner.Token,
http.StatusOK,
&plan,
)
assert.Equal(t, database.ID, plan.DatabaseID)
// Adjust plan limits directly in database to fixed restrictive values
err := storage.GetDb().Model(&plans.DatabasePlan{}).
Where("database_id = ?", database.ID).
Updates(map[string]any{
"max_backup_size_mb": 100,
"max_backups_total_size_mb": 1000,
"max_storage_period": period.PeriodMonth,
}).Error
assert.NoError(t, err)
// Test 1: Try to save backup config with exceeded backup size limit
timeOfDay := "04:00"
backupConfigExceededSize := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 200, // Exceeds limit of 100
MaxBackupsTotalSizeMB: 800,
}
respExceededSize := test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigExceededSize,
http.StatusBadRequest,
)
assert.Contains(t, string(respExceededSize.Body), "max backup size exceeds plan limit")
// Test 2: Try to save backup config with exceeded total size limit
backupConfigExceededTotal := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 50,
MaxBackupsTotalSizeMB: 2000, // Exceeds limit of 1000
}
respExceededTotal := test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigExceededTotal,
http.StatusBadRequest,
)
assert.Contains(t, string(respExceededTotal.Body), "max total backups size exceeds plan limit")
// Test 3: Try to save backup config with exceeded storage period limit
backupConfigExceededPeriod := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodYear, // Exceeds limit of Month
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 80,
MaxBackupsTotalSizeMB: 800,
}
respExceededPeriod := test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigExceededPeriod,
http.StatusBadRequest,
)
assert.Contains(t, string(respExceededPeriod.Body), "storage period exceeds plan limit")
// Test 4: Save backup config within all limits - should succeed
backupConfigValid := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek, // Within Month limit
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 80, // Within 100 limit
MaxBackupsTotalSizeMB: 800, // Within 1000 limit
}
var responseValid BackupConfig
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigValid,
http.StatusOK,
&responseValid,
)
assert.Equal(t, database.ID, responseValid.DatabaseID)
assert.Equal(t, int64(80), responseValid.MaxBackupSizeMB)
assert.Equal(t, int64(800), responseValid.MaxBackupsTotalSizeMB)
assert.Equal(t, period.PeriodWeek, responseValid.StorePeriod)
}
func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
@@ -340,6 +564,10 @@ func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
)
storage := createTestStorage(workspace.ID)
defer func() {
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
var testUserToken string
if tt.isStorageOwner {
testUserToken = storageOwner.Token
@@ -372,10 +600,6 @@ func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
)
assert.Contains(t, string(testResp.Body), "error")
}
// Cleanup
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
@@ -387,6 +611,11 @@ func Test_SaveBackupConfig_WithEncryptionNone_ConfigSaved(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
@@ -426,6 +655,11 @@ func Test_SaveBackupConfig_WithEncryptionEncrypted_ConfigSaved(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
@@ -536,6 +770,15 @@ func Test_TransferDatabase_PermissionsEnforced(t *testing.T) {
targetStorage := createTestStorage(targetWorkspace.ID)
defer func() {
// Cleanup in correct order to avoid foreign key violations
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond) // Wait for cascade delete of backup_config
storages.RemoveTestStorage(targetStorage.ID)
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
@@ -628,6 +871,12 @@ func Test_TransferDatabase_NonMemberInSourceWorkspace_CannotTransfer(t *testing.
router,
)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
request := TransferDatabaseRequest{
TargetWorkspaceID: targetWorkspace.ID,
}
@@ -668,6 +917,12 @@ func Test_TransferDatabase_NonMemberInTargetWorkspace_CannotTransfer(t *testing.
router,
)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
request := TransferDatabaseRequest{
TargetWorkspaceID: targetWorkspace.ID,
}
@@ -695,6 +950,13 @@ func Test_TransferDatabase_ToNewStorage_DatabaseTransferd(t *testing.T) {
sourceStorage := createTestStorage(sourceWorkspace.ID)
targetStorage := createTestStorage(targetWorkspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
time.Sleep(200 * time.Millisecond) // Wait for cascading deletes
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
timeOfDay := "04:00"
backupConfigRequest := BackupConfig{
DatabaseID: database.ID,
@@ -774,6 +1036,13 @@ func Test_TransferDatabase_WithExistingStorage_DatabaseAndStorageTransferd(t *te
database := createTestDatabaseViaAPI("Test Database", sourceWorkspace.ID, owner.Token, router)
storage := createTestStorage(sourceWorkspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
time.Sleep(200 * time.Millisecond) // Wait for cascading deletes
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
timeOfDay := "04:00"
backupConfigRequest := BackupConfig{
DatabaseID: database.ID,
@@ -863,6 +1132,14 @@ func Test_TransferDatabase_StorageHasOtherDBs_CannotTransfer(t *testing.T) {
)
storage := createTestStorage(sourceWorkspace.ID)
defer func() {
databases.RemoveTestDatabase(database1)
databases.RemoveTestDatabase(database2)
time.Sleep(200 * time.Millisecond) // Wait for cascading deletes
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
timeOfDay := "04:00"
backupConfigRequest1 := BackupConfig{
DatabaseID: database1.ID,
@@ -945,6 +1222,14 @@ func Test_TransferDatabase_WithNotifiers_NotifiersTransferred(t *testing.T) {
targetStorage := createTestStorage(targetWorkspace.ID)
notifier := notifiers.CreateTestNotifier(sourceWorkspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
time.Sleep(200 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
database.Notifiers = []notifiers.Notifier{*notifier}
var updatedDatabase databases.Database
test_utils.MakePostRequestAndUnmarshal(
@@ -1048,6 +1333,15 @@ func Test_TransferDatabase_NotifierHasOtherDBs_NotifierSkipped(t *testing.T) {
targetStorage := createTestStorage(targetWorkspace.ID)
sharedNotifier := notifiers.CreateTestNotifier(sourceWorkspace.ID)
defer func() {
databases.RemoveTestDatabase(database1)
databases.RemoveTestDatabase(database2)
time.Sleep(200 * time.Millisecond)
notifiers.RemoveTestNotifier(sharedNotifier)
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
database1.Notifiers = []notifiers.Notifier{*sharedNotifier}
test_utils.MakePostRequest(
t,
@@ -1160,6 +1454,16 @@ func Test_TransferDatabase_WithMultipleNotifiers_OnlyExclusiveOnesTransferred(t
exclusiveNotifier := notifiers.CreateTestNotifier(sourceWorkspace.ID)
sharedNotifier := notifiers.CreateTestNotifier(sourceWorkspace.ID)
defer func() {
databases.RemoveTestDatabase(database1)
databases.RemoveTestDatabase(database2)
time.Sleep(200 * time.Millisecond)
notifiers.RemoveTestNotifier(exclusiveNotifier)
notifiers.RemoveTestNotifier(sharedNotifier)
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
database1.Notifiers = []notifiers.Notifier{*exclusiveNotifier, *sharedNotifier}
test_utils.MakePostRequest(
t,
@@ -1271,6 +1575,14 @@ func Test_TransferDatabase_WithTargetNotifiers_NotifiersAssigned(t *testing.T) {
targetStorage := createTestStorage(targetWorkspace.ID)
targetNotifier := notifiers.CreateTestNotifier(targetWorkspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
time.Sleep(200 * time.Millisecond)
notifiers.RemoveTestNotifier(targetNotifier)
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
timeOfDay := "04:00"
backupConfigRequest := BackupConfig{
DatabaseID: database.ID,
@@ -1342,6 +1654,15 @@ func Test_TransferDatabase_TargetNotifierFromDifferentWorkspace_ReturnsBadReques
targetStorage := createTestStorage(targetWorkspace.ID)
wrongNotifier := notifiers.CreateTestNotifier(otherWorkspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
time.Sleep(200 * time.Millisecond)
notifiers.RemoveTestNotifier(wrongNotifier)
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
workspaces_testing.RemoveTestWorkspace(otherWorkspace, router)
}()
timeOfDay := "04:00"
backupConfigRequest := BackupConfig{
DatabaseID: database.ID,
@@ -1399,6 +1720,14 @@ func Test_TransferDatabase_TargetStorageFromDifferentWorkspace_ReturnsBadRequest
sourceStorage := createTestStorage(sourceWorkspace.ID)
wrongStorage := createTestStorage(otherWorkspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
time.Sleep(200 * time.Millisecond)
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
workspaces_testing.RemoveTestWorkspace(otherWorkspace, router)
}()
timeOfDay := "04:00"
backupConfigRequest := BackupConfig{
DatabaseID: database.ID,
@@ -1443,6 +1772,115 @@ func Test_TransferDatabase_TargetStorageFromDifferentWorkspace_ReturnsBadRequest
assert.Contains(t, string(testResp.Body), "target storage does not belong to target workspace")
}
func Test_SaveBackupConfig_WithSystemStorage_CanBeUsedByAnyDatabase(t *testing.T) {
router := createTestRouterWithStorageForTransfer()
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
workspaceA := workspaces_testing.CreateTestWorkspace("Workspace A", owner1, router)
workspaceB := workspaces_testing.CreateTestWorkspace("Workspace B", owner2, router)
databaseA := createTestDatabaseViaAPI("Database A", workspaceA.ID, owner1.Token, router)
// Test 1: Regular storage from workspace B cannot be used by database in workspace A
regularStorageB := createTestStorage(workspaceB.ID)
timeOfDay := "04:00"
backupConfigWithRegularStorage := BackupConfig{
DatabaseID: databaseA.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
StorageID: &regularStorageB.ID,
Storage: regularStorageB,
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
}
respRegular := test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner1.Token,
backupConfigWithRegularStorage,
http.StatusBadRequest,
)
assert.Contains(t, string(respRegular.Body), "storage does not belong to the same workspace")
// Test 2: System storage from workspace B CAN be used by database in workspace A
systemStorageB := &storages.Storage{
WorkspaceID: workspaceB.ID,
Type: storages.StorageTypeLocal,
Name: "Test System Storage " + uuid.New().String(),
IsSystem: true,
LocalStorage: &local_storage.LocalStorage{},
}
var savedSystemStorage storages.Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+admin.Token,
*systemStorageB,
http.StatusOK,
&savedSystemStorage,
)
assert.True(t, savedSystemStorage.IsSystem)
backupConfigWithSystemStorage := BackupConfig{
DatabaseID: databaseA.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
StorageID: &savedSystemStorage.ID,
Storage: &savedSystemStorage,
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
}
var savedConfig BackupConfig
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner1.Token,
backupConfigWithSystemStorage,
http.StatusOK,
&savedConfig,
)
assert.Equal(t, databaseA.ID, savedConfig.DatabaseID)
assert.NotNil(t, savedConfig.StorageID)
assert.Equal(t, savedSystemStorage.ID, *savedConfig.StorageID)
assert.True(t, savedConfig.IsBackupsEnabled)
// Cleanup: database first (cascades to backup_config), then storages, then workspaces
databases.RemoveTestDatabase(databaseA)
storages.RemoveTestStorage(regularStorageB.ID)
storages.RemoveTestStorage(savedSystemStorage.ID)
workspaces_testing.RemoveTestWorkspace(workspaceA, router)
workspaces_testing.RemoveTestWorkspace(workspaceB, router)
}
func createTestDatabaseViaAPI(
name string,
workspaceID uuid.UUID,

View File

@@ -6,6 +6,7 @@ import (
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/features/storages"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/logger"
@@ -18,6 +19,7 @@ var backupConfigService = &BackupConfigService{
storages.GetStorageService(),
notifiers.GetNotifierService(),
workspaces_services.GetWorkspaceService(),
plans.GetDatabasePlanService(),
nil,
}
var backupConfigController = &BackupConfigController{

View File

@@ -1,7 +1,9 @@
package backups_config
import (
"databasus-backend/internal/config"
"databasus-backend/internal/features/intervals"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/features/storages"
"databasus-backend/internal/util/period"
"errors"
@@ -31,6 +33,11 @@ type BackupConfig struct {
MaxFailedTriesCount int `json:"maxFailedTriesCount" gorm:"column:max_failed_tries_count;type:int;not null"`
Encryption BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
// MaxBackupSizeMB limits individual backup size. 0 = unlimited.
MaxBackupSizeMB int64 `json:"maxBackupSizeMb" gorm:"column:max_backup_size_mb;type:int;not null"`
// MaxBackupsTotalSizeMB limits total size of all backups. 0 = unlimited.
MaxBackupsTotalSizeMB int64 `json:"maxBackupsTotalSizeMb" gorm:"column:max_backups_total_size_mb;type:int;not null"`
}
func (h *BackupConfig) TableName() string {
@@ -70,7 +77,7 @@ func (b *BackupConfig) AfterFind(tx *gorm.DB) error {
return nil
}
func (b *BackupConfig) Validate() error {
func (b *BackupConfig) Validate(plan *plans.DatabasePlan) error {
// Backup interval is required either as ID or as object
if b.BackupIntervalID == uuid.Nil && b.BackupInterval == nil {
return errors.New("backup interval is required")
@@ -89,20 +96,59 @@ func (b *BackupConfig) Validate() error {
return errors.New("encryption must be NONE or ENCRYPTED")
}
if config.GetEnv().IsCloud {
if b.Encryption != BackupEncryptionEncrypted {
return errors.New("encryption is mandatory for cloud storage")
}
}
if b.MaxBackupSizeMB < 0 {
return errors.New("max backup size must be non-negative")
}
if b.MaxBackupsTotalSizeMB < 0 {
return errors.New("max backups total size must be non-negative")
}
// Validate against plan limits
// Check storage period limit
if plan.MaxStoragePeriod != period.PeriodForever {
if b.StorePeriod.CompareTo(plan.MaxStoragePeriod) > 0 {
return errors.New("storage period exceeds plan limit")
}
}
// Check max backup size limit (0 in plan means unlimited)
if plan.MaxBackupSizeMB > 0 {
if b.MaxBackupSizeMB == 0 || b.MaxBackupSizeMB > plan.MaxBackupSizeMB {
return errors.New("max backup size exceeds plan limit")
}
}
// Check max total backups size limit (0 in plan means unlimited)
if plan.MaxBackupsTotalSizeMB > 0 {
if b.MaxBackupsTotalSizeMB == 0 ||
b.MaxBackupsTotalSizeMB > plan.MaxBackupsTotalSizeMB {
return errors.New("max total backups size exceeds plan limit")
}
}
return nil
}
func (b *BackupConfig) Copy(newDatabaseID uuid.UUID) *BackupConfig {
return &BackupConfig{
DatabaseID: newDatabaseID,
IsBackupsEnabled: b.IsBackupsEnabled,
StorePeriod: b.StorePeriod,
BackupIntervalID: uuid.Nil,
BackupInterval: b.BackupInterval.Copy(),
StorageID: b.StorageID,
SendNotificationsOn: b.SendNotificationsOn,
IsRetryIfFailed: b.IsRetryIfFailed,
MaxFailedTriesCount: b.MaxFailedTriesCount,
Encryption: b.Encryption,
DatabaseID: newDatabaseID,
IsBackupsEnabled: b.IsBackupsEnabled,
StorePeriod: b.StorePeriod,
BackupIntervalID: uuid.Nil,
BackupInterval: b.BackupInterval.Copy(),
StorageID: b.StorageID,
SendNotificationsOn: b.SendNotificationsOn,
IsRetryIfFailed: b.IsRetryIfFailed,
MaxFailedTriesCount: b.MaxFailedTriesCount,
Encryption: b.Encryption,
MaxBackupSizeMB: b.MaxBackupSizeMB,
MaxBackupsTotalSizeMB: b.MaxBackupsTotalSizeMB,
}
}

View File

@@ -0,0 +1,391 @@
package backups_config
import (
"testing"
"databasus-backend/internal/features/intervals"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/util/period"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_Validate_WhenStoragePeriodIsWeekAndPlanAllowsMonth_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.StorePeriod = period.PeriodWeek
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodMonth
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenStoragePeriodIsYearAndPlanAllowsMonth_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.StorePeriod = period.PeriodYear
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodMonth
err := config.Validate(plan)
assert.EqualError(t, err, "storage period exceeds plan limit")
}
func Test_Validate_WhenStoragePeriodIsForeverAndPlanAllowsForever_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.StorePeriod = period.PeriodForever
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodForever
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenStoragePeriodIsForeverAndPlanAllowsYear_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.StorePeriod = period.PeriodForever
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodYear
err := config.Validate(plan)
assert.EqualError(t, err, "storage period exceeds plan limit")
}
func Test_Validate_WhenStoragePeriodEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.StorePeriod = period.PeriodMonth
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodMonth
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenBackupSize100MBAndPlanAllows500MB_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 100
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 500
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenBackupSize500MBAndPlanAllows100MB_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 500
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 100
err := config.Validate(plan)
assert.EqualError(t, err, "max backup size exceeds plan limit")
}
func Test_Validate_WhenBackupSizeIsUnlimitedAndPlanAllowsUnlimited_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 0
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 0
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenBackupSizeIsUnlimitedAndPlanHas500MBLimit_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 0
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 500
err := config.Validate(plan)
assert.EqualError(t, err, "max backup size exceeds plan limit")
}
func Test_Validate_WhenBackupSizeEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 500
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 500
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenTotalSize1GBAndPlanAllows5GB_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 1000
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 5000
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenTotalSize5GBAndPlanAllows1GB_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 5000
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 1000
err := config.Validate(plan)
assert.EqualError(t, err, "max total backups size exceeds plan limit")
}
func Test_Validate_WhenTotalSizeIsUnlimitedAndPlanAllowsUnlimited_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 0
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 0
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenTotalSizeIsUnlimitedAndPlanHas1GBLimit_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 0
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 1000
err := config.Validate(plan)
assert.EqualError(t, err, "max total backups size exceeds plan limit")
}
func Test_Validate_WhenTotalSizeEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 5000
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 5000
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenAllLimitsAreUnlimitedInPlan_AnyConfigurationPasses(t *testing.T) {
config := createValidBackupConfig()
config.StorePeriod = period.PeriodForever
config.MaxBackupSizeMB = 0
config.MaxBackupsTotalSizeMB = 0
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenMultipleLimitsExceeded_ValidationFailsWithFirstError(t *testing.T) {
config := createValidBackupConfig()
config.StorePeriod = period.PeriodYear
config.MaxBackupSizeMB = 500
config.MaxBackupsTotalSizeMB = 5000
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodMonth
plan.MaxBackupSizeMB = 100
plan.MaxBackupsTotalSizeMB = 1000
err := config.Validate(plan)
assert.Error(t, err)
assert.EqualError(t, err, "storage period exceeds plan limit")
}
func Test_Validate_WhenConfigHasInvalidIntervalButPlanIsValid_ValidationFailsOnInterval(
t *testing.T,
) {
config := createValidBackupConfig()
config.BackupIntervalID = uuid.Nil
config.BackupInterval = nil
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "backup interval is required")
}
func Test_Validate_WhenIntervalIsMissing_ValidationFailsRegardlessOfPlan(t *testing.T) {
config := createValidBackupConfig()
config.BackupIntervalID = uuid.Nil
config.BackupInterval = nil
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "backup interval is required")
}
func Test_Validate_WhenRetryEnabledButMaxTriesIsZero_ValidationFailsRegardlessOfPlan(t *testing.T) {
config := createValidBackupConfig()
config.IsRetryIfFailed = true
config.MaxFailedTriesCount = 0
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "max failed tries count must be greater than 0")
}
func Test_Validate_WhenEncryptionIsInvalid_ValidationFailsRegardlessOfPlan(t *testing.T) {
config := createValidBackupConfig()
config.Encryption = "INVALID"
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "encryption must be NONE or ENCRYPTED")
}
func Test_Validate_WhenStoragePeriodIsEmpty_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.StorePeriod = ""
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "store period is required")
}
func Test_Validate_WhenMaxBackupSizeIsNegative_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = -100
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "max backup size must be non-negative")
}
func Test_Validate_WhenMaxTotalSizeIsNegative_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = -1000
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "max backups total size must be non-negative")
}
func Test_Validate_WhenPlanLimitsAreAtBoundary_ValidationWorks(t *testing.T) {
tests := []struct {
name string
configPeriod period.Period
planPeriod period.Period
configSize int64
planSize int64
configTotal int64
planTotal int64
shouldSucceed bool
}{
{
name: "all values just under limit",
configPeriod: period.PeriodWeek,
planPeriod: period.PeriodMonth,
configSize: 99,
planSize: 100,
configTotal: 999,
planTotal: 1000,
shouldSucceed: true,
},
{
name: "all values equal to limit",
configPeriod: period.PeriodMonth,
planPeriod: period.PeriodMonth,
configSize: 100,
planSize: 100,
configTotal: 1000,
planTotal: 1000,
shouldSucceed: true,
},
{
name: "period just over limit",
configPeriod: period.Period3Month,
planPeriod: period.PeriodMonth,
configSize: 100,
planSize: 100,
configTotal: 1000,
planTotal: 1000,
shouldSucceed: false,
},
{
name: "size just over limit",
configPeriod: period.PeriodMonth,
planPeriod: period.PeriodMonth,
configSize: 101,
planSize: 100,
configTotal: 1000,
planTotal: 1000,
shouldSucceed: false,
},
{
name: "total size just over limit",
configPeriod: period.PeriodMonth,
planPeriod: period.PeriodMonth,
configSize: 100,
planSize: 100,
configTotal: 1001,
planTotal: 1000,
shouldSucceed: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := createValidBackupConfig()
config.StorePeriod = tt.configPeriod
config.MaxBackupSizeMB = tt.configSize
config.MaxBackupsTotalSizeMB = tt.configTotal
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = tt.planPeriod
plan.MaxBackupSizeMB = tt.planSize
plan.MaxBackupsTotalSizeMB = tt.planTotal
err := config.Validate(plan)
if tt.shouldSucceed {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
})
}
}
func createValidBackupConfig() *BackupConfig {
intervalID := uuid.New()
return &BackupConfig{
DatabaseID: uuid.New(),
IsBackupsEnabled: true,
StorePeriod: period.PeriodMonth,
BackupIntervalID: intervalID,
BackupInterval: &intervals.Interval{ID: intervalID},
SendNotificationsOn: []BackupNotificationType{},
IsRetryIfFailed: false,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 100,
MaxBackupsTotalSizeMB: 1000,
}
}
func createUnlimitedPlan() *plans.DatabasePlan {
return &plans.DatabasePlan{
DatabaseID: uuid.New(),
MaxBackupSizeMB: 0,
MaxBackupsTotalSizeMB: 0,
MaxStoragePeriod: period.PeriodForever,
}
}

View File

@@ -26,6 +26,12 @@ func Test_AttachNotifierFromSameWorkspace_SuccessfullyAttached(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
notifier := notifiers.CreateTestNotifier(workspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
database.Notifiers = []notifiers.Notifier{*notifier}
var response databases.Database
@@ -55,6 +61,13 @@ func Test_AttachNotifierFromDifferentWorkspace_ReturnsForbidden(t *testing.T) {
workspace2 := workspaces_testing.CreateTestWorkspace("Workspace 2", owner2, router)
notifier := notifiers.CreateTestNotifier(workspace2.ID)
defer func() {
databases.RemoveTestDatabase(database)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace1, router)
workspaces_testing.RemoveTestWorkspace(workspace2, router)
}()
database.Notifiers = []notifiers.Notifier{*notifier}
testResp := test_utils.MakePostRequest(
@@ -77,6 +90,12 @@ func Test_DeleteNotifierWithAttachedDatabases_CannotDelete(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
notifier := notifiers.CreateTestNotifier(workspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
database.Notifiers = []notifiers.Notifier{*notifier}
var response databases.Database
@@ -114,6 +133,13 @@ func Test_TransferNotifierWithAttachedDatabase_CannotTransfer(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
notifier := notifiers.CreateTestNotifier(workspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
database.Notifiers = []notifiers.Notifier{*notifier}
var response databases.Database

View File

@@ -6,10 +6,10 @@ import (
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/features/storages"
users_models "databasus-backend/internal/features/users/models"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/period"
"github.com/google/uuid"
)
@@ -20,6 +20,7 @@ type BackupConfigService struct {
storageService *storages.StorageService
notifierService *notifiers.NotifierService
workspaceService *workspaces_services.WorkspaceService
databasePlanService *plans.DatabasePlanService
dbStorageChangeListener BackupConfigStorageChangeListener
}
@@ -45,7 +46,12 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
user *users_models.User,
backupConfig *BackupConfig,
) (*BackupConfig, error) {
if err := backupConfig.Validate(); err != nil {
plan, err := s.databasePlanService.GetDatabasePlan(backupConfig.DatabaseID)
if err != nil {
return nil, err
}
if err := backupConfig.Validate(plan); err != nil {
return nil, err
}
@@ -71,7 +77,7 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
if err != nil {
return nil, err
}
if storage.WorkspaceID != *database.WorkspaceID {
if storage.WorkspaceID != *database.WorkspaceID && !storage.IsSystem {
return nil, errors.New("storage does not belong to the same workspace as the database")
}
}
@@ -82,7 +88,12 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
func (s *BackupConfigService) SaveBackupConfig(
backupConfig *BackupConfig,
) (*BackupConfig, error) {
if err := backupConfig.Validate(); err != nil {
plan, err := s.databasePlanService.GetDatabasePlan(backupConfig.DatabaseID)
if err != nil {
return nil, err
}
if err := backupConfig.Validate(plan); err != nil {
return nil, err
}
@@ -120,6 +131,18 @@ func (s *BackupConfigService) GetBackupConfigByDbIdWithAuth(
return s.GetBackupConfigByDbId(databaseID)
}
func (s *BackupConfigService) GetDatabasePlan(
user *users_models.User,
databaseID uuid.UUID,
) (*plans.DatabasePlan, error) {
_, err := s.databaseService.GetDatabase(user, databaseID)
if err != nil {
return nil, err
}
return s.databasePlanService.GetDatabasePlan(databaseID)
}
func (s *BackupConfigService) GetBackupConfigByDbId(
databaseID uuid.UUID,
) (*BackupConfig, error) {
@@ -194,12 +217,19 @@ func (s *BackupConfigService) CreateDisabledBackupConfig(databaseID uuid.UUID) e
func (s *BackupConfigService) initializeDefaultConfig(
databaseID uuid.UUID,
) error {
plan, err := s.databasePlanService.GetDatabasePlan(databaseID)
if err != nil {
return err
}
timeOfDay := "04:00"
_, err := s.backupConfigRepository.Save(&BackupConfig{
DatabaseID: databaseID,
IsBackupsEnabled: false,
StorePeriod: period.PeriodWeek,
_, err = s.backupConfigRepository.Save(&BackupConfig{
DatabaseID: databaseID,
IsBackupsEnabled: false,
StorePeriod: plan.MaxStoragePeriod,
MaxBackupSizeMB: plan.MaxBackupSizeMB,
MaxBackupsTotalSizeMB: plan.MaxBackupsTotalSizeMB,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,

View File

@@ -27,6 +27,12 @@ func Test_AttachStorageFromSameWorkspace_SuccessfullyAttached(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
@@ -72,6 +78,13 @@ func Test_AttachStorageFromDifferentWorkspace_ReturnsForbidden(t *testing.T) {
workspace2 := workspaces_testing.CreateTestWorkspace("Workspace 2", owner2, router)
storage := createTestStorage(workspace2.ID)
defer func() {
databases.RemoveTestDatabase(database)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace1, router)
workspaces_testing.RemoveTestWorkspace(workspace2, router)
}()
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
@@ -110,6 +123,12 @@ func Test_DeleteStorageWithAttachedDatabases_CannotDelete(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
@@ -163,6 +182,13 @@ func Test_TransferStorageWithAttachedDatabase_CannotTransfer(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,

View File

@@ -25,80 +25,6 @@ import (
"databasus-backend/internal/util/tools"
)
func createTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
GetDatabaseController(),
)
return router
}
func getTestPostgresConfig() *postgresql.PostgresqlDatabase {
env := config.GetEnv()
port, err := strconv.Atoi(env.TestPostgres16Port)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
}
testDbName := "testdb"
return &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
}
}
func getTestMariadbConfig() *mariadb.MariadbDatabase {
env := config.GetEnv()
portStr := env.TestMariadb1011Port
if portStr == "" {
portStr = "33111"
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_MARIADB_1011_PORT: %v", err))
}
testDbName := "testdb"
return &mariadb.MariadbDatabase{
Version: tools.MariadbVersion1011,
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
}
}
func getTestMongodbConfig() *mongodb.MongodbDatabase {
env := config.GetEnv()
portStr := env.TestMongodb70Port
if portStr == "" {
portStr = "27070"
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_MONGODB_70_PORT: %v", err))
}
return &mongodb.MongodbDatabase{
Version: tools.MongodbVersion7,
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "root",
Password: "rootpassword",
Database: "testdb",
AuthDatabase: "admin",
IsHttps: false,
CpuCount: 1,
}
}
func Test_CreateDatabase_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
@@ -142,6 +68,7 @@ func Test_CreateDatabase_PermissionsEnforced(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
var testUserToken string
if tt.isGlobalAdmin {
@@ -180,6 +107,7 @@ func Test_CreateDatabase_PermissionsEnforced(t *testing.T) {
)
if tt.expectSuccess {
defer RemoveTestDatabase(&response)
assert.Equal(t, "Test Database", response.Name)
assert.NotEqual(t, uuid.Nil, response.ID)
} else {
@@ -193,6 +121,7 @@ func Test_CreateDatabase_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testin
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
@@ -258,8 +187,10 @@ func Test_UpdateDatabase_PermissionsEnforced(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(database)
var testUserToken string
if tt.isGlobalAdmin {
@@ -305,8 +236,10 @@ func Test_UpdateDatabase_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testin
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(database)
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
database.Name = "Hacked Name"
@@ -366,6 +299,7 @@ func Test_DeleteDatabase_PermissionsEnforced(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
@@ -396,6 +330,7 @@ func Test_DeleteDatabase_PermissionsEnforced(t *testing.T) {
)
if !tt.expectSuccess {
defer RemoveTestDatabase(database)
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
})
@@ -439,8 +374,10 @@ func Test_GetDatabase_PermissionsEnforced(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(database)
var testUser string
if tt.isGlobalAdmin {
@@ -517,9 +454,12 @@ func Test_GetDatabasesByWorkspace_PermissionsEnforced(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
createTestDatabaseViaAPI("Database 1", workspace.ID, owner.Token, router)
createTestDatabaseViaAPI("Database 2", workspace.ID, owner.Token, router)
db1 := createTestDatabaseViaAPI("Database 1", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(db1)
db2 := createTestDatabaseViaAPI("Database 2", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(db2)
var testUser string
if tt.isGlobalAdmin {
@@ -561,10 +501,14 @@ func Test_GetDatabasesByWorkspace_WhenMultipleDatabasesExist_ReturnsCorrectCount
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
createTestDatabaseViaAPI("Database 1", workspace.ID, owner.Token, router)
createTestDatabaseViaAPI("Database 2", workspace.ID, owner.Token, router)
createTestDatabaseViaAPI("Database 3", workspace.ID, owner.Token, router)
db1 := createTestDatabaseViaAPI("Database 1", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(db1)
db2 := createTestDatabaseViaAPI("Database 2", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(db2)
db3 := createTestDatabaseViaAPI("Database 3", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(db3)
var response []Database
test_utils.MakeGetRequestAndUnmarshal(
@@ -583,14 +527,19 @@ func Test_GetDatabasesByWorkspace_EnsuresCrossWorkspaceIsolation(t *testing.T) {
router := createTestRouter()
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace1 := workspaces_testing.CreateTestWorkspace("Workspace 1", owner1, router)
defer workspaces_testing.RemoveTestWorkspace(workspace1, router)
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace2 := workspaces_testing.CreateTestWorkspace("Workspace 2", owner2, router)
defer workspaces_testing.RemoveTestWorkspace(workspace2, router)
createTestDatabaseViaAPI("Workspace1 DB1", workspace1.ID, owner1.Token, router)
createTestDatabaseViaAPI("Workspace1 DB2", workspace1.ID, owner1.Token, router)
workspace1Db1 := createTestDatabaseViaAPI("Workspace1 DB1", workspace1.ID, owner1.Token, router)
defer RemoveTestDatabase(workspace1Db1)
workspace1Db2 := createTestDatabaseViaAPI("Workspace1 DB2", workspace1.ID, owner1.Token, router)
defer RemoveTestDatabase(workspace1Db2)
createTestDatabaseViaAPI("Workspace2 DB1", workspace2.ID, owner2.Token, router)
workspace2Db1 := createTestDatabaseViaAPI("Workspace2 DB1", workspace2.ID, owner2.Token, router)
defer RemoveTestDatabase(workspace2Db1)
var workspace1Dbs []Database
test_utils.MakeGetRequestAndUnmarshal(
@@ -667,8 +616,10 @@ func Test_CopyDatabase_PermissionsEnforced(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(database)
var testUserToken string
if tt.isGlobalAdmin {
@@ -700,6 +651,7 @@ func Test_CopyDatabase_PermissionsEnforced(t *testing.T) {
)
if tt.expectSuccess {
defer RemoveTestDatabase(&response)
assert.NotEqual(t, database.ID, response.ID)
assert.Contains(t, response.Name, "(Copy)")
} else {
@@ -713,8 +665,10 @@ func Test_CopyDatabase_CopyStaysInSameWorkspace(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(database)
var response Database
test_utils.MakePostRequestAndUnmarshal(
@@ -727,139 +681,14 @@ func Test_CopyDatabase_CopyStaysInSameWorkspace(t *testing.T) {
&response,
)
defer RemoveTestDatabase(&response)
assert.NotEqual(t, database.ID, response.ID)
assert.Equal(t, "Test Database (Copy)", response.Name)
assert.Equal(t, workspace.ID, *response.WorkspaceID)
assert.Equal(t, database.Type, response.Type)
}
func Test_TestConnection_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
isMember bool
isGlobalAdmin bool
expectAccessGranted bool
expectedStatusCodeOnErr int
}{
{
name: "workspace member can test connection",
isMember: true,
isGlobalAdmin: false,
expectAccessGranted: true,
expectedStatusCodeOnErr: http.StatusBadRequest,
},
{
name: "non-member cannot test connection",
isMember: false,
isGlobalAdmin: false,
expectAccessGranted: false,
expectedStatusCodeOnErr: http.StatusBadRequest,
},
{
name: "global admin can test connection",
isMember: false,
isGlobalAdmin: true,
expectAccessGranted: true,
expectedStatusCodeOnErr: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
var testUser string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUser = admin.Token
} else if tt.isMember {
testUser = owner.Token
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testUser = nonMember.Token
}
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/"+database.ID.String()+"/test-connection",
"Bearer "+testUser,
nil,
)
body := w.Body.String()
if tt.expectAccessGranted {
assert.True(
t,
w.Code == http.StatusOK ||
(w.Code == http.StatusBadRequest && strings.Contains(body, "connect")),
"Expected 200 OK or 400 with connection error, got %d: %s",
w.Code,
body,
)
} else {
assert.Equal(t, tt.expectedStatusCodeOnErr, w.Code)
assert.Contains(t, body, "insufficient permissions")
}
})
}
}
func createTestDatabaseViaAPI(
name string,
workspaceID uuid.UUID,
token string,
router *gin.Engine,
) *Database {
env := config.GetEnv()
port, err := strconv.Atoi(env.TestPostgres16Port)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
}
testDbName := "testdb"
request := Database{
Name: name,
WorkspaceID: &workspaceID,
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
},
}
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/create",
"Bearer "+token,
request,
)
if w.Code != http.StatusCreated {
panic(
fmt.Sprintf("Failed to create database. Status: %d, Body: %s", w.Code, w.Body.String()),
)
}
var database Database
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
panic(err)
}
return &database
}
func Test_CreateDatabase_PasswordIsEncryptedInDB(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
@@ -1141,3 +970,207 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
})
}
}
func Test_TestConnection_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
isMember bool
isGlobalAdmin bool
expectAccessGranted bool
expectedStatusCodeOnErr int
}{
{
name: "workspace member can test connection",
isMember: true,
isGlobalAdmin: false,
expectAccessGranted: true,
expectedStatusCodeOnErr: http.StatusBadRequest,
},
{
name: "non-member cannot test connection",
isMember: false,
isGlobalAdmin: false,
expectAccessGranted: false,
expectedStatusCodeOnErr: http.StatusBadRequest,
},
{
name: "global admin can test connection",
isMember: false,
isGlobalAdmin: true,
expectAccessGranted: true,
expectedStatusCodeOnErr: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(database)
var testUser string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUser = admin.Token
} else if tt.isMember {
testUser = owner.Token
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testUser = nonMember.Token
}
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/"+database.ID.String()+"/test-connection",
"Bearer "+testUser,
nil,
)
body := w.Body.String()
if tt.expectAccessGranted {
assert.True(
t,
w.Code == http.StatusOK ||
(w.Code == http.StatusBadRequest && strings.Contains(body, "connect")),
"Expected 200 OK or 400 with connection error, got %d: %s",
w.Code,
body,
)
} else {
assert.Equal(t, tt.expectedStatusCodeOnErr, w.Code)
assert.Contains(t, body, "insufficient permissions")
}
})
}
}
func createTestDatabaseViaAPI(
name string,
workspaceID uuid.UUID,
token string,
router *gin.Engine,
) *Database {
env := config.GetEnv()
port, err := strconv.Atoi(env.TestPostgres16Port)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
}
testDbName := "testdb"
request := Database{
Name: name,
WorkspaceID: &workspaceID,
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
},
}
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/create",
"Bearer "+token,
request,
)
if w.Code != http.StatusCreated {
panic(
fmt.Sprintf("Failed to create database. Status: %d, Body: %s", w.Code, w.Body.String()),
)
}
var database Database
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
panic(err)
}
return &database
}
func createTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
GetDatabaseController(),
)
return router
}
func getTestPostgresConfig() *postgresql.PostgresqlDatabase {
env := config.GetEnv()
port, err := strconv.Atoi(env.TestPostgres16Port)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
}
testDbName := "testdb"
return &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
}
}
func getTestMariadbConfig() *mariadb.MariadbDatabase {
env := config.GetEnv()
portStr := env.TestMariadb1011Port
if portStr == "" {
portStr = "33111"
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_MARIADB_1011_PORT: %v", err))
}
testDbName := "testdb"
return &mariadb.MariadbDatabase{
Version: tools.MariadbVersion1011,
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
}
}
func getTestMongodbConfig() *mongodb.MongodbDatabase {
env := config.GetEnv()
portStr := env.TestMongodb70Port
if portStr == "" {
portStr = "27070"
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_MONGODB_70_PORT: %v", err))
}
return &mongodb.MongodbDatabase{
Version: tools.MongodbVersion7,
Host: config.GetEnv().TestLocalhost,
Port: &port,
Username: "root",
Password: "rootpassword",
Database: "testdb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: false,
CpuCount: 1,
}
}

View File

@@ -25,13 +25,14 @@ type MariadbDatabase struct {
Version tools.MariadbVersion `json:"version" gorm:"type:text;not null"`
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database *string `json:"database" gorm:"type:text"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database *string `json:"database" gorm:"type:text"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
IsExcludeEvents bool `json:"isExcludeEvents" gorm:"type:boolean;default:false"`
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
}
func (m *MariadbDatabase) TableName() string {
@@ -124,6 +125,7 @@ func (m *MariadbDatabase) Update(incoming *MariadbDatabase) {
m.Username = incoming.Username
m.Database = incoming.Database
m.IsHttps = incoming.IsHttps
m.IsExcludeEvents = incoming.IsExcludeEvents
m.Privileges = incoming.Privileges
if incoming.Password != "" {
@@ -515,9 +517,13 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
hasProcess := false
hasAllPrivileges := false
// Escape underscores to match MariaDB's grant output format
// MariaDB escapes _ as \_ in SHOW GRANTS output
// Pattern matches either literal _ or escaped \_
escapedDbName := strings.ReplaceAll(regexp.QuoteMeta(database), "_", `(_|\\_)`)
dbPatternStr := fmt.Sprintf(
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
regexp.QuoteMeta(database),
escapedDbName,
)
dbPattern := regexp.MustCompile(dbPatternStr)
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)

View File

@@ -694,6 +694,115 @@ func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
assert.NoError(t, err)
}
func Test_TestConnection_DatabaseWithUnderscoresAndAllPrivileges_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MariadbVersion
port string
}{
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMariadbContainer(t, tc.port, tc.version)
defer container.DB.Close()
underscoreDbName := "test_all_db"
_, err := container.DB.Exec(
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
)
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
)
}()
underscoreDSN := fmt.Sprintf(
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
container.Username,
container.Password,
container.Host,
container.Port,
underscoreDbName,
)
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
assert.NoError(t, err)
defer underscoreDB.Close()
_, err = underscoreDB.Exec(`
CREATE TABLE all_priv_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)
`)
assert.NoError(t, err)
_, err = underscoreDB.Exec(`INSERT INTO all_priv_test (data) VALUES ('test1')`)
assert.NoError(t, err)
allPrivUsername := fmt.Sprintf("allpriv%s", uuid.New().String()[:8])
allPrivPassword := "allprivpass123"
_, err = underscoreDB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
allPrivUsername,
allPrivPassword,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec(fmt.Sprintf(
"GRANT ALL PRIVILEGES ON `%s`.* TO '%s'@'%%'",
underscoreDbName,
allPrivUsername,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer dropUserSafe(underscoreDB, allPrivUsername)
mariadbModel := &MariadbDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: allPrivUsername,
Password: allPrivPassword,
Database: &underscoreDbName,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mariadbModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
assert.NotEmpty(t, mariadbModel.Privileges)
assert.Contains(t, mariadbModel.Privileges, "SELECT")
assert.Contains(t, mariadbModel.Privileges, "SHOW VIEW")
})
}
}
type MariadbContainer struct {
Host string
Port int

View File

@@ -26,12 +26,13 @@ type MongodbDatabase struct {
Version tools.MongodbVersion `json:"version" gorm:"type:text;not null"`
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Port *int `json:"port" gorm:"type:int"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database string `json:"database" gorm:"type:text;not null"`
AuthDatabase string `json:"authDatabase" gorm:"type:text;not null;default:'admin'"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
IsSrv bool `json:"isSrv" gorm:"column:is_srv;type:boolean;not null;default:false"`
CpuCount int `json:"cpuCount" gorm:"column:cpu_count;type:int;not null;default:1"`
}
@@ -43,9 +44,13 @@ func (m *MongodbDatabase) Validate() error {
if m.Host == "" {
return errors.New("host is required")
}
if m.Port == 0 {
return errors.New("port is required")
if !m.IsSrv {
if m.Port == nil || *m.Port == 0 {
return errors.New("port is required for standard connections")
}
}
if m.Username == "" {
return errors.New("username is required")
}
@@ -58,6 +63,7 @@ func (m *MongodbDatabase) Validate() error {
if m.CpuCount <= 0 {
return errors.New("cpu count must be greater than 0")
}
return nil
}
@@ -125,6 +131,7 @@ func (m *MongodbDatabase) Update(incoming *MongodbDatabase) {
m.Database = incoming.Database
m.AuthDatabase = incoming.AuthDatabase
m.IsHttps = incoming.IsHttps
m.IsSrv = incoming.IsSrv
m.CpuCount = incoming.CpuCount
if incoming.Password != "" {
@@ -455,12 +462,29 @@ func (m *MongodbDatabase) buildConnectionURI(password string) string {
tlsParams = "&tls=true&tlsInsecure=true"
}
if m.IsSrv {
return fmt.Sprintf(
"mongodb+srv://%s:%s@%s/%s?authSource=%s&connectTimeoutMS=15000%s",
url.QueryEscape(m.Username),
url.QueryEscape(password),
m.Host,
m.Database,
authDB,
tlsParams,
)
}
port := 27017
if m.Port != nil {
port = *m.Port
}
return fmt.Sprintf(
"mongodb://%s:%s@%s:%d/%s?authSource=%s&connectTimeoutMS=15000%s",
url.QueryEscape(m.Username),
url.QueryEscape(password),
m.Host,
m.Port,
port,
m.Database,
authDB,
tlsParams,
@@ -479,12 +503,28 @@ func (m *MongodbDatabase) BuildMongodumpURI(password string) string {
tlsParams = "&tls=true&tlsInsecure=true"
}
if m.IsSrv {
return fmt.Sprintf(
"mongodb+srv://%s:%s@%s/?authSource=%s&connectTimeoutMS=15000%s",
url.QueryEscape(m.Username),
url.QueryEscape(password),
m.Host,
authDB,
tlsParams,
)
}
port := 27017
if m.Port != nil {
port = *m.Port
}
return fmt.Sprintf(
"mongodb://%s:%s@%s:%d/?authSource=%s&connectTimeoutMS=15000%s",
url.QueryEscape(m.Username),
url.QueryEscape(password),
m.Host,
m.Port,
port,
authDB,
tlsParams,
)

View File

@@ -64,15 +64,17 @@ func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
defer dropUserSafe(container.Client, limitedUsername, container.AuthDatabase)
port := container.Port
mongodbModel := &MongodbDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Port: &port,
Username: limitedUsername,
Password: limitedPassword,
Database: container.Database,
AuthDatabase: container.AuthDatabase,
IsHttps: false,
IsSrv: false,
CpuCount: 1,
}
@@ -133,15 +135,17 @@ func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
defer dropUserSafe(container.Client, backupUsername, container.AuthDatabase)
port := container.Port
mongodbModel := &MongodbDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Port: &port,
Username: backupUsername,
Password: backupPassword,
Database: container.Database,
AuthDatabase: container.AuthDatabase,
IsHttps: false,
IsSrv: false,
CpuCount: 1,
}
@@ -442,15 +446,17 @@ func connectToMongodbContainer(
}
func createMongodbModel(container *MongodbContainer) *MongodbDatabase {
port := container.Port
return &MongodbDatabase{
Version: container.Version,
Host: container.Host,
Port: container.Port,
Port: &port,
Username: container.Username,
Password: container.Password,
Database: container.Database,
AuthDatabase: container.AuthDatabase,
IsHttps: false,
IsSrv: false,
CpuCount: 1,
}
}
@@ -489,3 +495,157 @@ func assertWriteDenied(t *testing.T, err error) {
strings.Contains(errStr, "permission denied"),
"Expected authorization error, got: %v", err)
}
func Test_BuildConnectionURI_WithSrvFormat_ReturnsCorrectUri(t *testing.T) {
port := 27017
model := &MongodbDatabase{
Host: "cluster0.example.mongodb.net",
Port: &port,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: true,
}
uri := model.buildConnectionURI("testpass123")
assert.Contains(t, uri, "mongodb+srv://")
assert.Contains(t, uri, "testuser")
assert.Contains(t, uri, "testpass123")
assert.Contains(t, uri, "cluster0.example.mongodb.net")
assert.Contains(t, uri, "/mydb")
assert.Contains(t, uri, "authSource=admin")
assert.Contains(t, uri, "connectTimeoutMS=15000")
assert.NotContains(t, uri, ":27017")
}
func Test_BuildConnectionURI_WithStandardFormat_ReturnsCorrectUri(t *testing.T) {
port := 27017
model := &MongodbDatabase{
Host: "localhost",
Port: &port,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: false,
}
uri := model.buildConnectionURI("testpass123")
assert.Contains(t, uri, "mongodb://")
assert.Contains(t, uri, "testuser")
assert.Contains(t, uri, "testpass123")
assert.Contains(t, uri, "localhost:27017")
assert.Contains(t, uri, "/mydb")
assert.Contains(t, uri, "authSource=admin")
assert.Contains(t, uri, "connectTimeoutMS=15000")
assert.NotContains(t, uri, "mongodb+srv://")
}
func Test_BuildConnectionURI_WithNullPort_UsesDefault(t *testing.T) {
model := &MongodbDatabase{
Host: "localhost",
Port: nil,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: false,
}
uri := model.buildConnectionURI("testpass123")
assert.Contains(t, uri, "localhost:27017")
}
func Test_BuildMongodumpURI_WithSrvFormat_ReturnsCorrectUri(t *testing.T) {
port := 27017
model := &MongodbDatabase{
Host: "cluster0.example.mongodb.net",
Port: &port,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: true,
}
uri := model.BuildMongodumpURI("testpass123")
assert.Contains(t, uri, "mongodb+srv://")
assert.Contains(t, uri, "testuser")
assert.Contains(t, uri, "testpass123")
assert.Contains(t, uri, "cluster0.example.mongodb.net")
assert.Contains(t, uri, "/?authSource=admin")
assert.Contains(t, uri, "connectTimeoutMS=15000")
assert.NotContains(t, uri, ":27017")
assert.NotContains(t, uri, "/mydb")
}
func Test_BuildMongodumpURI_WithStandardFormat_ReturnsCorrectUri(t *testing.T) {
port := 27017
model := &MongodbDatabase{
Host: "localhost",
Port: &port,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: false,
}
uri := model.BuildMongodumpURI("testpass123")
assert.Contains(t, uri, "mongodb://")
assert.Contains(t, uri, "testuser")
assert.Contains(t, uri, "testpass123")
assert.Contains(t, uri, "localhost:27017")
assert.Contains(t, uri, "/?authSource=admin")
assert.Contains(t, uri, "connectTimeoutMS=15000")
assert.NotContains(t, uri, "mongodb+srv://")
assert.NotContains(t, uri, "/mydb")
}
func Test_Validate_SrvConnection_AllowsNullPort(t *testing.T) {
model := &MongodbDatabase{
Host: "cluster0.example.mongodb.net",
Port: nil,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: true,
CpuCount: 1,
}
err := model.Validate()
assert.NoError(t, err)
}
func Test_Validate_StandardConnection_RequiresPort(t *testing.T) {
model := &MongodbDatabase{
Host: "localhost",
Port: nil,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: false,
CpuCount: 1,
}
err := model.Validate()
assert.Error(t, err)
assert.Contains(t, err.Error(), "port is required for standard connections")
}

View File

@@ -489,9 +489,13 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
hasProcess := false
hasAllPrivileges := false
// Escape underscores to match MySQL's grant output format
// MySQL escapes _ as \_ in SHOW GRANTS output
// Pattern matches either literal _ or escaped \_
escapedDbName := strings.ReplaceAll(regexp.QuoteMeta(database), "_", `(_|\\_)`)
dbPatternStr := fmt.Sprintf(
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
regexp.QuoteMeta(database),
escapedDbName,
)
dbPattern := regexp.MustCompile(dbPatternStr)
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)

View File

@@ -674,6 +674,112 @@ func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
assert.NoError(t, err)
}
func Test_TestConnection_DatabaseWithUnderscoresAndAllPrivileges_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MysqlVersion
port string
}{
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
{"MySQL 9", tools.MysqlVersion9, env.TestMysql90Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMysqlContainer(t, tc.port, tc.version)
defer container.DB.Close()
underscoreDbName := "test_all_db"
_, err := container.DB.Exec(
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
)
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
)
}()
underscoreDSN := fmt.Sprintf(
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
container.Username,
container.Password,
container.Host,
container.Port,
underscoreDbName,
)
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
assert.NoError(t, err)
defer underscoreDB.Close()
_, err = underscoreDB.Exec(`
CREATE TABLE all_priv_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)
`)
assert.NoError(t, err)
_, err = underscoreDB.Exec(`INSERT INTO all_priv_test (data) VALUES ('test1')`)
assert.NoError(t, err)
allPrivUsername := fmt.Sprintf("allpriv_%s", uuid.New().String()[:8])
allPrivPassword := "allprivpass123"
_, err = underscoreDB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
allPrivUsername,
allPrivPassword,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec(fmt.Sprintf(
"GRANT ALL PRIVILEGES ON `%s`.* TO '%s'@'%%'",
underscoreDbName,
allPrivUsername,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(
fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", allPrivUsername),
)
}()
mysqlModel := &MysqlDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: allPrivUsername,
Password: allPrivPassword,
Database: &underscoreDbName,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mysqlModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
assert.NotEmpty(t, mysqlModel.Privileges)
assert.Contains(t, mysqlModel.Privileges, "SELECT")
assert.Contains(t, mysqlModel.Privileges, "SHOW VIEW")
})
}
}
type MysqlContainer struct {
Host string
Port int

View File

@@ -13,6 +13,7 @@ import (
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"gorm.io/gorm"
)
@@ -394,10 +395,13 @@ func (p *PostgresqlDatabase) IsUserReadOnly(
//
// This method performs the following operations atomically in a single transaction:
// 1. Creates a PostgreSQL user with a UUID-based password
// 2. Grants CONNECT privilege on the database
// 3. Grants USAGE on all non-system schemas
// 4. Grants SELECT on all existing tables and sequences
// 5. Sets default privileges for future tables and sequences
// 2. Revokes CREATE privilege on public schema from PUBLIC role
// 3. Grants CONNECT privilege on the database
// 4. Discovers all user-created schemas
// 5. Grants USAGE on all non-system schemas
// 6. Grants SELECT on all existing tables and sequences
// 7. Sets default privileges for future tables and sequences
// 8. Verifies user creation before committing
//
// Security features:
// - Username format: "databasus-{8-char-uuid}" for uniqueness
@@ -487,33 +491,56 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
return "", "", fmt.Errorf("failed to create user: %w", err)
}
// Step 1.5: Revoke CREATE privilege from PUBLIC role on public schema
// Step 2: Check if public schema exists and revoke CREATE privilege if it does
// This is necessary because all PostgreSQL users inherit CREATE privilege on the
// public schema through the PUBLIC role. This is a one-time operation that affects
// the entire database, making it more secure by default.
// Note: This only affects the public schema; other schemas are unaffected.
_, err = tx.Exec(ctx, `REVOKE CREATE ON SCHEMA public FROM PUBLIC`)
if err != nil {
logger.Error("Failed to revoke CREATE on public from PUBLIC", "error", err)
if !strings.Contains(err.Error(), "schema \"public\" does not exist") &&
!strings.Contains(err.Error(), "permission denied") {
return "", "", fmt.Errorf("failed to revoke CREATE from PUBLIC: %w", err)
}
}
// Now revoke from the specific user as well (belt and suspenders)
_, err = tx.Exec(ctx, fmt.Sprintf(`REVOKE CREATE ON SCHEMA public FROM "%s"`, baseUsername))
if err != nil {
logger.Error(
"Failed to revoke CREATE on public schema from user",
"error",
err,
"username",
baseUsername,
var publicSchemaExists bool
err = tx.QueryRow(ctx, `
SELECT EXISTS(
SELECT 1 FROM information_schema.schemata
WHERE schema_name = 'public'
)
`).Scan(&publicSchemaExists)
if err != nil {
return "", "", fmt.Errorf("failed to check if public schema exists: %w", err)
}
// Step 2: Grant database connection privilege and revoke TEMP
if publicSchemaExists {
// Revoke CREATE from PUBLIC role (affects all users)
_, err = tx.Exec(ctx, `REVOKE CREATE ON SCHEMA public FROM PUBLIC`)
if err != nil {
if strings.Contains(err.Error(), "permission denied") {
logger.Warn(
"Failed to revoke CREATE on public from PUBLIC (permission denied)",
"error",
err,
)
} else {
return "", "", fmt.Errorf("failed to revoke CREATE from PUBLIC on existing public schema: %w", err)
}
}
// Now revoke from the specific user as well (belt and suspenders)
_, err = tx.Exec(
ctx,
fmt.Sprintf(`REVOKE CREATE ON SCHEMA public FROM "%s"`, baseUsername),
)
if err != nil {
logger.Warn(
"Failed to revoke CREATE on public schema from user",
"error",
err,
"username",
baseUsername,
)
}
} else {
logger.Info("Public schema does not exist, skipping CREATE privilege revocation")
}
// Step 3: Grant database connection privilege and revoke TEMP
_, err = tx.Exec(
ctx,
fmt.Sprintf(`GRANT CONNECT ON DATABASE "%s" TO "%s"`, *p.Database, baseUsername),
@@ -537,12 +564,23 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
logger.Warn("Failed to revoke TEMP privilege", "error", err, "username", baseUsername)
}
// Step 3: Discover all user-created schemas
rows, err := tx.Query(ctx, `
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
`)
// Step 4: Discover schemas to grant privileges on
// If IncludeSchemas is specified, only use those schemas; otherwise use all non-system schemas
var rows pgx.Rows
if len(p.IncludeSchemas) > 0 {
rows, err = tx.Query(ctx, `
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
AND schema_name = ANY($1::text[])
`, p.IncludeSchemas)
} else {
rows, err = tx.Query(ctx, `
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
`)
}
if err != nil {
return "", "", fmt.Errorf("failed to get schemas: %w", err)
}
@@ -562,7 +600,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
return "", "", fmt.Errorf("error iterating schemas: %w", err)
}
// Step 4: Grant USAGE on each schema and explicitly prevent CREATE
// Step 5: Grant USAGE on each schema and explicitly prevent CREATE
for _, schema := range schemas {
// Revoke CREATE specifically (handles inheritance from PUBLIC role)
_, err = tx.Exec(
@@ -591,51 +629,198 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
}
}
// Step 5: Grant SELECT on ALL existing tables and sequences
grantSelectSQL := fmt.Sprintf(`
DO $$
DECLARE
schema_rec RECORD;
BEGIN
FOR schema_rec IN
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
LOOP
EXECUTE format('GRANT SELECT ON ALL TABLES IN SCHEMA %%I TO "%s"', schema_rec.schema_name);
EXECUTE format('GRANT SELECT ON ALL SEQUENCES IN SCHEMA %%I TO "%s"', schema_rec.schema_name);
END LOOP;
END $$;
`, baseUsername, baseUsername)
// Step 6: Grant SELECT on ALL existing tables and sequences
// Use the already-filtered schemas list from Step 4
for _, schema := range schemas {
_, err = tx.Exec(
ctx,
fmt.Sprintf(
`GRANT SELECT ON ALL TABLES IN SCHEMA "%s" TO "%s"`,
schema,
baseUsername,
),
)
if err != nil {
return "", "", fmt.Errorf(
"failed to grant select on tables in schema %s: %w",
schema,
err,
)
}
_, err = tx.Exec(ctx, grantSelectSQL)
if err != nil {
return "", "", fmt.Errorf("failed to grant select on tables: %w", err)
_, err = tx.Exec(
ctx,
fmt.Sprintf(
`GRANT SELECT ON ALL SEQUENCES IN SCHEMA "%s" TO "%s"`,
schema,
baseUsername,
),
)
if err != nil {
return "", "", fmt.Errorf(
"failed to grant select on sequences in schema %s: %w",
schema,
err,
)
}
}
// Step 6: Set default privileges for FUTURE tables and sequences
defaultPrivilegesSQL := fmt.Sprintf(`
DO $$
DECLARE
schema_rec RECORD;
BEGIN
FOR schema_rec IN
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
LOOP
EXECUTE format('ALTER DEFAULT PRIVILEGES IN SCHEMA %%I GRANT SELECT ON TABLES TO "%s"', schema_rec.schema_name);
EXECUTE format('ALTER DEFAULT PRIVILEGES IN SCHEMA %%I GRANT SELECT ON SEQUENCES TO "%s"', schema_rec.schema_name);
END LOOP;
END $$;
`, baseUsername, baseUsername)
// Step 7: Set default privileges for FUTURE tables and sequences
// First, set default privileges for objects created by the current user
// Use the already-filtered schemas list from Step 4
for _, schema := range schemas {
_, err = tx.Exec(
ctx,
fmt.Sprintf(
`ALTER DEFAULT PRIVILEGES IN SCHEMA "%s" GRANT SELECT ON TABLES TO "%s"`,
schema,
baseUsername,
),
)
if err != nil {
return "", "", fmt.Errorf(
"failed to set default privileges for tables in schema %s: %w",
schema,
err,
)
}
_, err = tx.Exec(ctx, defaultPrivilegesSQL)
if err != nil {
return "", "", fmt.Errorf("failed to set default privileges: %w", err)
_, err = tx.Exec(
ctx,
fmt.Sprintf(
`ALTER DEFAULT PRIVILEGES IN SCHEMA "%s" GRANT SELECT ON SEQUENCES TO "%s"`,
schema,
baseUsername,
),
)
if err != nil {
return "", "", fmt.Errorf(
"failed to set default privileges for sequences in schema %s: %w",
schema,
err,
)
}
}
// Step 7: Verify user creation before committing
// Step 8: Discover all roles that own objects in each schema
// This is needed because ALTER DEFAULT PRIVILEGES only applies to objects created by the current role.
// To handle tables created by OTHER users (like the GitHub issue with partitioned tables),
// we need to set "ALTER DEFAULT PRIVILEGES FOR ROLE <owner>" for each object owner.
// Filter by IncludeSchemas if specified.
type SchemaOwner struct {
SchemaName string
RoleName string
}
var ownerRows pgx.Rows
if len(p.IncludeSchemas) > 0 {
ownerRows, err = tx.Query(ctx, `
SELECT DISTINCT n.nspname as schema_name, pg_get_userbyid(c.relowner) as role_name
FROM pg_class c
JOIN pg_namespace n ON c.relnamespace = n.oid
WHERE n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
AND n.nspname = ANY($1::text[])
AND c.relkind IN ('r', 'p', 'v', 'm', 'f')
AND pg_get_userbyid(c.relowner) != current_user
ORDER BY n.nspname, role_name
`, p.IncludeSchemas)
} else {
ownerRows, err = tx.Query(ctx, `
SELECT DISTINCT n.nspname as schema_name, pg_get_userbyid(c.relowner) as role_name
FROM pg_class c
JOIN pg_namespace n ON c.relnamespace = n.oid
WHERE n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
AND c.relkind IN ('r', 'p', 'v', 'm', 'f')
AND pg_get_userbyid(c.relowner) != current_user
ORDER BY n.nspname, role_name
`)
}
if err != nil {
// Log warning but continue - this is a best-effort enhancement
logger.Warn("Failed to query object owners for default privileges", "error", err)
} else {
var schemaOwners []SchemaOwner
for ownerRows.Next() {
var so SchemaOwner
if err := ownerRows.Scan(&so.SchemaName, &so.RoleName); err != nil {
ownerRows.Close()
logger.Warn("Failed to scan schema owner", "error", err)
break
}
schemaOwners = append(schemaOwners, so)
}
ownerRows.Close()
if err := ownerRows.Err(); err != nil {
logger.Warn("Error iterating schema owners", "error", err)
}
// Step 9: Set default privileges FOR ROLE for each object owner
// Note: This may fail for some roles due to permission issues (e.g., roles owned by other superusers)
// We log warnings but continue - user creation should succeed even if some roles can't be configured
for _, so := range schemaOwners {
// Try to set default privileges for tables
_, err = tx.Exec(
ctx,
fmt.Sprintf(
`ALTER DEFAULT PRIVILEGES FOR ROLE "%s" IN SCHEMA "%s" GRANT SELECT ON TABLES TO "%s"`,
so.RoleName,
so.SchemaName,
baseUsername,
),
)
if err != nil {
logger.Warn(
"Failed to set default privileges for role (tables)",
"error",
err,
"role",
so.RoleName,
"schema",
so.SchemaName,
"readonly_user",
baseUsername,
)
}
// Try to set default privileges for sequences
_, err = tx.Exec(
ctx,
fmt.Sprintf(
`ALTER DEFAULT PRIVILEGES FOR ROLE "%s" IN SCHEMA "%s" GRANT SELECT ON SEQUENCES TO "%s"`,
so.RoleName,
so.SchemaName,
baseUsername,
),
)
if err != nil {
logger.Warn(
"Failed to set default privileges for role (sequences)",
"error",
err,
"role",
so.RoleName,
"schema",
so.SchemaName,
"readonly_user",
baseUsername,
)
}
}
if len(schemaOwners) > 0 {
logger.Info(
"Set default privileges for existing object owners",
"readonly_user",
baseUsername,
"owner_count",
len(schemaOwners),
)
}
}
// Step 10: Verify user creation before committing
var verifyUsername string
err = tx.QueryRow(ctx, fmt.Sprintf(`SELECT rolname FROM pg_roles WHERE rolname = '%s'`, baseUsername)).
Scan(&verifyUsername)
@@ -851,7 +1036,15 @@ func checkBackupPermissions(
}
if err != nil {
return fmt.Errorf("cannot check SELECT privileges: %w", err)
// If the user doesn't have USAGE on the schema, has_table_privilege will fail
// with "permission denied for schema". This means they definitely don't have
// SELECT privileges, so treat this as missing permissions rather than an error.
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code == "42501" { // insufficient_privilege
selectableTableCount = 0
} else {
return fmt.Errorf("cannot check SELECT privileges: %w", err)
}
}
if selectableTableCount == 0 {
missingPrivileges = append(missingPrivileges, "SELECT on tables")

View File

@@ -709,6 +709,344 @@ func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
assert.Contains(t, err.Error(), "permission denied")
}
func Test_CreateReadOnlyUser_WithPublicSchema_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version string
port string
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
{"PostgreSQL 16", "16", env.TestPostgres16Port},
{"PostgreSQL 17", "17", env.TestPostgres17Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToPostgresContainer(t, tc.port)
defer container.DB.Close()
_, err := container.DB.Exec(`
DROP TABLE IF EXISTS public_schema_test CASCADE;
CREATE TABLE public_schema_test (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO public_schema_test (data) VALUES ('test1'), ('test2');
`)
assert.NoError(t, err)
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.NotEmpty(t, username)
assert.NotEmpty(t, password)
assert.True(t, strings.HasPrefix(username, "databasus-"))
readOnlyModel := &PostgresqlDatabase{
Version: pgModel.Version,
Host: pgModel.Host,
Port: pgModel.Port,
Username: username,
Password: password,
Database: pgModel.Database,
IsHttps: false,
}
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
ctx,
logger,
nil,
uuid.New(),
)
assert.NoError(t, err)
assert.True(t, isReadOnly, "User should be read-only")
assert.Empty(t, privileges, "Read-only user should have no write privileges")
readOnlyDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
username,
password,
container.Database,
)
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var count int
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM public_schema_test")
assert.NoError(t, err)
assert.Equal(t, 2, count)
_, err = readOnlyConn.Exec(
"INSERT INTO public_schema_test (data) VALUES ('should-fail')",
)
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec("CREATE TABLE public.hack_table (id INT)")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
}
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
assert.NoError(t, err)
})
}
}
func Test_CreateReadOnlyUser_WithoutPublicSchema_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version string
port string
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
{"PostgreSQL 16", "16", env.TestPostgres16Port},
{"PostgreSQL 17", "17", env.TestPostgres17Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToPostgresContainer(t, tc.port)
defer container.DB.Close()
_, err := container.DB.Exec(`
DROP SCHEMA IF EXISTS public CASCADE;
DROP SCHEMA IF EXISTS app_schema CASCADE;
DROP SCHEMA IF EXISTS data_schema CASCADE;
CREATE SCHEMA app_schema;
CREATE SCHEMA data_schema;
CREATE TABLE app_schema.users (
id SERIAL PRIMARY KEY,
username TEXT NOT NULL
);
CREATE TABLE data_schema.records (
id SERIAL PRIMARY KEY,
info TEXT NOT NULL
);
INSERT INTO app_schema.users (username) VALUES ('user1'), ('user2');
INSERT INTO data_schema.records (info) VALUES ('record1'), ('record2');
`)
assert.NoError(t, err)
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err, "CreateReadOnlyUser should succeed without public schema")
assert.NotEmpty(t, username)
assert.NotEmpty(t, password)
assert.True(t, strings.HasPrefix(username, "databasus-"))
readOnlyModel := &PostgresqlDatabase{
Version: pgModel.Version,
Host: pgModel.Host,
Port: pgModel.Port,
Username: username,
Password: password,
Database: pgModel.Database,
IsHttps: false,
}
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
ctx,
logger,
nil,
uuid.New(),
)
assert.NoError(t, err)
assert.True(t, isReadOnly, "User should be read-only")
assert.Empty(t, privileges, "Read-only user should have no write privileges")
readOnlyDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
username,
password,
container.Database,
)
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var userCount int
err = readOnlyConn.Get(&userCount, "SELECT COUNT(*) FROM app_schema.users")
assert.NoError(t, err)
assert.Equal(t, 2, userCount)
var recordCount int
err = readOnlyConn.Get(&recordCount, "SELECT COUNT(*) FROM data_schema.records")
assert.NoError(t, err)
assert.Equal(t, 2, recordCount)
_, err = readOnlyConn.Exec(
"INSERT INTO app_schema.users (username) VALUES ('should-fail')",
)
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec("CREATE TABLE app_schema.hack_table (id INT)")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec("CREATE TABLE data_schema.hack_table (id INT)")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
}
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
assert.NoError(t, err)
_, err = container.DB.Exec(`
DROP SCHEMA IF EXISTS app_schema CASCADE;
DROP SCHEMA IF EXISTS data_schema CASCADE;
CREATE SCHEMA IF NOT EXISTS public;
`)
assert.NoError(t, err)
})
}
}
func Test_CreateReadOnlyUser_PublicSchemaExistsButNoPermissions_ReturnsError(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version string
port string
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
{"PostgreSQL 16", "16", env.TestPostgres16Port},
{"PostgreSQL 17", "17", env.TestPostgres17Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToPostgresContainer(t, tc.port)
defer container.DB.Close()
limitedAdminUsername := fmt.Sprintf("limited_admin_%s", uuid.New().String()[:8])
limitedAdminPassword := "limited_password_123"
_, err := container.DB.Exec(`
CREATE SCHEMA IF NOT EXISTS public;
DROP TABLE IF EXISTS public.permission_test_table CASCADE;
CREATE TABLE public.permission_test_table (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO public.permission_test_table (data) VALUES ('test1');
`)
assert.NoError(t, err)
_, err = container.DB.Exec(`GRANT CREATE ON SCHEMA public TO PUBLIC`)
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN CREATEROLE`,
limitedAdminUsername,
limitedAdminPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
container.Database,
limitedAdminUsername,
))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(
fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, limitedAdminUsername),
)
_, _ = container.DB.Exec(
fmt.Sprintf(`DROP USER IF EXISTS "%s"`, limitedAdminUsername),
)
_, _ = container.DB.Exec(`REVOKE CREATE ON SCHEMA public FROM PUBLIC`)
}()
limitedAdminDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
limitedAdminUsername,
limitedAdminPassword,
container.Database,
)
limitedAdminConn, err := sqlx.Connect("postgres", limitedAdminDSN)
assert.NoError(t, err)
defer limitedAdminConn.Close()
pgModel := &PostgresqlDatabase{
Version: tools.GetPostgresqlVersionEnum(tc.version),
Host: container.Host,
Port: container.Port,
Username: limitedAdminUsername,
Password: limitedAdminPassword,
Database: &container.Database,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.Error(
t,
err,
"CreateReadOnlyUser should fail when admin lacks permissions to secure public schema",
)
if err != nil {
errorMsg := err.Error()
hasExpectedError := strings.Contains(
errorMsg,
"failed to revoke CREATE from PUBLIC on existing public schema",
) ||
strings.Contains(errorMsg, "permission denied for schema public") ||
strings.Contains(errorMsg, "failed to grant")
assert.True(
t,
hasExpectedError,
"Error should indicate permission issues with public schema, got: %s",
errorMsg,
)
}
assert.Empty(t, username)
assert.Empty(t, password)
})
}
}
func Test_Validate_WhenLocalhostAndDatabasus_ReturnsError(t *testing.T) {
testCases := []struct {
name string
@@ -981,6 +1319,346 @@ type PostgresContainer struct {
DB *sqlx.DB
}
func Test_CreateReadOnlyUser_TablesCreatedByDifferentUser_ReadOnlyUserCanRead(t *testing.T) {
env := config.GetEnv()
container := connectToPostgresContainer(t, env.TestPostgres16Port)
defer container.DB.Close()
// Step 1: Create a second database user who will create tables
userCreatorUsername := fmt.Sprintf("user_creator_%s", uuid.New().String()[:8])
userCreatorPassword := "creator_password_123"
_, err := container.DB.Exec(fmt.Sprintf(
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
userCreatorUsername,
userCreatorPassword,
))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, userCreatorUsername))
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, userCreatorUsername))
}()
// Step 2: Grant the user_creator privileges to connect and create tables
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
container.Database,
userCreatorUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT USAGE ON SCHEMA public TO "%s"`,
userCreatorUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT CREATE ON SCHEMA public TO "%s"`,
userCreatorUsername,
))
assert.NoError(t, err)
// Step 2b: Create an initial table by user_creator so they become an object owner
// This is important because our fix discovers existing object owners
userCreatorDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
userCreatorUsername,
userCreatorPassword,
container.Database,
)
userCreatorConn, err := sqlx.Connect("postgres", userCreatorDSN)
assert.NoError(t, err)
defer userCreatorConn.Close()
initialTableName := fmt.Sprintf(
"public.initial_table_%s",
strings.ReplaceAll(uuid.New().String()[:8], "-", ""),
)
_, err = userCreatorConn.Exec(fmt.Sprintf(`
CREATE TABLE %s (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO %s (data) VALUES ('initial_data');
`, initialTableName, initialTableName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP TABLE IF EXISTS %s CASCADE`, initialTableName))
}()
// Step 3: NOW create read-only user via Databasus (as admin)
// At this point, user_creator already owns objects, so ALTER DEFAULT PRIVILEGES FOR ROLE should apply
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
readonlyUsername, readonlyPassword, err := pgModel.CreateReadOnlyUser(
ctx,
logger,
nil,
uuid.New(),
)
assert.NoError(t, err)
assert.NotEmpty(t, readonlyUsername)
assert.NotEmpty(t, readonlyPassword)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, readonlyUsername))
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, readonlyUsername))
}()
// Step 4: user_creator creates a NEW table AFTER the read-only user was created
// This table should automatically grant SELECT to the read-only user via ALTER DEFAULT PRIVILEGES FOR ROLE
tableName := fmt.Sprintf(
"public.future_table_%s",
strings.ReplaceAll(uuid.New().String()[:8], "-", ""),
)
_, err = userCreatorConn.Exec(fmt.Sprintf(`
CREATE TABLE %s (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO %s (data) VALUES ('test_data_1'), ('test_data_2');
`, tableName, tableName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP TABLE IF EXISTS %s CASCADE`, tableName))
}()
// Step 5: Connect as read-only user and verify it can SELECT from the new table
readonlyDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
readonlyUsername,
readonlyPassword,
container.Database,
)
readonlyConn, err := sqlx.Connect("postgres", readonlyDSN)
assert.NoError(t, err)
defer readonlyConn.Close()
var count int
err = readonlyConn.Get(&count, fmt.Sprintf("SELECT COUNT(*) FROM %s", tableName))
assert.NoError(t, err)
assert.Equal(
t,
2,
count,
"Read-only user should be able to SELECT from table created by different user",
)
// Step 6: Verify read-only user cannot write to the table
_, err = readonlyConn.Exec(
fmt.Sprintf("INSERT INTO %s (data) VALUES ('should-fail')", tableName),
)
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
// Step 7: Verify pg_dump operations (LOCK TABLE) work
// pg_dump needs to lock tables in ACCESS SHARE MODE for consistent backup
tx, err := readonlyConn.Begin()
assert.NoError(t, err)
defer tx.Rollback()
_, err = tx.Exec(fmt.Sprintf("LOCK TABLE %s IN ACCESS SHARE MODE", tableName))
assert.NoError(t, err, "Read-only user should be able to LOCK TABLE (needed for pg_dump)")
err = tx.Commit()
assert.NoError(t, err)
}
func Test_CreateReadOnlyUser_WithIncludeSchemas_OnlyGrantsAccessToSpecifiedSchemas(t *testing.T) {
env := config.GetEnv()
container := connectToPostgresContainer(t, env.TestPostgres16Port)
defer container.DB.Close()
// Step 1: Create multiple schemas and tables
_, err := container.DB.Exec(`
DROP SCHEMA IF EXISTS included_schema CASCADE;
DROP SCHEMA IF EXISTS excluded_schema CASCADE;
CREATE SCHEMA included_schema;
CREATE SCHEMA excluded_schema;
CREATE TABLE public.public_table (id INT, data TEXT);
INSERT INTO public.public_table VALUES (1, 'public_data');
CREATE TABLE included_schema.included_table (id INT, data TEXT);
INSERT INTO included_schema.included_table VALUES (2, 'included_data');
CREATE TABLE excluded_schema.excluded_table (id INT, data TEXT);
INSERT INTO excluded_schema.excluded_table VALUES (3, 'excluded_data');
`)
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(`DROP SCHEMA IF EXISTS included_schema CASCADE`)
_, _ = container.DB.Exec(`DROP SCHEMA IF EXISTS excluded_schema CASCADE`)
}()
// Step 2: Create a second user who owns tables in both included and excluded schemas
userCreatorUsername := fmt.Sprintf("user_creator_%s", uuid.New().String()[:8])
userCreatorPassword := "creator_password_123"
_, err = container.DB.Exec(fmt.Sprintf(
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
userCreatorUsername,
userCreatorPassword,
))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, userCreatorUsername))
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, userCreatorUsername))
}()
// Grant privileges to user_creator
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
container.Database,
userCreatorUsername,
))
assert.NoError(t, err)
for _, schema := range []string{"public", "included_schema", "excluded_schema"} {
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT USAGE, CREATE ON SCHEMA %s TO "%s"`,
schema,
userCreatorUsername,
))
assert.NoError(t, err)
}
// User_creator creates tables in included and excluded schemas
userCreatorDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
userCreatorUsername,
userCreatorPassword,
container.Database,
)
userCreatorConn, err := sqlx.Connect("postgres", userCreatorDSN)
assert.NoError(t, err)
defer userCreatorConn.Close()
_, err = userCreatorConn.Exec(`
CREATE TABLE included_schema.user_table (id INT, data TEXT);
INSERT INTO included_schema.user_table VALUES (4, 'user_included_data');
CREATE TABLE excluded_schema.user_excluded_table (id INT, data TEXT);
INSERT INTO excluded_schema.user_excluded_table VALUES (5, 'user_excluded_data');
`)
assert.NoError(t, err)
// Step 3: Create read-only user with IncludeSchemas = ["public", "included_schema"]
pgModel := createPostgresModel(container)
pgModel.IncludeSchemas = []string{"public", "included_schema"}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
readonlyUsername, readonlyPassword, err := pgModel.CreateReadOnlyUser(
ctx,
logger,
nil,
uuid.New(),
)
assert.NoError(t, err)
assert.NotEmpty(t, readonlyUsername)
assert.NotEmpty(t, readonlyPassword)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, readonlyUsername))
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, readonlyUsername))
}()
// Step 4: Connect as read-only user
readonlyDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
readonlyUsername,
readonlyPassword,
container.Database,
)
readonlyConn, err := sqlx.Connect("postgres", readonlyDSN)
assert.NoError(t, err)
defer readonlyConn.Close()
// Step 5: Verify read-only user CAN access included schemas
var publicData string
err = readonlyConn.Get(&publicData, "SELECT data FROM public.public_table LIMIT 1")
assert.NoError(t, err)
assert.Equal(t, "public_data", publicData)
var includedData string
err = readonlyConn.Get(&includedData, "SELECT data FROM included_schema.included_table LIMIT 1")
assert.NoError(t, err)
assert.Equal(t, "included_data", includedData)
var userIncludedData string
err = readonlyConn.Get(&userIncludedData, "SELECT data FROM included_schema.user_table LIMIT 1")
assert.NoError(t, err)
assert.Equal(t, "user_included_data", userIncludedData)
// Step 6: Verify read-only user CANNOT access excluded schema
var excludedData string
err = readonlyConn.Get(&excludedData, "SELECT data FROM excluded_schema.excluded_table LIMIT 1")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
err = readonlyConn.Get(
&excludedData,
"SELECT data FROM excluded_schema.user_excluded_table LIMIT 1",
)
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
// Step 7: Verify future tables in included schemas are accessible
_, err = userCreatorConn.Exec(`
CREATE TABLE included_schema.future_table (id INT, data TEXT);
INSERT INTO included_schema.future_table VALUES (6, 'future_data');
`)
assert.NoError(t, err)
var futureData string
err = readonlyConn.Get(&futureData, "SELECT data FROM included_schema.future_table LIMIT 1")
assert.NoError(t, err)
assert.Equal(
t,
"future_data",
futureData,
"Read-only user should access future tables in included schemas via ALTER DEFAULT PRIVILEGES FOR ROLE",
)
// Step 8: Verify future tables in excluded schema are NOT accessible
_, err = userCreatorConn.Exec(`
CREATE TABLE excluded_schema.future_excluded_table (id INT, data TEXT);
INSERT INTO excluded_schema.future_excluded_table VALUES (7, 'future_excluded_data');
`)
assert.NoError(t, err)
var futureExcludedData string
err = readonlyConn.Get(
&futureExcludedData,
"SELECT data FROM excluded_schema.future_excluded_table LIMIT 1",
)
assert.Error(t, err)
assert.Contains(
t,
err.Error(),
"permission denied",
"Read-only user should NOT access tables in excluded schemas",
)
}
func connectToPostgresContainer(t *testing.T, port string) *PostgresContainer {
dbName := "testdb"
password := "testpassword"

View File

@@ -1,6 +1,7 @@
package databases
import (
"context"
"databasus-backend/internal/features/databases/databases/mariadb"
"databasus-backend/internal/features/databases/databases/mongodb"
"databasus-backend/internal/features/databases/databases/mysql"
@@ -84,6 +85,25 @@ func (d *Database) TestConnection(
return d.getSpecificDatabase().TestConnection(logger, encryptor, d.ID)
}
func (d *Database) IsUserReadOnly(
ctx context.Context,
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
) (bool, []string, error) {
switch d.Type {
case DatabaseTypePostgres:
return d.Postgresql.IsUserReadOnly(ctx, logger, encryptor, d.ID)
case DatabaseTypeMysql:
return d.Mysql.IsUserReadOnly(ctx, logger, encryptor, d.ID)
case DatabaseTypeMariadb:
return d.Mariadb.IsUserReadOnly(ctx, logger, encryptor, d.ID)
case DatabaseTypeMongodb:
return d.Mongodb.IsUserReadOnly(ctx, logger, encryptor, d.ID)
default:
return false, nil, errors.New("read-only check not supported for this database type")
}
}
func (d *Database) HideSensitiveData() {
d.getSpecificDatabase().HideSensitiveData()
}

View File

@@ -7,6 +7,7 @@ import (
"log/slog"
"time"
"databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/databases/databases/mariadb"
"databasus-backend/internal/features/databases/databases/mongodb"
@@ -86,6 +87,23 @@ func (s *DatabaseService) CreateDatabase(
return nil, fmt.Errorf("failed to auto-detect database data: %w", err)
}
if config.GetEnv().IsCloud {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
isReadOnly, permissions, err := database.IsUserReadOnly(ctx, s.logger, s.fieldEncryptor)
if err != nil {
return nil, fmt.Errorf("failed to verify user permissions: %w", err)
}
if !isReadOnly {
return nil, fmt.Errorf(
"in cloud mode, only read-only database users are allowed (user has permissions: %v)",
permissions,
)
}
}
if err := database.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
return nil, fmt.Errorf("failed to encrypt sensitive fields: %w", err)
}
@@ -153,6 +171,29 @@ func (s *DatabaseService) UpdateDatabase(
return fmt.Errorf("failed to auto-detect database data: %w", err)
}
if config.GetEnv().IsCloud {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
isReadOnly, permissions, err := existingDatabase.IsUserReadOnly(
ctx,
s.logger,
s.fieldEncryptor,
)
if err != nil {
return fmt.Errorf("failed to verify user permissions: %w", err)
}
if !isReadOnly {
return fmt.Errorf(
"in cloud mode, only read-only database users are allowed (user has permissions: %v)",
permissions,
)
}
}
oldName := existingDatabase.Name
if err := existingDatabase.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
return fmt.Errorf("failed to encrypt sensitive fields: %w", err)
}
@@ -162,11 +203,23 @@ func (s *DatabaseService) UpdateDatabase(
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Database updated: %s", existingDatabase.Name),
&user.ID,
existingDatabase.WorkspaceID,
)
if oldName != existingDatabase.Name {
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Database updated and renamed from '%s' to '%s'",
oldName,
existingDatabase.Name,
),
&user.ID,
existingDatabase.WorkspaceID,
)
} else {
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Database updated: %s", existingDatabase.Name),
&user.ID,
existingDatabase.WorkspaceID,
)
}
return nil
}
@@ -532,9 +585,19 @@ func (s *DatabaseService) TransferDatabaseToWorkspace(
return err
}
sourceWorkspace, err := s.workspaceService.GetWorkspaceByID(*sourceWorkspaceID)
if err != nil {
return fmt.Errorf("failed to get source workspace: %w", err)
}
targetWorkspace, err := s.workspaceService.GetWorkspaceByID(targetWorkspaceID)
if err != nil {
return fmt.Errorf("failed to get target workspace: %w", err)
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Database transferred: %s from workspace %s to workspace %s",
database.Name, sourceWorkspaceID, targetWorkspaceID),
fmt.Sprintf("Database transferred: %s from workspace '%s' to workspace '%s'",
database.Name, sourceWorkspace.Name, targetWorkspace.Name),
nil,
&targetWorkspaceID,
)
@@ -649,38 +712,7 @@ func (s *DatabaseService) IsUserReadOnly(
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
switch usingDatabase.Type {
case DatabaseTypePostgres:
return usingDatabase.Postgresql.IsUserReadOnly(
ctx,
s.logger,
s.fieldEncryptor,
usingDatabase.ID,
)
case DatabaseTypeMysql:
return usingDatabase.Mysql.IsUserReadOnly(
ctx,
s.logger,
s.fieldEncryptor,
usingDatabase.ID,
)
case DatabaseTypeMariadb:
return usingDatabase.Mariadb.IsUserReadOnly(
ctx,
s.logger,
s.fieldEncryptor,
usingDatabase.ID,
)
case DatabaseTypeMongodb:
return usingDatabase.Mongodb.IsUserReadOnly(
ctx,
s.logger,
s.fieldEncryptor,
usingDatabase.ID,
)
default:
return false, nil, errors.New("read-only check not supported for this database type")
}
return usingDatabase.IsUserReadOnly(ctx, s.logger, s.fieldEncryptor)
}
func (s *DatabaseService) CreateReadOnlyUser(

View File

@@ -10,6 +10,7 @@ import (
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
"databasus-backend/internal/storage"
"databasus-backend/internal/util/tools"
"github.com/google/uuid"
@@ -70,12 +71,13 @@ func GetTestMongodbConfig() *mongodb.MongodbDatabase {
return &mongodb.MongodbDatabase{
Version: tools.MongodbVersion7,
Host: config.GetEnv().TestLocalhost,
Port: port,
Port: &port,
Username: "root",
Password: "rootpassword",
Database: "testdb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: false,
CpuCount: 1,
}
}
@@ -104,6 +106,19 @@ func CreateTestDatabase(
}
func RemoveTestDatabase(database *Database) {
// Delete backups and backup configs associated with this database
// We hardcode SQL here because we cannot call backups feature due to DI inversion
// (databases package cannot import backups package as backups already imports databases)
db := storage.GetDb()
if err := db.Exec("DELETE FROM backups WHERE database_id = ?", database.ID).Error; err != nil {
panic(fmt.Sprintf("failed to delete backups: %v", err))
}
if err := db.Exec("DELETE FROM backup_configs WHERE database_id = ?", database.ID).Error; err != nil {
panic(fmt.Sprintf("failed to delete backup config: %v", err))
}
err := databaseRepository.Delete(database.ID)
if err != nil {
panic(err)

View File

@@ -12,6 +12,15 @@ import (
type DiskService struct{}
func (s *DiskService) GetDiskUsage() (*DiskUsage, error) {
if config.GetEnv().IsCloud {
return &DiskUsage{
Platform: PlatformLinux,
TotalSpaceBytes: 100,
UsedSpaceBytes: 0,
FreeSpaceBytes: 100,
}, nil
}
platform := s.detectPlatform()
var path string

View File

@@ -0,0 +1,22 @@
package email
import (
"databasus-backend/internal/config"
"databasus-backend/internal/util/logger"
)
var env = config.GetEnv()
var log = logger.GetLogger()
var emailSMTPSender = &EmailSMTPSender{
log,
env.SMTPHost,
env.SMTPPort,
env.SMTPUser,
env.SMTPPassword,
env.SMTPHost != "" && env.SMTPPort != 0,
}
func GetEmailSMTPSender() *EmailSMTPSender {
return emailSMTPSender
}

View File

@@ -0,0 +1,245 @@
package email
import (
"crypto/tls"
"fmt"
"log/slog"
"mime"
"net"
"net/smtp"
"time"
)
const (
ImplicitTLSPort = 465
DefaultTimeout = 5 * time.Second
DefaultHelloName = "localhost"
MIMETypeHTML = "text/html"
MIMECharsetUTF8 = "UTF-8"
)
type EmailSMTPSender struct {
logger *slog.Logger
smtpHost string
smtpPort int
smtpUser string
smtpPassword string
isConfigured bool
}
func (s *EmailSMTPSender) SendEmail(to, subject, body string) error {
if !s.isConfigured {
s.logger.Warn("Skipping email send, SMTP not initialized", "to", to, "subject", subject)
return nil
}
from := s.smtpUser
if from == "" {
from = "noreply@" + s.smtpHost
}
emailContent := s.buildEmailContent(to, subject, body, from)
isAuthRequired := s.smtpUser != "" && s.smtpPassword != ""
if s.smtpPort == ImplicitTLSPort {
return s.sendImplicitTLS(to, from, emailContent, isAuthRequired)
}
return s.sendStartTLS(to, from, emailContent, isAuthRequired)
}
func (s *EmailSMTPSender) buildEmailContent(to, subject, body, from string) []byte {
// Encode Subject header using RFC 2047 to avoid SMTPUTF8 requirement
encodedSubject := encodeRFC2047(subject)
subjectHeader := fmt.Sprintf("Subject: %s\r\n", encodedSubject)
dateHeader := fmt.Sprintf("Date: %s\r\n", time.Now().UTC().Format(time.RFC1123Z))
mimeHeaders := fmt.Sprintf(
"MIME-version: 1.0;\nContent-Type: %s; charset=\"%s\";\n\n",
MIMETypeHTML,
MIMECharsetUTF8,
)
// Encode From header display name if it contains non-ASCII
encodedFrom := encodeRFC2047(from)
fromHeader := fmt.Sprintf("From: %s\r\n", encodedFrom)
toHeader := fmt.Sprintf("To: %s\r\n", to)
return []byte(fromHeader + toHeader + subjectHeader + dateHeader + mimeHeaders + body)
}
func (s *EmailSMTPSender) sendImplicitTLS(
to, from string,
emailContent []byte,
isAuthRequired bool,
) error {
createClient := func() (*smtp.Client, func(), error) {
return s.createImplicitTLSClient()
}
client, cleanup, err := s.authenticateWithRetry(createClient, isAuthRequired)
if err != nil {
return err
}
defer cleanup()
return s.sendEmail(client, to, from, emailContent)
}
func (s *EmailSMTPSender) sendStartTLS(
to, from string,
emailContent []byte,
isAuthRequired bool,
) error {
createClient := func() (*smtp.Client, func(), error) {
return s.createStartTLSClient()
}
client, cleanup, err := s.authenticateWithRetry(createClient, isAuthRequired)
if err != nil {
return err
}
defer cleanup()
return s.sendEmail(client, to, from, emailContent)
}
func (s *EmailSMTPSender) createImplicitTLSClient() (*smtp.Client, func(), error) {
addr := net.JoinHostPort(s.smtpHost, fmt.Sprintf("%d", s.smtpPort))
tlsConfig := &tls.Config{ServerName: s.smtpHost}
dialer := &net.Dialer{Timeout: DefaultTimeout}
conn, err := tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
if err != nil {
return nil, nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
}
client, err := smtp.NewClient(conn, s.smtpHost)
if err != nil {
_ = conn.Close()
return nil, nil, fmt.Errorf("failed to create SMTP client: %w", err)
}
return client, func() { _ = client.Quit() }, nil
}
func (s *EmailSMTPSender) createStartTLSClient() (*smtp.Client, func(), error) {
addr := net.JoinHostPort(s.smtpHost, fmt.Sprintf("%d", s.smtpPort))
dialer := &net.Dialer{Timeout: DefaultTimeout}
conn, err := dialer.Dial("tcp", addr)
if err != nil {
return nil, nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
}
client, err := smtp.NewClient(conn, s.smtpHost)
if err != nil {
_ = conn.Close()
return nil, nil, fmt.Errorf("failed to create SMTP client: %w", err)
}
if err := client.Hello(DefaultHelloName); err != nil {
_ = client.Quit()
_ = conn.Close()
return nil, nil, fmt.Errorf("SMTP hello failed: %w", err)
}
if ok, _ := client.Extension("STARTTLS"); ok {
if err := client.StartTLS(&tls.Config{ServerName: s.smtpHost}); err != nil {
_ = client.Quit()
_ = conn.Close()
return nil, nil, fmt.Errorf("STARTTLS failed: %w", err)
}
}
return client, func() { _ = client.Quit() }, nil
}
func (s *EmailSMTPSender) authenticateWithRetry(
createClient func() (*smtp.Client, func(), error),
isAuthRequired bool,
) (*smtp.Client, func(), error) {
client, cleanup, err := createClient()
if err != nil {
return nil, nil, err
}
if !isAuthRequired {
return client, cleanup, nil
}
// Try PLAIN auth first
plainAuth := smtp.PlainAuth("", s.smtpUser, s.smtpPassword, s.smtpHost)
if err := client.Auth(plainAuth); err == nil {
return client, cleanup, nil
}
// PLAIN auth failed, connection may be closed - recreate and try LOGIN auth
cleanup()
client, cleanup, err = createClient()
if err != nil {
return nil, nil, err
}
loginAuth := &loginAuth{username: s.smtpUser, password: s.smtpPassword}
if err := client.Auth(loginAuth); err != nil {
cleanup()
return nil, nil, fmt.Errorf("SMTP authentication failed: %w", err)
}
return client, cleanup, nil
}
func (s *EmailSMTPSender) sendEmail(client *smtp.Client, to, from string, content []byte) error {
if err := client.Mail(from); err != nil {
return fmt.Errorf("failed to set sender: %w", err)
}
if err := client.Rcpt(to); err != nil {
return fmt.Errorf("failed to set recipient: %w", err)
}
writer, err := client.Data()
if err != nil {
return fmt.Errorf("failed to get data writer: %w", err)
}
if _, err = writer.Write(content); err != nil {
return fmt.Errorf("failed to write email content: %w", err)
}
if err = writer.Close(); err != nil {
return fmt.Errorf("failed to close data writer: %w", err)
}
return nil
}
func encodeRFC2047(s string) string {
return mime.QEncoding.Encode("UTF-8", s)
}
type loginAuth struct {
username string
password string
}
func (a *loginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) {
return "LOGIN", []byte{}, nil
}
func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) {
if more {
switch string(fromServer) {
case "Username:", "User Name\x00":
return []byte(a.username), nil
case "Password:", "Password\x00":
return []byte(a.password), nil
default:
return []byte(a.username), nil
}
}
return nil, nil
}

View File

@@ -144,6 +144,10 @@ func Test_GetAttemptsByDatabase_PermissionsEnforced(t *testing.T) {
)
assert.Contains(t, string(testResp.Body), "forbidden")
}
// Cleanup
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
@@ -181,6 +185,10 @@ func Test_GetAttemptsByDatabase_FiltersByAfterDate(t *testing.T) {
for _, attempt := range response {
assert.True(t, attempt.CreatedAt.After(afterDate) || attempt.CreatedAt.Equal(afterDate))
}
// Cleanup
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_GetAttemptsByDatabase_ReturnsEmptyListForNewDatabase(t *testing.T) {
@@ -201,6 +209,10 @@ func Test_GetAttemptsByDatabase_ReturnsEmptyListForNewDatabase(t *testing.T) {
)
assert.Equal(t, 0, len(response))
// Cleanup
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func createTestDatabaseViaAPI(

View File

@@ -130,6 +130,10 @@ func Test_SaveHealthcheckConfig_PermissionsEnforced(t *testing.T) {
)
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
// Cleanup
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
@@ -162,6 +166,10 @@ func Test_SaveHealthcheckConfig_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t
)
assert.Contains(t, string(testResp.Body), "insufficient permissions")
// Cleanup
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_GetHealthcheckConfig_PermissionsEnforced(t *testing.T) {
@@ -268,6 +276,10 @@ func Test_GetHealthcheckConfig_PermissionsEnforced(t *testing.T) {
)
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
// Cleanup
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
@@ -295,6 +307,10 @@ func Test_GetHealthcheckConfig_ReturnsDefaultConfigForNewDatabase(t *testing.T)
assert.Equal(t, 1, response.IntervalMinutes)
assert.Equal(t, 3, response.AttemptsBeforeConcideredAsDown)
assert.Equal(t, 7, response.StoreAttemptsDays)
// Cleanup
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func createTestDatabaseViaAPI(

View File

@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"log/slog"
"mime"
"net"
"net/smtp"
"time"
@@ -115,16 +116,35 @@ func (e *EmailNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor
return nil
}
// encodeRFC2047 encodes a string using RFC 2047 MIME encoding for email headers
// This ensures compatibility with SMTP servers that don't support SMTPUTF8
func encodeRFC2047(s string) string {
// mime.QEncoding handles UTF-8 → =?UTF-8?Q?...?= encoding
// This allows non-ASCII characters (emojis, accents, etc.) in email headers
// while maintaining compatibility with all SMTP servers
return mime.QEncoding.Encode("UTF-8", s)
}
func (e *EmailNotifier) buildEmailContent(heading, message, from string) []byte {
subject := fmt.Sprintf("Subject: %s\r\n", heading)
mime := fmt.Sprintf(
// Encode Subject header using RFC 2047 to avoid SMTPUTF8 requirement
// This ensures compatibility with SMTP servers that don't support SMTPUTF8
encodedSubject := encodeRFC2047(heading)
subject := fmt.Sprintf("Subject: %s\r\n", encodedSubject)
dateHeader := fmt.Sprintf("Date: %s\r\n", time.Now().UTC().Format(time.RFC1123Z))
mimeHeaders := fmt.Sprintf(
"MIME-version: 1.0;\nContent-Type: %s; charset=\"%s\";\n\n",
MIMETypeHTML,
MIMECharsetUTF8,
)
fromHeader := fmt.Sprintf("From: %s\r\n", from)
// Encode From header display name if it contains non-ASCII
encodedFrom := encodeRFC2047(from)
fromHeader := fmt.Sprintf("From: %s\r\n", encodedFrom)
toHeader := fmt.Sprintf("To: %s\r\n", e.TargetEmail)
return []byte(fromHeader + toHeader + subject + mime + message)
return []byte(fromHeader + toHeader + subject + dateHeader + mimeHeaders + message)
}
func (e *EmailNotifier) sendImplicitTLS(

View File

@@ -58,6 +58,8 @@ func (s *NotifierService) SaveNotifier(
return err
}
oldName := existingNotifier.Name
if err := existingNotifier.Validate(s.fieldEncryptor); err != nil {
return err
}
@@ -67,11 +69,23 @@ func (s *NotifierService) SaveNotifier(
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Notifier updated: %s", existingNotifier.Name),
&user.ID,
&workspaceID,
)
if oldName != existingNotifier.Name {
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Notifier updated and renamed from '%s' to '%s'",
oldName,
existingNotifier.Name,
),
&user.ID,
&workspaceID,
)
} else {
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Notifier updated: %s", existingNotifier.Name),
&user.ID,
&workspaceID,
)
}
} else {
notifier.WorkspaceID = workspaceID
@@ -343,9 +357,19 @@ func (s *NotifierService) TransferNotifierToWorkspace(
return err
}
sourceWorkspace, err := s.workspaceService.GetWorkspaceByID(sourceWorkspaceID)
if err != nil {
return fmt.Errorf("failed to get source workspace: %w", err)
}
targetWorkspace, err := s.workspaceService.GetWorkspaceByID(targetWorkspaceID)
if err != nil {
return fmt.Errorf("failed to get target workspace: %w", err)
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Notifier transferred: %s from workspace %s to workspace %s",
existingNotifier.Name, sourceWorkspaceID, targetWorkspaceID),
fmt.Sprintf("Notifier transferred: %s from workspace '%s' to workspace '%s'",
existingNotifier.Name, sourceWorkspace.Name, targetWorkspace.Name),
&user.ID,
&targetWorkspaceID,
)

View File

@@ -0,0 +1,20 @@
package plans
import (
"databasus-backend/internal/util/logger"
)
var databasePlanRepository = &DatabasePlanRepository{}
var databasePlanService = &DatabasePlanService{
databasePlanRepository,
logger.GetLogger(),
}
func GetDatabasePlanService() *DatabasePlanService {
return databasePlanService
}
func GetDatabasePlanRepository() *DatabasePlanRepository {
return databasePlanRepository
}

View File

@@ -0,0 +1,19 @@
package plans
import (
"databasus-backend/internal/util/period"
"github.com/google/uuid"
)
type DatabasePlan struct {
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;primaryKey;not null"`
MaxBackupSizeMB int64 `json:"maxBackupSizeMb" gorm:"column:max_backup_size_mb;type:int;not null"`
MaxBackupsTotalSizeMB int64 `json:"maxBackupsTotalSizeMb" gorm:"column:max_backups_total_size_mb;type:int;not null"`
MaxStoragePeriod period.Period `json:"maxStoragePeriod" gorm:"column:max_storage_period;type:text;not null"`
}
func (p *DatabasePlan) TableName() string {
return "database_plans"
}

View File

@@ -0,0 +1,27 @@
package plans
import (
"databasus-backend/internal/storage"
"github.com/google/uuid"
)
type DatabasePlanRepository struct{}
func (r *DatabasePlanRepository) GetDatabasePlan(databaseID uuid.UUID) (*DatabasePlan, error) {
var databasePlan DatabasePlan
if err := storage.GetDb().Where("database_id = ?", databaseID).First(&databasePlan).Error; err != nil {
if err.Error() == "record not found" {
return nil, nil
}
return nil, err
}
return &databasePlan, nil
}
func (r *DatabasePlanRepository) CreateDatabasePlan(databasePlan *DatabasePlan) error {
return storage.GetDb().Create(&databasePlan).Error
}

View File

@@ -0,0 +1,67 @@
package plans
import (
"databasus-backend/internal/config"
"databasus-backend/internal/util/period"
"log/slog"
"github.com/google/uuid"
)
type DatabasePlanService struct {
databasePlanRepository *DatabasePlanRepository
logger *slog.Logger
}
func (s *DatabasePlanService) GetDatabasePlan(databaseID uuid.UUID) (*DatabasePlan, error) {
plan, err := s.databasePlanRepository.GetDatabasePlan(databaseID)
if err != nil {
return nil, err
}
if plan == nil {
s.logger.Info("no database plan found, creating default plan", "databaseID", databaseID)
defaultPlan := s.createDefaultDatabasePlan(databaseID)
err := s.databasePlanRepository.CreateDatabasePlan(defaultPlan)
if err != nil {
s.logger.Error("failed to create default database plan", "error", err)
return nil, err
}
return defaultPlan, nil
}
return plan, nil
}
func (s *DatabasePlanService) createDefaultDatabasePlan(databaseID uuid.UUID) *DatabasePlan {
var plan DatabasePlan
isCloud := config.GetEnv().IsCloud
if isCloud {
s.logger.Info("creating default database plan for cloud", "databaseID", databaseID)
// for playground we set limited storages enough to test,
// but not too expensive to provide it for Databasus
plan = DatabasePlan{
DatabaseID: databaseID,
MaxBackupSizeMB: 100, // ~ 1.5GB database
MaxBackupsTotalSizeMB: 4000, // ~ 30 daily backups + 10 manual backups
MaxStoragePeriod: period.PeriodWeek,
}
} else {
s.logger.Info("creating default database plan for self hosted", "databaseID", databaseID)
// by default - everything is unlimited in self hosted mode
plan = DatabasePlan{
DatabaseID: databaseID,
MaxBackupSizeMB: 0,
MaxBackupsTotalSizeMB: 0,
MaxStoragePeriod: period.PeriodForever,
}
}
return &plan
}

View File

@@ -32,7 +32,6 @@ import (
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
users_dto "databasus-backend/internal/features/users/dto"
users_enums "databasus-backend/internal/features/users/enums"
users_services "databasus-backend/internal/features/users/services"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_models "databasus-backend/internal/features/workspaces/models"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
@@ -46,8 +45,10 @@ func Test_GetRestores_WhenUserIsWorkspaceMember_RestoresReturned(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
var restores []*restores_core.Restore
test_utils.MakeGetRequestAndUnmarshal(
@@ -68,8 +69,10 @@ func Test_GetRestores_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing.T
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
@@ -88,8 +91,10 @@ func Test_GetRestores_WhenUserIsGlobalAdmin_RestoresReturned(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
@@ -114,8 +119,10 @@ func Test_RestoreBackup_WhenUserIsWorkspaceMember_RestoreInitiated(t *testing.T)
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
@@ -143,8 +150,10 @@ func Test_RestoreBackup_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
@@ -178,8 +187,10 @@ func Test_RestoreBackup_WithIsExcludeExtensions_FlagPassedCorrectly(t *testing.T
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
@@ -212,8 +223,10 @@ func Test_RestoreBackup_AuditLogWritten(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
@@ -248,7 +261,7 @@ func Test_RestoreBackup_AuditLogWritten(t *testing.T) {
found := false
for _, log := range auditLogs.AuditLogs {
if strings.Contains(log.Message, "Database restored from backup") &&
if strings.Contains(log.Message, "Database restored for database") &&
strings.Contains(log.Message, database.Name) {
found = true
break
@@ -296,12 +309,16 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
var database *databases.Database
var backup *backups_core.Backup
var storage *storages.Storage
var request restores_core.RestoreBackupRequest
if tc.dbType == databases.DatabaseTypePostgres {
_, backup = createTestDatabaseWithBackupForRestore(workspace, owner, router)
database, backup = createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
request = restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
@@ -319,7 +336,16 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
owner.Token,
router,
)
storage := createTestStorage(workspace.ID)
database = mysqlDB
storage = createTestStorage(workspace.ID)
defer func() {
// Cleanup in dependency order: backup -> database -> storage
cleanupBackup(backup)
databases.RemoveTestDatabase(mysqlDB)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
}()
configService := backups_config.GetBackupConfigService()
config, err := configService.GetBackupConfigByDbId(mysqlDB.ID)
@@ -331,7 +357,8 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
_, err = configService.SaveBackupConfig(config)
assert.NoError(t, err)
backup = createTestBackup(mysqlDB, owner)
backup = createTestBackup(mysqlDB, storage)
request = restores_core.RestoreBackupRequest{
MysqlDatabase: &mysql.MysqlDatabase{
Version: tools.MysqlVersion80,
@@ -519,8 +546,10 @@ func Test_RestoreBackup_WithParallelRestoreInProgress_ReturnsError(t *testing.T)
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
@@ -580,7 +609,7 @@ func createTestDatabaseWithBackupForRestore(
panic(err)
}
backup := createTestBackup(database, owner)
backup := createTestBackup(database, storage)
return database, backup
}
@@ -697,24 +726,14 @@ func createTestStorage(workspaceID uuid.UUID) *storages.Storage {
func createTestBackup(
database *databases.Database,
owner *users_dto.SignInResponseDTO,
storage *storages.Storage,
) *backups_core.Backup {
fieldEncryptor := util_encryption.GetFieldEncryptor()
userService := users_services.GetUserService()
user, err := userService.GetUserFromToken(owner.Token)
if err != nil {
panic(err)
}
storages, err := storages.GetStorageService().GetStorages(user, *database.WorkspaceID)
if err != nil || len(storages) == 0 {
panic("No storage found for workspace")
}
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storages[0].ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10.5,
BackupDurationMs: 1000,
@@ -729,11 +748,11 @@ func createTestBackup(
dummyContent := []byte("dummy backup content for testing")
reader := strings.NewReader(string(dummyContent))
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
if err := storages[0].SaveFile(
if err := storage.SaveFile(
context.Background(),
fieldEncryptor,
logger,
backup.ID,
backup.ID.String(),
reader,
); err != nil {
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
@@ -741,3 +760,22 @@ func createTestBackup(
return backup
}
func cleanupDatabaseWithBackup(database *databases.Database, backup *backups_core.Backup) {
// Clean up in reverse dependency order
cleanupBackup(backup)
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
// Clean up storage last (after database and backup are removed)
configService := backups_config.GetBackupConfigService()
config, err := configService.GetBackupConfigByDbId(database.ID)
if err == nil && config.StorageID != nil {
storages.RemoveTestStorage(*config.StorageID)
}
}
func cleanupBackup(backup *backups_core.Backup) {
repo := &backups_core.BackupRepository{}
repo.DeleteByID(backup.ID)
}

View File

@@ -6,6 +6,7 @@ import (
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/features/backups/backups/backuping"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/disk"
@@ -51,6 +52,7 @@ func SetupDependencies() {
setupOnce.Do(func() {
backups.GetBackupService().AddBackupRemoveListener(restoreService)
backuping.GetBackupCleaner().AddBackupRemoveListener(restoreService)
isSetup.Store(true)
})

View File

@@ -1,6 +1,7 @@
package restores
import (
"databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
@@ -127,6 +128,13 @@ func (s *RestoreService) RestoreBackupWithAuth(
return err
}
if config.GetEnv().IsCloud {
// in cloud mode we use only single thread mode,
// because otherwise we will exhaust local storage
// space (instead of streaming from S3 directly to DB)
requestDTO.PostgresqlDatabase.CpuCount = 1
}
if err := s.validateVersionCompatibility(backupDatabase, requestDTO); err != nil {
return err
}
@@ -182,11 +190,7 @@ func (s *RestoreService) RestoreBackupWithAuth(
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Database restored from backup %s for database: %s",
backupID.String(),
database.Name,
),
fmt.Sprintf("Database restored for database: %s", database.Name),
&user.ID,
database.WorkspaceID,
)
@@ -404,11 +408,7 @@ func (s *RestoreService) CancelRestore(
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Restore cancelled for database: %s (ID: %s)",
database.Name,
restoreID.String(),
),
fmt.Sprintf("Restore cancelled for database: %s", database.Name),
&user.ID,
database.WorkspaceID,
)

View File

@@ -106,7 +106,7 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
storage *storages.Storage,
mdbConfig *mariadbtypes.MariadbDatabase,
) error {
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
ctx, cancel := context.WithTimeout(parentCtx, 23*time.Hour)
defer cancel()
go func() {
@@ -141,7 +141,7 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
defer func() { _ = os.RemoveAll(filepath.Dir(myCnfFile)) }()
// Stream backup directly from storage
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
rawReader, err := storage.GetFile(fieldEncryptor, backup.FileName)
if err != nil {
return fmt.Errorf("failed to get backup file from storage: %w", err)
}

View File

@@ -154,7 +154,7 @@ func (uc *RestoreMongodbBackupUsecase) restoreFromStorage(
// Stream backup directly from storage
fieldEncryptor := util_encryption.GetFieldEncryptor()
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
rawReader, err := storage.GetFile(fieldEncryptor, backup.FileName)
if err != nil {
return fmt.Errorf("failed to get backup file from storage: %w", err)
}

View File

@@ -105,7 +105,7 @@ func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
storage *storages.Storage,
myConfig *mysqltypes.MysqlDatabase,
) error {
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
ctx, cancel := context.WithTimeout(parentCtx, 23*time.Hour)
defer cancel()
go func() {
@@ -140,7 +140,7 @@ func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
defer func() { _ = os.RemoveAll(filepath.Dir(myCnfFile)) }()
// Stream backup directly from storage
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
rawReader, err := storage.GetFile(fieldEncryptor, backup.FileName)
if err != nil {
return fmt.Errorf("failed to get backup file from storage: %w", err)
}

View File

@@ -65,6 +65,13 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
return fmt.Errorf("target database name is required for pg_restore")
}
// Validate CPU count constraint for cloud environments
if config.GetEnv().IsCloud && pg.CpuCount > 1 {
return fmt.Errorf(
"parallel restore (CPU count > 1) is not supported in cloud mode due to storage constraints. Please use CPU count = 1",
)
}
pgBin := tools.GetPostgresqlExecutable(
pg.Version,
"pg_restore",
@@ -145,7 +152,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
"--no-acl",
}
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
ctx, cancel := context.WithTimeout(parentCtx, 23*time.Hour)
defer cancel()
// Monitor for shutdown and parent cancellation
@@ -202,7 +209,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
}
// Get backup stream from storage
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
rawReader, err := storage.GetFile(fieldEncryptor, backup.FileName)
if err != nil {
return fmt.Errorf("failed to get backup file from storage: %w", err)
}
@@ -422,7 +429,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
isExcludeExtensions,
)
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
ctx, cancel := context.WithTimeout(parentCtx, 23*time.Hour)
defer cancel()
// Monitor for shutdown and parent cancellation
@@ -533,12 +540,14 @@ func (uc *RestorePostgresqlBackupUsecase) downloadBackupToTempFile(
"encrypted",
backup.Encryption == backups_config.BackupEncryptionEncrypted,
)
fieldEncryptor := util_encryption.GetFieldEncryptor()
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
rawReader, err := storage.GetFile(fieldEncryptor, backup.FileName)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to get backup file from storage: %w", err)
}
defer func() {
if err := rawReader.Close(); err != nil {
uc.logger.Error("Failed to close backup reader", "error", err)

View File

@@ -58,7 +58,8 @@ func (c *StorageController) SaveStorage(ctx *gin.Context) {
}
if err := c.storageService.SaveStorage(user, request.WorkspaceID, &request); err != nil {
if errors.Is(err, ErrInsufficientPermissionsToManageStorage) {
if errors.Is(err, ErrInsufficientPermissionsToManageStorage) ||
errors.Is(err, ErrLocalStorageNotAllowedInCloudMode) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
@@ -325,7 +326,11 @@ func (c *StorageController) TestStorageConnectionDirect(ctx *gin.Context) {
return
}
if err := c.storageService.TestStorageConnectionDirect(&request); err != nil {
if err := c.storageService.TestStorageConnectionDirect(user, &request); err != nil {
if errors.Is(err, ErrLocalStorageNotAllowedInCloudMode) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}

File diff suppressed because it is too large Load Diff

View File

@@ -33,4 +33,13 @@ var (
ErrStorageHasOtherAttachedDatabasesCannotTransfer = errors.New(
"storage has other attached databases and cannot be transferred",
)
ErrSystemStorageCannotBeTransferred = errors.New(
"system storage cannot be transferred between workspaces",
)
ErrSystemStorageCannotBeMadePrivate = errors.New(
"system storage cannot be changed to non-system",
)
ErrLocalStorageNotAllowedInCloudMode = errors.New(
"local storage can only be managed by administrators in cloud mode",
)
)

View File

@@ -14,13 +14,13 @@ type StorageFileSaver interface {
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
fileName string,
file io.Reader,
) error
GetFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) (io.ReadCloser, error)
GetFile(encryptor encryption.FieldEncryptor, fileName string) (io.ReadCloser, error)
DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error
DeleteFile(encryptor encryption.FieldEncryptor, fileName string) error
Validate(encryptor encryption.FieldEncryptor) error

View File

@@ -24,6 +24,7 @@ type Storage struct {
Type StorageType `json:"type" gorm:"column:type;not null;type:text"`
Name string `json:"name" gorm:"column:name;not null;type:text"`
LastSaveError *string `json:"lastSaveError" gorm:"column:last_save_error;type:text"`
IsSystem bool `json:"isSystem" gorm:"column:is_system;not null;default:false"`
// specific storage
LocalStorage *local_storage.LocalStorage `json:"localStorage" gorm:"foreignKey:StorageID"`
@@ -40,10 +41,10 @@ func (s *Storage) SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
fileName string,
file io.Reader,
) error {
err := s.getSpecificStorage().SaveFile(ctx, encryptor, logger, fileID, file)
err := s.getSpecificStorage().SaveFile(ctx, encryptor, logger, fileName, file)
if err != nil {
lastSaveError := err.Error()
s.LastSaveError = &lastSaveError
@@ -57,13 +58,13 @@ func (s *Storage) SaveFile(
func (s *Storage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
fileName string,
) (io.ReadCloser, error) {
return s.getSpecificStorage().GetFile(encryptor, fileID)
return s.getSpecificStorage().GetFile(encryptor, fileName)
}
func (s *Storage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
return s.getSpecificStorage().DeleteFile(encryptor, fileID)
func (s *Storage) DeleteFile(encryptor encryption.FieldEncryptor, fileName string) error {
return s.getSpecificStorage().DeleteFile(encryptor, fileName)
}
func (s *Storage) Validate(encryptor encryption.FieldEncryptor) error {
@@ -86,6 +87,17 @@ func (s *Storage) HideSensitiveData() {
s.getSpecificStorage().HideSensitiveData()
}
func (s *Storage) HideAllData() {
s.LocalStorage = nil
s.S3Storage = nil
s.GoogleDriveStorage = nil
s.NASStorage = nil
s.AzureBlobStorage = nil
s.FTPStorage = nil
s.SFTPStorage = nil
s.RcloneStorage = nil
}
func (s *Storage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
return s.getSpecificStorage().EncryptSensitiveData(encryptor)
}
@@ -93,6 +105,7 @@ func (s *Storage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) erro
func (s *Storage) Update(incoming *Storage) {
s.Name = incoming.Name
s.Type = incoming.Type
s.IsSystem = incoming.IsSystem
switch s.Type {
case StorageTypeLocal:

View File

@@ -229,12 +229,12 @@ acl = private`, s3Container.accessKey, s3Container.secretKey, s3Container.endpoi
context.Background(),
encryptor,
logger.GetLogger(),
fileID,
fileID.String(),
bytes.NewReader(fileData),
)
require.NoError(t, err, "SaveFile should succeed")
file, err := tc.storage.GetFile(encryptor, fileID)
file, err := tc.storage.GetFile(encryptor, fileID.String())
assert.NoError(t, err, "GetFile should succeed")
defer file.Close()
@@ -252,15 +252,15 @@ acl = private`, s3Container.accessKey, s3Container.secretKey, s3Container.endpoi
context.Background(),
encryptor,
logger.GetLogger(),
fileID,
fileID.String(),
bytes.NewReader(fileData),
)
require.NoError(t, err, "SaveFile should succeed")
err = tc.storage.DeleteFile(encryptor, fileID)
err = tc.storage.DeleteFile(encryptor, fileID.String())
assert.NoError(t, err, "DeleteFile should succeed")
file, err := tc.storage.GetFile(encryptor, fileID)
file, err := tc.storage.GetFile(encryptor, fileID.String())
assert.Error(t, err, "GetFile should fail for non-existent file")
if file != nil {
file.Close()
@@ -270,7 +270,7 @@ acl = private`, s3Container.accessKey, s3Container.secretKey, s3Container.endpoi
t.Run("Test_TestDeleteNonExistentFile_DoesNotError", func(t *testing.T) {
// Try to delete a non-existent file
nonExistentID := uuid.New()
err := tc.storage.DeleteFile(encryptor, nonExistentID)
err := tc.storage.DeleteFile(encryptor, nonExistentID.String())
assert.NoError(t, err, "DeleteFile should not error for non-existent file")
})
})

View File

@@ -68,7 +68,7 @@ func (s *AzureBlobStorage) SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
fileName string,
file io.Reader,
) error {
select {
@@ -82,7 +82,7 @@ func (s *AzureBlobStorage) SaveFile(
return err
}
blobName := s.buildBlobName(fileID.String())
blobName := s.buildBlobName(fileName)
blockBlobClient := client.ServiceClient().
NewContainerClient(s.ContainerName).
NewBlockBlobClient(blobName)
@@ -157,14 +157,14 @@ func (s *AzureBlobStorage) SaveFile(
func (s *AzureBlobStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
fileName string,
) (io.ReadCloser, error) {
client, err := s.getClient(encryptor)
if err != nil {
return nil, err
}
blobName := s.buildBlobName(fileID.String())
blobName := s.buildBlobName(fileName)
response, err := client.DownloadStream(
context.TODO(),
@@ -179,13 +179,13 @@ func (s *AzureBlobStorage) GetFile(
return response.Body, nil
}
func (s *AzureBlobStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
func (s *AzureBlobStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileName string) error {
client, err := s.getClient(encryptor)
if err != nil {
return err
}
blobName := s.buildBlobName(fileID.String())
blobName := s.buildBlobName(fileName)
ctx, cancel := context.WithTimeout(context.Background(), azureDeleteTimeout)
defer cancel()

View File

@@ -41,7 +41,7 @@ func (f *FTPStorage) SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
fileName string,
file io.Reader,
) error {
select {
@@ -50,19 +50,19 @@ func (f *FTPStorage) SaveFile(
default:
}
logger.Info("Starting to save file to FTP storage", "fileId", fileID.String(), "host", f.Host)
logger.Info("Starting to save file to FTP storage", "fileName", fileName, "host", f.Host)
conn, err := f.connect(encryptor, ftpConnectTimeout)
if err != nil {
logger.Error("Failed to connect to FTP", "fileId", fileID.String(), "error", err)
logger.Error("Failed to connect to FTP", "fileName", fileName, "error", err)
return fmt.Errorf("failed to connect to FTP: %w", err)
}
defer func() {
if quitErr := conn.Quit(); quitErr != nil {
logger.Error(
"Failed to close FTP connection",
"fileId",
fileID.String(),
"fileName",
fileName,
"error",
quitErr,
)
@@ -73,8 +73,8 @@ func (f *FTPStorage) SaveFile(
if err := f.ensureDirectory(conn, f.Path); err != nil {
logger.Error(
"Failed to ensure directory",
"fileId",
fileID.String(),
"fileName",
fileName,
"path",
f.Path,
"error",
@@ -84,8 +84,8 @@ func (f *FTPStorage) SaveFile(
}
}
filePath := f.getFilePath(fileID.String())
logger.Debug("Uploading file to FTP", "fileId", fileID.String(), "filePath", filePath)
filePath := f.getFilePath(fileName)
logger.Debug("Uploading file to FTP", "fileName", fileName, "filePath", filePath)
ctxReader := &contextReader{ctx: ctx, reader: file}
@@ -93,18 +93,18 @@ func (f *FTPStorage) SaveFile(
if err != nil {
select {
case <-ctx.Done():
logger.Info("FTP upload cancelled", "fileId", fileID.String())
logger.Info("FTP upload cancelled", "fileName", fileName)
return ctx.Err()
default:
logger.Error("Failed to upload file to FTP", "fileId", fileID.String(), "error", err)
logger.Error("Failed to upload file to FTP", "fileName", fileName, "error", err)
return fmt.Errorf("failed to upload file to FTP: %w", err)
}
}
logger.Info(
"Successfully saved file to FTP storage",
"fileId",
fileID.String(),
"fileName",
fileName,
"filePath",
filePath,
)
@@ -113,14 +113,14 @@ func (f *FTPStorage) SaveFile(
func (f *FTPStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
fileName string,
) (io.ReadCloser, error) {
conn, err := f.connect(encryptor, ftpConnectTimeout)
if err != nil {
return nil, fmt.Errorf("failed to connect to FTP: %w", err)
}
filePath := f.getFilePath(fileID.String())
filePath := f.getFilePath(fileName)
resp, err := conn.Retr(filePath)
if err != nil {
@@ -134,7 +134,7 @@ func (f *FTPStorage) GetFile(
}, nil
}
func (f *FTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
func (f *FTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileName string) error {
ctx, cancel := context.WithTimeout(context.Background(), ftpDeleteTimeout)
defer cancel()
@@ -146,7 +146,7 @@ func (f *FTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid
_ = conn.Quit()
}()
filePath := f.getFilePath(fileID.String())
filePath := f.getFilePath(fileName)
_, err = conn.FileSize(filePath)
if err != nil {
@@ -285,30 +285,30 @@ func (f *FTPStorage) ensureDirectory(conn *ftp.ServerConn, path string) error {
}
parts := strings.Split(path, "/")
currentPath := ""
currentDir, err := conn.CurrentDir()
if err != nil {
return fmt.Errorf("failed to get current directory: %w", err)
}
defer func() {
_ = conn.ChangeDir(currentDir)
}()
for _, part := range parts {
if part == "" || part == "." {
continue
}
if currentPath == "" {
currentPath = part
} else {
currentPath = currentPath + "/" + part
}
err := conn.ChangeDir(currentPath)
err := conn.ChangeDir(part)
if err != nil {
err = conn.MakeDir(currentPath)
err = conn.MakeDir(part)
if err != nil {
return fmt.Errorf("failed to create directory '%s': %w", currentPath, err)
return fmt.Errorf("failed to create directory '%s': %w", part, err)
}
err = conn.ChangeDir(part)
if err != nil {
return fmt.Errorf("failed to change into directory '%s': %w", part, err)
}
}
err = conn.ChangeDirToParent()
if err != nil {
return fmt.Errorf("failed to change to parent directory: %w", err)
}
}

View File

@@ -50,21 +50,19 @@ func (s *GoogleDriveStorage) SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
fileName string,
file io.Reader,
) error {
return s.withRetryOnAuth(ctx, encryptor, func(driveService *drive.Service) error {
filename := fileID.String()
folderID, err := s.ensureBackupsFolderExists(ctx, driveService)
if err != nil {
return fmt.Errorf("failed to create/find backups folder: %w", err)
}
_ = s.deleteByName(ctx, driveService, filename, folderID)
_ = s.deleteByName(ctx, driveService, fileName, folderID)
fileMeta := &drive.File{
Name: filename,
Name: fileName,
Parents: []string{folderID},
}
@@ -91,7 +89,7 @@ func (s *GoogleDriveStorage) SaveFile(
logger.Info(
"file uploaded to Google Drive",
"name",
filename,
fileName,
"folder",
"databasus_backups",
)
@@ -152,7 +150,7 @@ func (r *backpressureReader) Read(p []byte) (n int, err error) {
func (s *GoogleDriveStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
fileName string,
) (io.ReadCloser, error) {
var result io.ReadCloser
err := s.withRetryOnAuth(
@@ -164,7 +162,7 @@ func (s *GoogleDriveStorage) GetFile(
return fmt.Errorf("failed to find backups folder: %w", err)
}
fileIDGoogle, err := s.lookupFileID(driveService, fileID.String(), folderID)
fileIDGoogle, err := s.lookupFileID(driveService, fileName, folderID)
if err != nil {
return err
}
@@ -184,7 +182,7 @@ func (s *GoogleDriveStorage) GetFile(
func (s *GoogleDriveStorage) DeleteFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
fileName string,
) error {
ctx, cancel := context.WithTimeout(context.Background(), gdDeleteTimeout)
defer cancel()
@@ -195,7 +193,7 @@ func (s *GoogleDriveStorage) DeleteFile(
return fmt.Errorf("failed to find backups folder: %w", err)
}
return s.deleteByName(ctx, driveService, fileID.String(), folderID)
return s.deleteByName(ctx, driveService, fileName, folderID)
})
}

View File

@@ -36,7 +36,7 @@ func (l *LocalStorage) SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
fileName string,
file io.Reader,
) error {
select {
@@ -45,7 +45,7 @@ func (l *LocalStorage) SaveFile(
default:
}
logger.Info("Starting to save file to local storage", "fileId", fileID.String())
logger.Info("Starting to save file to local storage", "fileName", fileName)
err := files_utils.EnsureDirectories([]string{
config.GetEnv().TempFolder,
@@ -54,15 +54,15 @@ func (l *LocalStorage) SaveFile(
return fmt.Errorf("failed to ensure directories: %w", err)
}
tempFilePath := filepath.Join(config.GetEnv().TempFolder, fileID.String())
logger.Debug("Creating temp file", "fileId", fileID.String(), "tempPath", tempFilePath)
tempFilePath := filepath.Join(config.GetEnv().TempFolder, fileName)
logger.Debug("Creating temp file", "fileName", fileName, "tempPath", tempFilePath)
tempFile, err := os.Create(tempFilePath)
if err != nil {
logger.Error(
"Failed to create temp file",
"fileId",
fileID.String(),
"fileName",
fileName,
"tempPath",
tempFilePath,
"error",
@@ -74,29 +74,29 @@ func (l *LocalStorage) SaveFile(
_ = tempFile.Close()
}()
logger.Debug("Copying file data to temp file", "fileId", fileID.String())
logger.Debug("Copying file data to temp file", "fileName", fileName)
_, err = copyWithContext(ctx, tempFile, file)
if err != nil {
logger.Error("Failed to write to temp file", "fileId", fileID.String(), "error", err)
logger.Error("Failed to write to temp file", "fileName", fileName, "error", err)
return fmt.Errorf("failed to write to temp file: %w", err)
}
if err = tempFile.Sync(); err != nil {
logger.Error("Failed to sync temp file", "fileId", fileID.String(), "error", err)
logger.Error("Failed to sync temp file", "fileName", fileName, "error", err)
return fmt.Errorf("failed to sync temp file: %w", err)
}
// Close the temp file explicitly before moving it (required on Windows)
if err = tempFile.Close(); err != nil {
logger.Error("Failed to close temp file", "fileId", fileID.String(), "error", err)
logger.Error("Failed to close temp file", "fileName", fileName, "error", err)
return fmt.Errorf("failed to close temp file: %w", err)
}
finalPath := filepath.Join(config.GetEnv().DataFolder, fileID.String())
finalPath := filepath.Join(config.GetEnv().DataFolder, fileName)
logger.Debug(
"Moving file from temp to final location",
"fileId",
fileID.String(),
"fileName",
fileName,
"finalPath",
finalPath,
)
@@ -105,8 +105,8 @@ func (l *LocalStorage) SaveFile(
if err = os.Rename(tempFilePath, finalPath); err != nil {
logger.Error(
"Failed to move file from temp to backups",
"fileId",
fileID.String(),
"fileName",
fileName,
"tempPath",
tempFilePath,
"finalPath",
@@ -119,8 +119,8 @@ func (l *LocalStorage) SaveFile(
logger.Info(
"Successfully saved file to local storage",
"fileId",
fileID.String(),
"fileName",
fileName,
"finalPath",
finalPath,
)
@@ -130,12 +130,12 @@ func (l *LocalStorage) SaveFile(
func (l *LocalStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
fileName string,
) (io.ReadCloser, error) {
filePath := filepath.Join(config.GetEnv().DataFolder, fileID.String())
filePath := filepath.Join(config.GetEnv().DataFolder, fileName)
if _, err := os.Stat(filePath); os.IsNotExist(err) {
return nil, fmt.Errorf("file not found: %s", fileID.String())
return nil, fmt.Errorf("file not found: %s", fileName)
}
file, err := os.Open(filePath)
@@ -146,8 +146,8 @@ func (l *LocalStorage) GetFile(
return file, nil
}
func (l *LocalStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
filePath := filepath.Join(config.GetEnv().DataFolder, fileID.String())
func (l *LocalStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileName string) error {
filePath := filepath.Join(config.GetEnv().DataFolder, fileName)
if _, err := os.Stat(filePath); os.IsNotExist(err) {
return nil

View File

@@ -46,7 +46,7 @@ func (n *NASStorage) SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
fileName string,
file io.Reader,
) error {
select {
@@ -55,19 +55,19 @@ func (n *NASStorage) SaveFile(
default:
}
logger.Info("Starting to save file to NAS storage", "fileId", fileID.String(), "host", n.Host)
logger.Info("Starting to save file to NAS storage", "fileName", fileName, "host", n.Host)
session, err := n.createSessionWithContext(ctx, encryptor)
if err != nil {
logger.Error("Failed to create NAS session", "fileId", fileID.String(), "error", err)
logger.Error("Failed to create NAS session", "fileName", fileName, "error", err)
return fmt.Errorf("failed to create NAS session: %w", err)
}
defer func() {
if logoffErr := session.Logoff(); logoffErr != nil {
logger.Error(
"Failed to logoff NAS session",
"fileId",
fileID.String(),
"fileName",
fileName,
"error",
logoffErr,
)
@@ -78,8 +78,8 @@ func (n *NASStorage) SaveFile(
if err != nil {
logger.Error(
"Failed to mount NAS share",
"fileId",
fileID.String(),
"fileName",
fileName,
"share",
n.Share,
"error",
@@ -91,8 +91,8 @@ func (n *NASStorage) SaveFile(
if umountErr := fs.Umount(); umountErr != nil {
logger.Error(
"Failed to unmount NAS share",
"fileId",
fileID.String(),
"fileName",
fileName,
"error",
umountErr,
)
@@ -104,8 +104,8 @@ func (n *NASStorage) SaveFile(
if err := n.ensureDirectory(fs, n.Path); err != nil {
logger.Error(
"Failed to ensure directory",
"fileId",
fileID.String(),
"fileName",
fileName,
"path",
n.Path,
"error",
@@ -115,15 +115,15 @@ func (n *NASStorage) SaveFile(
}
}
filePath := n.getFilePath(fileID.String())
logger.Debug("Creating file on NAS", "fileId", fileID.String(), "filePath", filePath)
filePath := n.getFilePath(fileName)
logger.Debug("Creating file on NAS", "fileName", fileName, "filePath", filePath)
nasFile, err := fs.Create(filePath)
if err != nil {
logger.Error(
"Failed to create file on NAS",
"fileId",
fileID.String(),
"fileName",
fileName,
"filePath",
filePath,
"error",
@@ -133,21 +133,21 @@ func (n *NASStorage) SaveFile(
}
defer func() {
if closeErr := nasFile.Close(); closeErr != nil {
logger.Error("Failed to close NAS file", "fileId", fileID.String(), "error", closeErr)
logger.Error("Failed to close NAS file", "fileName", fileName, "error", closeErr)
}
}()
logger.Debug("Copying file data to NAS", "fileId", fileID.String())
logger.Debug("Copying file data to NAS", "fileName", fileName)
_, err = copyWithContext(ctx, nasFile, file)
if err != nil {
logger.Error("Failed to write file to NAS", "fileId", fileID.String(), "error", err)
logger.Error("Failed to write file to NAS", "fileName", fileName, "error", err)
return fmt.Errorf("failed to write file to NAS: %w", err)
}
logger.Info(
"Successfully saved file to NAS storage",
"fileId",
fileID.String(),
"fileName",
fileName,
"filePath",
filePath,
)
@@ -156,7 +156,7 @@ func (n *NASStorage) SaveFile(
func (n *NASStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
fileName string,
) (io.ReadCloser, error) {
session, err := n.createSession(encryptor)
if err != nil {
@@ -169,14 +169,14 @@ func (n *NASStorage) GetFile(
return nil, fmt.Errorf("failed to mount share '%s': %w", n.Share, err)
}
filePath := n.getFilePath(fileID.String())
filePath := n.getFilePath(fileName)
// Check if file exists
_, err = fs.Stat(filePath)
if err != nil {
_ = fs.Umount()
_ = session.Logoff()
return nil, fmt.Errorf("file not found: %s", fileID.String())
return nil, fmt.Errorf("file not found: %s", fileName)
}
nasFile, err := fs.Open(filePath)
@@ -194,7 +194,7 @@ func (n *NASStorage) GetFile(
}, nil
}
func (n *NASStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
func (n *NASStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileName string) error {
ctx, cancel := context.WithTimeout(context.Background(), nasDeleteTimeout)
defer cancel()
@@ -214,7 +214,7 @@ func (n *NASStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid
_ = fs.Umount()
}()
filePath := n.getFilePath(fileID.String())
filePath := n.getFilePath(fileName)
_, err = fs.Stat(filePath)
if err != nil {

View File

@@ -41,7 +41,7 @@ func (r *RcloneStorage) SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
fileName string,
file io.Reader,
) error {
select {
@@ -50,28 +50,28 @@ func (r *RcloneStorage) SaveFile(
default:
}
logger.Info("Starting to save file to rclone storage", "fileId", fileID.String())
logger.Info("Starting to save file to rclone storage", "fileName", fileName)
remoteFs, err := r.getFs(ctx, encryptor)
if err != nil {
logger.Error("Failed to create rclone filesystem", "fileId", fileID.String(), "error", err)
logger.Error("Failed to create rclone filesystem", "fileName", fileName, "error", err)
return fmt.Errorf("failed to create rclone filesystem: %w", err)
}
filePath := r.getFilePath(fileID.String())
logger.Debug("Uploading file via rclone", "fileId", fileID.String(), "filePath", filePath)
filePath := r.getFilePath(fileName)
logger.Debug("Uploading file via rclone", "fileName", fileName, "filePath", filePath)
_, err = operations.Rcat(ctx, remoteFs, filePath, io.NopCloser(file), time.Now().UTC(), nil)
if err != nil {
select {
case <-ctx.Done():
logger.Info("Rclone upload cancelled", "fileId", fileID.String())
logger.Info("Rclone upload cancelled", "fileName", fileName)
return ctx.Err()
default:
logger.Error(
"Failed to upload file via rclone",
"fileId",
fileID.String(),
"fileName",
fileName,
"error",
err,
)
@@ -81,8 +81,8 @@ func (r *RcloneStorage) SaveFile(
logger.Info(
"Successfully saved file to rclone storage",
"fileId",
fileID.String(),
"fileName",
fileName,
"filePath",
filePath,
)
@@ -91,7 +91,7 @@ func (r *RcloneStorage) SaveFile(
func (r *RcloneStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
fileName string,
) (io.ReadCloser, error) {
ctx := context.Background()
@@ -100,7 +100,7 @@ func (r *RcloneStorage) GetFile(
return nil, fmt.Errorf("failed to create rclone filesystem: %w", err)
}
filePath := r.getFilePath(fileID.String())
filePath := r.getFilePath(fileName)
obj, err := remoteFs.NewObject(ctx, filePath)
if err != nil {
@@ -115,7 +115,7 @@ func (r *RcloneStorage) GetFile(
return reader, nil
}
func (r *RcloneStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
func (r *RcloneStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileName string) error {
ctx, cancel := context.WithTimeout(context.Background(), rcloneDeleteTimeout)
defer cancel()
@@ -124,7 +124,7 @@ func (r *RcloneStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID u
return fmt.Errorf("failed to create rclone filesystem: %w", err)
}
filePath := r.getFilePath(fileID.String())
filePath := r.getFilePath(fileName)
obj, err := remoteFs.NewObject(ctx, filePath)
if err != nil {

View File

@@ -55,7 +55,7 @@ func (s *S3Storage) SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
fileName string,
file io.Reader,
) error {
select {
@@ -69,7 +69,7 @@ func (s *S3Storage) SaveFile(
return err
}
objectKey := s.buildObjectKey(fileID.String())
objectKey := s.buildObjectKey(fileName)
uploadID, err := coreClient.NewMultipartUpload(
ctx,
@@ -184,14 +184,14 @@ func (s *S3Storage) SaveFile(
func (s *S3Storage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
fileName string,
) (io.ReadCloser, error) {
client, err := s.getClient(encryptor)
if err != nil {
return nil, err
}
objectKey := s.buildObjectKey(fileID.String())
objectKey := s.buildObjectKey(fileName)
object, err := client.GetObject(
context.TODO(),
@@ -221,13 +221,13 @@ func (s *S3Storage) GetFile(
return object, nil
}
func (s *S3Storage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
func (s *S3Storage) DeleteFile(encryptor encryption.FieldEncryptor, fileName string) error {
client, err := s.getClient(encryptor)
if err != nil {
return err
}
objectKey := s.buildObjectKey(fileID.String())
objectKey := s.buildObjectKey(fileName)
ctx, cancel := context.WithTimeout(context.Background(), s3DeleteTimeout)
defer cancel()

View File

@@ -41,7 +41,7 @@ func (s *SFTPStorage) SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
fileName string,
file io.Reader,
) error {
select {
@@ -50,19 +50,19 @@ func (s *SFTPStorage) SaveFile(
default:
}
logger.Info("Starting to save file to SFTP storage", "fileId", fileID.String(), "host", s.Host)
logger.Info("Starting to save file to SFTP storage", "fileName", fileName, "host", s.Host)
client, sshConn, err := s.connect(encryptor, sftpConnectTimeout)
if err != nil {
logger.Error("Failed to connect to SFTP", "fileId", fileID.String(), "error", err)
logger.Error("Failed to connect to SFTP", "fileName", fileName, "error", err)
return fmt.Errorf("failed to connect to SFTP: %w", err)
}
defer func() {
if closeErr := client.Close(); closeErr != nil {
logger.Error(
"Failed to close SFTP client",
"fileId",
fileID.String(),
"fileName",
fileName,
"error",
closeErr,
)
@@ -70,8 +70,8 @@ func (s *SFTPStorage) SaveFile(
if closeErr := sshConn.Close(); closeErr != nil {
logger.Error(
"Failed to close SSH connection",
"fileId",
fileID.String(),
"fileName",
fileName,
"error",
closeErr,
)
@@ -82,8 +82,8 @@ func (s *SFTPStorage) SaveFile(
if err := s.ensureDirectory(client, s.Path); err != nil {
logger.Error(
"Failed to ensure directory",
"fileId",
fileID.String(),
"fileName",
fileName,
"path",
s.Path,
"error",
@@ -93,12 +93,12 @@ func (s *SFTPStorage) SaveFile(
}
}
filePath := s.getFilePath(fileID.String())
logger.Debug("Uploading file to SFTP", "fileId", fileID.String(), "filePath", filePath)
filePath := s.getFilePath(fileName)
logger.Debug("Uploading file to SFTP", "fileName", fileName, "filePath", filePath)
remoteFile, err := client.Create(filePath)
if err != nil {
logger.Error("Failed to create remote file", "fileId", fileID.String(), "error", err)
logger.Error("Failed to create remote file", "fileName", fileName, "error", err)
return fmt.Errorf("failed to create remote file: %w", err)
}
defer func() {
@@ -111,18 +111,18 @@ func (s *SFTPStorage) SaveFile(
if err != nil {
select {
case <-ctx.Done():
logger.Info("SFTP upload cancelled", "fileId", fileID.String())
logger.Info("SFTP upload cancelled", "fileName", fileName)
return ctx.Err()
default:
logger.Error("Failed to upload file to SFTP", "fileId", fileID.String(), "error", err)
logger.Error("Failed to upload file to SFTP", "fileName", fileName, "error", err)
return fmt.Errorf("failed to upload file to SFTP: %w", err)
}
}
logger.Info(
"Successfully saved file to SFTP storage",
"fileId",
fileID.String(),
"fileName",
fileName,
"filePath",
filePath,
)
@@ -131,14 +131,14 @@ func (s *SFTPStorage) SaveFile(
func (s *SFTPStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
fileName string,
) (io.ReadCloser, error) {
client, sshConn, err := s.connect(encryptor, sftpConnectTimeout)
if err != nil {
return nil, fmt.Errorf("failed to connect to SFTP: %w", err)
}
filePath := s.getFilePath(fileID.String())
filePath := s.getFilePath(fileName)
remoteFile, err := client.Open(filePath)
if err != nil {
@@ -154,7 +154,7 @@ func (s *SFTPStorage) GetFile(
}, nil
}
func (s *SFTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
func (s *SFTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileName string) error {
ctx, cancel := context.WithTimeout(context.Background(), sftpDeleteTimeout)
defer cancel()
@@ -167,7 +167,7 @@ func (s *SFTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uui
_ = sshConn.Close()
}()
filePath := s.getFilePath(fileID.String())
filePath := s.getFilePath(fileName)
_, err = client.Stat(filePath)
if err != nil {

View File

@@ -165,7 +165,7 @@ func (r *StorageRepository) FindByWorkspaceID(workspaceID uuid.UUID) ([]*Storage
Preload("FTPStorage").
Preload("SFTPStorage").
Preload("RcloneStorage").
Where("workspace_id = ?", workspaceID).
Where("workspace_id = ? OR is_system = TRUE", workspaceID).
Order("name ASC").
Find(&storages).Error; err != nil {
return nil, err

View File

@@ -3,7 +3,9 @@ package storages
import (
"fmt"
"databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
users_enums "databasus-backend/internal/features/users/enums"
users_models "databasus-backend/internal/features/users/models"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/encryption"
@@ -23,6 +25,32 @@ func (s *StorageService) SetStorageDatabaseCounter(storageDatabaseCounter Storag
s.storageDatabaseCounter = storageDatabaseCounter
}
func (s *StorageService) OnBeforeWorkspaceDeletion(workspaceID uuid.UUID) error {
storages, err := s.storageRepository.FindByWorkspaceID(workspaceID)
if err != nil {
return fmt.Errorf("failed to get storages for workspace deletion: %w", err)
}
for _, storage := range storages {
if storage.IsSystem && storage.WorkspaceID != workspaceID {
// skip system storage from another workspace
continue
}
if storage.IsSystem && storage.WorkspaceID == workspaceID {
return fmt.Errorf(
"system storage cannot be deleted due to workspace deletion, please transfer or remove storage first",
)
}
if err := s.storageRepository.Delete(storage); err != nil {
return fmt.Errorf("failed to delete storage %s: %w", storage.ID, err)
}
}
return nil
}
func (s *StorageService) SaveStorage(
user *users_models.User,
workspaceID uuid.UUID,
@@ -36,8 +64,18 @@ func (s *StorageService) SaveStorage(
return ErrInsufficientPermissionsToManageStorage
}
if config.GetEnv().IsCloud && storage.Type == StorageTypeLocal &&
user.Role != users_enums.UserRoleAdmin {
return ErrLocalStorageNotAllowedInCloudMode
}
isUpdate := storage.ID != uuid.Nil
if storage.IsSystem && user.Role != users_enums.UserRoleAdmin {
// only admin can manage system storage
return ErrInsufficientPermissionsToManageStorage
}
if isUpdate {
existingStorage, err := s.storageRepository.FindByID(storage.ID)
if err != nil {
@@ -48,8 +86,14 @@ func (s *StorageService) SaveStorage(
return ErrStorageDoesNotBelongToWorkspace
}
if existingStorage.IsSystem && !storage.IsSystem {
return ErrSystemStorageCannotBeMadePrivate
}
existingStorage.Update(storage)
oldName := existingStorage.Name
if err := existingStorage.EncryptSensitiveData(s.fieldEncryptor); err != nil {
return err
}
@@ -63,11 +107,19 @@ func (s *StorageService) SaveStorage(
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Storage updated: %s", existingStorage.Name),
&user.ID,
&workspaceID,
)
if oldName != existingStorage.Name {
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Storage renamed from '%s' to '%s'", oldName, existingStorage.Name),
&user.ID,
&workspaceID,
)
} else {
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Storage updated: %s", existingStorage.Name),
&user.ID,
&workspaceID,
)
}
} else {
storage.WorkspaceID = workspaceID
@@ -111,6 +163,11 @@ func (s *StorageService) DeleteStorage(
return ErrInsufficientPermissionsToManageStorage
}
if storage.IsSystem && user.Role != users_enums.UserRoleAdmin {
// only admin can manage system storage
return ErrInsufficientPermissionsToManageStorage
}
attachedDatabasesIDs, err := s.storageDatabaseCounter.GetStorageAttachedDatabasesIDs(storage.ID)
if err != nil {
return err
@@ -142,16 +199,22 @@ func (s *StorageService) GetStorage(
return nil, err
}
canView, _, err := s.workspaceService.CanUserAccessWorkspace(storage.WorkspaceID, user)
if err != nil {
return nil, err
}
if !canView {
return nil, ErrInsufficientPermissionsToViewStorage
if !storage.IsSystem {
canView, _, err := s.workspaceService.CanUserAccessWorkspace(storage.WorkspaceID, user)
if err != nil {
return nil, err
}
if !canView {
return nil, ErrInsufficientPermissionsToViewStorage
}
}
storage.HideSensitiveData()
if storage.IsSystem && user.Role != users_enums.UserRoleAdmin {
storage.HideAllData()
}
return storage, nil
}
@@ -174,6 +237,10 @@ func (s *StorageService) GetStorages(
for _, storage := range storages {
storage.HideSensitiveData()
if storage.IsSystem && user.Role != users_enums.UserRoleAdmin {
storage.HideAllData()
}
}
return storages, nil
@@ -213,8 +280,14 @@ func (s *StorageService) TestStorageConnection(
}
func (s *StorageService) TestStorageConnectionDirect(
user *users_models.User,
storage *Storage,
) error {
if config.GetEnv().IsCloud && storage.Type == StorageTypeLocal &&
user.Role != users_enums.UserRoleAdmin {
return ErrLocalStorageNotAllowedInCloudMode
}
var usingStorage *Storage
if storage.ID != uuid.Nil {
@@ -258,6 +331,10 @@ func (s *StorageService) TransferStorageToWorkspace(
return err
}
if existingStorage.IsSystem {
return ErrSystemStorageCannotBeTransferred
}
canManageSource, err := s.workspaceService.CanUserManageDBs(existingStorage.WorkspaceID, user)
if err != nil {
return err
@@ -301,27 +378,29 @@ func (s *StorageService) TransferStorageToWorkspace(
return err
}
sourceWorkspace, err := s.workspaceService.GetWorkspaceByID(sourceWorkspaceID)
if err != nil {
return fmt.Errorf("failed to get source workspace: %w", err)
}
targetWorkspace, err := s.workspaceService.GetWorkspaceByID(targetWorkspaceID)
if err != nil {
return fmt.Errorf("failed to get target workspace: %w", err)
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Storage transferred: %s from workspace %s to workspace %s",
existingStorage.Name, sourceWorkspaceID, targetWorkspaceID),
fmt.Sprintf("Storage transferred out: %s to workspace '%s'",
existingStorage.Name, targetWorkspace.Name),
&user.ID,
&sourceWorkspaceID,
)
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Storage transferred in: %s from workspace '%s'",
existingStorage.Name, sourceWorkspace.Name),
&user.ID,
&targetWorkspaceID,
)
return nil
}
func (s *StorageService) OnBeforeWorkspaceDeletion(workspaceID uuid.UUID) error {
storages, err := s.storageRepository.FindByWorkspaceID(workspaceID)
if err != nil {
return fmt.Errorf("failed to get storages for workspace deletion: %w", err)
}
for _, storage := range storages {
if err := s.storageRepository.Delete(storage); err != nil {
return fmt.Errorf("failed to delete storage %s: %w", storage.ID, err)
}
}
return nil
}

View File

@@ -23,6 +23,18 @@ func (c *HealthcheckController) RegisterRoutes(router *gin.RouterGroup) {
// @Failure 503 {object} HealthcheckResponse
// @Router /system/health [get]
func (c *HealthcheckController) CheckHealth(ctx *gin.Context) {
// Allow unrestricted CORS for health check endpoint
// This enables monitoring tools from any origin to check system health
ctx.Header("Access-Control-Allow-Origin", "*")
ctx.Header("Access-Control-Allow-Methods", "GET, OPTIONS")
ctx.Header("Access-Control-Allow-Headers", "Content-Type")
// Handle preflight OPTIONS request
if ctx.Request.Method == "OPTIONS" {
ctx.AbortWithStatus(http.StatusNoContent)
return
}
err := c.healthcheckService.IsHealthy()
if err == nil {

View File

@@ -1,11 +1,14 @@
package system_healthcheck
import (
"context"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups/backuping"
"databasus-backend/internal/features/disk"
"databasus-backend/internal/storage"
cache_utils "databasus-backend/internal/util/cache"
"errors"
"time"
)
type HealthcheckService struct {
@@ -15,6 +18,20 @@ type HealthcheckService struct {
}
func (s *HealthcheckService) IsHealthy() error {
return s.performHealthCheck()
}
func (s *HealthcheckService) performHealthCheck() error {
// Check if cache is available with PING
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
client := cache_utils.GetValkeyClient()
pingResult := client.Do(ctx, client.B().Ping().Build())
if pingResult.Error() != nil {
return errors.New("cannot connect to valkey")
}
diskUsage, err := s.diskService.GetDiskUsage()
if err != nil {
return errors.New("cannot get disk usage")
@@ -35,11 +52,16 @@ func (s *HealthcheckService) IsHealthy() error {
if !s.backupBackgroundService.IsSchedulerRunning() {
return errors.New("backups are not running for more than 5 minutes")
}
if !s.backupBackgroundService.IsBackupNodesAvailable() {
return errors.New("no backup nodes available")
}
}
if config.GetEnv().IsProcessingNode {
if !s.backuperNode.IsBackuperRunning() {
return errors.New("backuper node is not running for more than 5 minutes")
}
}

View File

@@ -147,6 +147,26 @@ func Test_BackupAndRestoreMariadb_WithReadOnlyUser_RestoreIsSuccessful(t *testin
}
}
func Test_BackupAndRestoreMariadb_WithExcludeEvents_EventsNotRestored(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MariadbVersion
port string
}{
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
testMariadbBackupRestoreWithExcludeEventsForVersion(t, tc.version, tc.port)
})
}
}
func testMariadbBackupRestoreForVersion(
t *testing.T,
mariadbVersion tools.MariadbVersion,
@@ -702,3 +722,145 @@ func updateMariadbDatabaseCredentialsViaAPI(
return &updatedDatabase
}
func testMariadbBackupRestoreWithExcludeEventsForVersion(
t *testing.T,
mariadbVersion tools.MariadbVersion,
port string,
) {
container, err := connectToMariadbContainer(mariadbVersion, port)
if err != nil {
t.Skipf("Skipping MariaDB %s test: %v", mariadbVersion, err)
return
}
defer func() {
if container.DB != nil {
container.DB.Close()
}
}()
setupMariadbTestData(t, container.DB)
_, err = container.DB.Exec(`
CREATE EVENT IF NOT EXISTS test_event
ON SCHEDULE EVERY 1 DAY
DO BEGIN
INSERT INTO test_data (name, value) VALUES ('event_test', 999);
END
`)
if err != nil {
t.Skipf(
"Skipping test: MariaDB version doesn't support events or event scheduler disabled: %v",
err,
)
return
}
router := createTestRouter()
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace(
"MariaDB Exclude Events Test Workspace",
user,
router,
)
storage := storages.CreateTestStorage(workspace.ID)
database := createMariadbDatabaseViaAPI(
t, router, "MariaDB Exclude Events Test Database", workspace.ID,
container.Host, container.Port,
container.Username, container.Password, container.Database,
container.Version,
user.Token,
)
database.Mariadb.IsExcludeEvents = true
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/update",
"Bearer "+user.Token,
database,
)
if w.Code != http.StatusOK {
t.Fatalf(
"Failed to update database with IsExcludeEvents. Status: %d, Body: %s",
w.Code,
w.Body.String(),
)
}
enableBackupsViaAPI(
t, router, database.ID, storage.ID,
backups_config.BackupEncryptionNone, user.Token,
)
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_mariadb_no_events"
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE %s;", newDBName))
assert.NoError(t, err)
newDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
container.Username, container.Password, container.Host, container.Port, newDBName)
newDB, err := sqlx.Connect("mysql", newDSN)
assert.NoError(t, err)
defer newDB.Close()
createMariadbRestoreViaAPI(
t, router, backup.ID,
container.Host, container.Port,
container.Username, container.Password, newDBName,
container.Version,
user.Token,
)
restore := waitForMariadbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
var tableExists int
err = newDB.Get(
&tableExists,
"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = 'test_data'",
newDBName,
)
assert.NoError(t, err)
assert.Equal(t, 1, tableExists, "Table 'test_data' should exist in restored database")
verifyMariadbDataIntegrity(t, container.DB, newDB)
var eventCount int
err = newDB.Get(
&eventCount,
"SELECT COUNT(*) FROM information_schema.events WHERE event_schema = ? AND event_name = 'test_event'",
newDBName,
)
assert.NoError(t, err)
assert.Equal(
t,
0,
eventCount,
"Event 'test_event' should NOT exist in restored database when IsExcludeEvents is true",
)
err = os.Remove(filepath.Join(config.GetEnv().DataFolder, backup.ID.String()))
if err != nil {
t.Logf("Warning: Failed to delete backup file: %v", err)
}
test_utils.MakeDeleteRequest(
t,
router,
"/api/v1/databases/"+database.ID.String(),
"Bearer "+user.Token,
http.StatusNoContent,
)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}

View File

@@ -385,13 +385,14 @@ func createMongodbDatabaseViaAPI(
Type: databases.DatabaseTypeMongodb,
Mongodb: &mongodbtypes.MongodbDatabase{
Host: host,
Port: port,
Port: &port,
Username: username,
Password: password,
Database: database,
AuthDatabase: authDatabase,
Version: version,
IsHttps: false,
IsSrv: false,
CpuCount: 1,
},
}
@@ -432,13 +433,14 @@ func createMongodbRestoreViaAPI(
request := restores_core.RestoreBackupRequest{
MongodbDatabase: &mongodbtypes.MongodbDatabase{
Host: host,
Port: port,
Port: &port,
Username: username,
Password: password,
Database: database,
AuthDatabase: authDatabase,
Version: version,
IsHttps: false,
IsSrv: false,
CpuCount: 1,
},
}

View File

@@ -726,41 +726,28 @@ func Test_InviteUserToWorkspace_MembershipReceivedAfterSignUp(t *testing.T) {
assert.Equal(t, workspaces_dto.AddStatusInvited, inviteResponse.Status)
// 3. Sign up the invited user
// 3. Sign up the invited user (now returns token directly)
signUpRequest := users_dto.SignUpRequestDTO{
Email: inviteEmail,
Password: "testpassword123",
Name: "Invited User",
}
resp := test_utils.MakePostRequest(
var signInResponse users_dto.SignInResponseDTO
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/users/signup",
"",
signUpRequest,
http.StatusOK,
)
assert.Contains(t, string(resp.Body), "User created successfully")
// 4. Sign in the newly registered user
signInRequest := users_dto.SignInRequestDTO{
Email: inviteEmail,
Password: "testpassword123",
}
var signInResponse users_dto.SignInResponseDTO
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/users/signin",
"",
signInRequest,
http.StatusOK,
&signInResponse,
)
// 5. Verify user is automatically added as member to workspace
assert.NotEmpty(t, signInResponse.Token)
assert.Equal(t, inviteEmail, signInResponse.Email)
// 4. Verify user is automatically added as member to workspace
var membersResponse workspaces_dto.GetMembersResponseDTO
test_utils.MakeGetRequestAndUnmarshal(
t,

View File

@@ -0,0 +1,593 @@
package users_controllers
import (
"net/http"
"testing"
"time"
users_dto "databasus-backend/internal/features/users/dto"
users_enums "databasus-backend/internal/features/users/enums"
users_models "databasus-backend/internal/features/users/models"
users_services "databasus-backend/internal/features/users/services"
users_testing "databasus-backend/internal/features/users/testing"
"databasus-backend/internal/storage"
test_utils "databasus-backend/internal/util/testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/bcrypt"
)
func Test_SendResetPasswordCode_WithValidEmail_CodeSent(t *testing.T) {
router := createUserTestRouter()
mockEmailSender := users_testing.NewMockEmailSender()
users_services.GetUserService().SetEmailSender(mockEmailSender)
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
request := users_dto.SendResetPasswordCodeRequestDTO{
Email: user.Email,
}
test_utils.MakePostRequest(
t,
router,
"/api/v1/users/send-reset-password-code",
"",
request,
http.StatusOK,
)
assert.Equal(t, 1, len(mockEmailSender.SentEmails))
assert.Equal(t, user.Email, mockEmailSender.SentEmails[0].To)
assert.Contains(t, mockEmailSender.SentEmails[0].Subject, "Password Reset")
}
func Test_SendResetPasswordCode_WithNonExistentUser_ReturnsSuccess(t *testing.T) {
router := createUserTestRouter()
mockEmailSender := users_testing.NewMockEmailSender()
users_services.GetUserService().SetEmailSender(mockEmailSender)
request := users_dto.SendResetPasswordCodeRequestDTO{
Email: "nonexistent" + uuid.New().String() + "@example.com",
}
// Should return success to prevent enumeration attacks
test_utils.MakePostRequest(
t,
router,
"/api/v1/users/send-reset-password-code",
"",
request,
http.StatusOK,
)
// But no email should be sent
assert.Equal(t, 0, len(mockEmailSender.SentEmails))
}
func Test_SendResetPasswordCode_WithInvitedUser_ReturnsBadRequest(t *testing.T) {
router := createUserTestRouter()
mockEmailSender := users_testing.NewMockEmailSender()
users_services.GetUserService().SetEmailSender(mockEmailSender)
adminUser := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
email := "invited" + uuid.New().String() + "@example.com"
inviteRequest := users_dto.InviteUserRequestDTO{
Email: email,
}
test_utils.MakePostRequest(
t,
router,
"/api/v1/users/invite",
"Bearer "+adminUser.Token,
inviteRequest,
http.StatusOK,
)
request := users_dto.SendResetPasswordCodeRequestDTO{
Email: email,
}
resp := test_utils.MakePostRequest(
t,
router,
"/api/v1/users/send-reset-password-code",
"",
request,
http.StatusBadRequest,
)
assert.Contains(t, string(resp.Body), "only active users")
}
func Test_SendResetPasswordCode_WithRateLimitExceeded_ReturnsTooManyRequests(t *testing.T) {
router := createUserTestRouter()
mockEmailSender := users_testing.NewMockEmailSender()
users_services.GetUserService().SetEmailSender(mockEmailSender)
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
request := users_dto.SendResetPasswordCodeRequestDTO{
Email: user.Email,
}
// Make 3 requests (should succeed)
for range 3 {
test_utils.MakePostRequest(
t,
router,
"/api/v1/users/send-reset-password-code",
"",
request,
http.StatusOK,
)
}
// 4th request should be rate limited
resp := test_utils.MakePostRequest(
t,
router,
"/api/v1/users/send-reset-password-code",
"",
request,
http.StatusTooManyRequests,
)
assert.Contains(t, string(resp.Body), "Rate limit exceeded")
}
func Test_SendResetPasswordCode_WithInvalidJSON_ReturnsBadRequest(t *testing.T) {
router := createUserTestRouter()
resp := test_utils.MakeRequest(t, router, test_utils.RequestOptions{
Method: "POST",
URL: "/api/v1/users/send-reset-password-code",
Body: "invalid json",
ExpectedStatus: http.StatusBadRequest,
})
assert.Contains(t, string(resp.Body), "Invalid request format")
}
func Test_ResetPassword_WithValidCode_PasswordReset(t *testing.T) {
router := createUserTestRouter()
mockEmailSender := users_testing.NewMockEmailSender()
users_services.GetUserService().SetEmailSender(mockEmailSender)
email := "resettest" + uuid.New().String() + "@example.com"
oldPassword := "oldpassword123"
newPassword := "newpassword456"
// Create user
signupRequest := users_dto.SignUpRequestDTO{
Email: email,
Password: oldPassword,
Name: "Test User",
}
test_utils.MakePostRequest(t, router, "/api/v1/users/signup", "", signupRequest, http.StatusOK)
// Request reset code
sendCodeRequest := users_dto.SendResetPasswordCodeRequestDTO{
Email: email,
}
test_utils.MakePostRequest(
t,
router,
"/api/v1/users/send-reset-password-code",
"",
sendCodeRequest,
http.StatusOK,
)
// Extract code from email
assert.Equal(t, 1, len(mockEmailSender.SentEmails))
emailBody := mockEmailSender.SentEmails[0].Body
code := extractCodeFromEmail(emailBody)
t.Logf("Extracted code: %s from email body (length: %d)", code, len(code))
assert.NotEmpty(t, code, "Code should be extracted from email")
assert.Len(t, code, 6, "Code should be 6 digits")
// Reset password
resetRequest := users_dto.ResetPasswordRequestDTO{
Email: email,
Code: code,
NewPassword: newPassword,
}
resp := test_utils.MakePostRequest(
t,
router,
"/api/v1/users/reset-password",
"",
resetRequest,
http.StatusOK,
)
if resp.StatusCode != http.StatusOK {
t.Logf("Reset password failed with body: %s", string(resp.Body))
}
// Verify old password doesn't work
oldSigninRequest := users_dto.SignInRequestDTO{
Email: email,
Password: oldPassword,
}
test_utils.MakePostRequest(
t,
router,
"/api/v1/users/signin",
"",
oldSigninRequest,
http.StatusBadRequest,
)
// Verify new password works
newSigninRequest := users_dto.SignInRequestDTO{
Email: email,
Password: newPassword,
}
test_utils.MakePostRequest(
t,
router,
"/api/v1/users/signin",
"",
newSigninRequest,
http.StatusOK,
)
}
func Test_ResetPassword_WithExpiredCode_ReturnsBadRequest(t *testing.T) {
router := createUserTestRouter()
mockEmailSender := users_testing.NewMockEmailSender()
users_services.GetUserService().SetEmailSender(mockEmailSender)
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
// Create expired reset code directly in database
code := "123456"
hashedCode, _ := bcrypt.GenerateFromPassword([]byte(code), bcrypt.DefaultCost)
expiredCode := &users_models.PasswordResetCode{
ID: uuid.New(),
UserID: user.UserID,
HashedCode: string(hashedCode),
ExpiresAt: time.Now().UTC().Add(-1 * time.Hour), // Expired 1 hour ago
IsUsed: false,
CreatedAt: time.Now().UTC().Add(-2 * time.Hour),
}
storage.GetDb().Create(expiredCode)
resetRequest := users_dto.ResetPasswordRequestDTO{
Email: user.Email,
Code: code,
NewPassword: "newpassword123",
}
resp := test_utils.MakePostRequest(
t,
router,
"/api/v1/users/reset-password",
"",
resetRequest,
http.StatusBadRequest,
)
assert.Contains(t, string(resp.Body), "invalid or expired")
}
func Test_ResetPassword_WithUsedCode_ReturnsBadRequest(t *testing.T) {
router := createUserTestRouter()
mockEmailSender := users_testing.NewMockEmailSender()
users_services.GetUserService().SetEmailSender(mockEmailSender)
email := "usedcode" + uuid.New().String() + "@example.com"
// Create user
signupRequest := users_dto.SignUpRequestDTO{
Email: email,
Password: "password123",
Name: "Test User",
}
test_utils.MakePostRequest(t, router, "/api/v1/users/signup", "", signupRequest, http.StatusOK)
// Request reset code
sendCodeRequest := users_dto.SendResetPasswordCodeRequestDTO{
Email: email,
}
test_utils.MakePostRequest(
t,
router,
"/api/v1/users/send-reset-password-code",
"",
sendCodeRequest,
http.StatusOK,
)
code := extractCodeFromEmail(mockEmailSender.SentEmails[0].Body)
// Use code first time
resetRequest := users_dto.ResetPasswordRequestDTO{
Email: email,
Code: code,
NewPassword: "newpassword123",
}
test_utils.MakePostRequest(
t,
router,
"/api/v1/users/reset-password",
"",
resetRequest,
http.StatusOK,
)
// Try to use same code again
resetRequest2 := users_dto.ResetPasswordRequestDTO{
Email: email,
Code: code,
NewPassword: "anotherpassword456",
}
resp := test_utils.MakePostRequest(
t,
router,
"/api/v1/users/reset-password",
"",
resetRequest2,
http.StatusBadRequest,
)
assert.Contains(t, string(resp.Body), "invalid or expired")
}
func Test_ResetPassword_WithWrongCode_ReturnsBadRequest(t *testing.T) {
router := createUserTestRouter()
mockEmailSender := users_testing.NewMockEmailSender()
users_services.GetUserService().SetEmailSender(mockEmailSender)
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
// Request reset code
sendCodeRequest := users_dto.SendResetPasswordCodeRequestDTO{
Email: user.Email,
}
test_utils.MakePostRequest(
t,
router,
"/api/v1/users/send-reset-password-code",
"",
sendCodeRequest,
http.StatusOK,
)
// Try to reset with wrong code
resetRequest := users_dto.ResetPasswordRequestDTO{
Email: user.Email,
Code: "999999", // Wrong code
NewPassword: "newpassword123",
}
resp := test_utils.MakePostRequest(
t,
router,
"/api/v1/users/reset-password",
"",
resetRequest,
http.StatusBadRequest,
)
assert.Contains(t, string(resp.Body), "invalid")
}
func Test_ResetPassword_WithInvalidNewPassword_ReturnsBadRequest(t *testing.T) {
router := createUserTestRouter()
mockEmailSender := users_testing.NewMockEmailSender()
users_services.GetUserService().SetEmailSender(mockEmailSender)
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
resetRequest := users_dto.ResetPasswordRequestDTO{
Email: user.Email,
Code: "123456",
NewPassword: "short", // Too short
}
test_utils.MakePostRequest(
t,
router,
"/api/v1/users/reset-password",
"",
resetRequest,
http.StatusBadRequest,
)
}
func Test_ResetPassword_EmailSendFailure_ReturnsError(t *testing.T) {
router := createUserTestRouter()
mockEmailSender := users_testing.NewMockEmailSender()
mockEmailSender.ShouldFail = true
users_services.GetUserService().SetEmailSender(mockEmailSender)
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
request := users_dto.SendResetPasswordCodeRequestDTO{
Email: user.Email,
}
resp := test_utils.MakePostRequest(
t,
router,
"/api/v1/users/send-reset-password-code",
"",
request,
http.StatusBadRequest,
)
assert.Contains(t, string(resp.Body), "failed to send email")
}
func Test_ResetPasswordFlow_E2E_CompletesSuccessfully(t *testing.T) {
router := createUserTestRouter()
mockEmailSender := users_testing.NewMockEmailSender()
users_services.GetUserService().SetEmailSender(mockEmailSender)
email := "e2e" + uuid.New().String() + "@example.com"
initialPassword := "initialpass123"
newPassword := "brandnewpass456"
// 1. Create user via signup
signupRequest := users_dto.SignUpRequestDTO{
Email: email,
Password: initialPassword,
Name: "E2E Test User",
}
test_utils.MakePostRequest(t, router, "/api/v1/users/signup", "", signupRequest, http.StatusOK)
// 2. Verify can sign in with initial password
signinRequest := users_dto.SignInRequestDTO{
Email: email,
Password: initialPassword,
}
var signinResponse users_dto.SignInResponseDTO
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/users/signin",
"",
signinRequest,
http.StatusOK,
&signinResponse,
)
assert.NotEmpty(t, signinResponse.Token)
// 3. Request password reset code
sendCodeRequest := users_dto.SendResetPasswordCodeRequestDTO{
Email: email,
}
test_utils.MakePostRequest(
t,
router,
"/api/v1/users/send-reset-password-code",
"",
sendCodeRequest,
http.StatusOK,
)
// 4. Verify email was sent
assert.Equal(t, 1, len(mockEmailSender.SentEmails))
code := extractCodeFromEmail(mockEmailSender.SentEmails[0].Body)
assert.NotEmpty(t, code)
// 5. Reset password using code
resetRequest := users_dto.ResetPasswordRequestDTO{
Email: email,
Code: code,
NewPassword: newPassword,
}
test_utils.MakePostRequest(
t,
router,
"/api/v1/users/reset-password",
"",
resetRequest,
http.StatusOK,
)
// 6. Verify old password no longer works
oldSignin := users_dto.SignInRequestDTO{
Email: email,
Password: initialPassword,
}
test_utils.MakePostRequest(
t,
router,
"/api/v1/users/signin",
"",
oldSignin,
http.StatusBadRequest,
)
// 7. Verify new password works
newSignin := users_dto.SignInRequestDTO{
Email: email,
Password: newPassword,
}
var finalResponse users_dto.SignInResponseDTO
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/users/signin",
"",
newSignin,
http.StatusOK,
&finalResponse,
)
assert.NotEmpty(t, finalResponse.Token)
}
func Test_ResetPassword_WithInvalidJSON_ReturnsBadRequest(t *testing.T) {
router := createUserTestRouter()
resp := test_utils.MakeRequest(t, router, test_utils.RequestOptions{
Method: "POST",
URL: "/api/v1/users/reset-password",
Body: "invalid json",
ExpectedStatus: http.StatusBadRequest,
})
assert.Contains(t, string(resp.Body), "Invalid request format")
}
// Helper function to extract 6-digit code from email HTML body
func extractCodeFromEmail(emailBody string) string {
// Look for pattern: <h1 ... >CODE</h1>
// First find <h1
h1Start := 0
for i := 0; i < len(emailBody)-3; i++ {
if emailBody[i:i+3] == "<h1" {
h1Start = i
break
}
}
if h1Start == 0 {
return ""
}
// Find the > after <h1
contentStart := h1Start
for i := h1Start; i < len(emailBody); i++ {
if emailBody[i] == '>' {
contentStart = i + 1
break
}
}
// Find </h1>
contentEnd := contentStart
for i := contentStart; i < len(emailBody)-5; i++ {
if emailBody[i:i+5] == "</h1>" {
contentEnd = i
break
}
}
if contentEnd <= contentStart {
return ""
}
// Extract content and remove whitespace
content := emailBody[contentStart:contentEnd]
code := ""
for i := 0; i < len(content); i++ {
if isDigit(content[i]) {
code += string(content[i])
}
}
if len(code) == 6 {
return code
}
return ""
}
func isDigit(b byte) bool {
return b >= '0' && b <= '9'
}

View File

@@ -11,6 +11,7 @@ import (
user_middleware "databasus-backend/internal/features/users/middleware"
users_services "databasus-backend/internal/features/users/services"
cache_utils "databasus-backend/internal/util/cache"
cloudflare_turnstile "databasus-backend/internal/util/cloudflare_turnstile"
"github.com/gin-gonic/gin"
)
@@ -28,6 +29,10 @@ func (c *UserController) RegisterRoutes(router *gin.RouterGroup) {
router.GET("/users/admin/has-password", c.IsAdminHasPassword)
router.POST("/users/admin/set-password", c.SetAdminPassword)
// Password reset (no auth required)
router.POST("/users/send-reset-password-code", c.SendResetPasswordCode)
router.POST("/users/reset-password", c.ResetPassword)
// OAuth callbacks
router.POST("/auth/github/callback", c.HandleGitHubOAuth)
router.POST("/auth/google/callback", c.HandleGoogleOAuth)
@@ -47,7 +52,7 @@ func (c *UserController) RegisterProtectedRoutes(router *gin.RouterGroup) {
// @Accept json
// @Produce json
// @Param request body users_dto.SignUpRequestDTO true "User signup data"
// @Success 200
// @Success 200 {object} users_dto.SignInResponseDTO
// @Failure 400
// @Router /users/signup [post]
func (c *UserController) SignUp(ctx *gin.Context) {
@@ -57,13 +62,41 @@ func (c *UserController) SignUp(ctx *gin.Context) {
return
}
err := c.userService.SignUp(&request)
// Verify Cloudflare Turnstile if enabled
turnstileService := cloudflare_turnstile.GetCloudflareTurnstileService()
if turnstileService.IsEnabled() {
if request.CloudflareTurnstileToken == nil || *request.CloudflareTurnstileToken == "" {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "Cloudflare Turnstile verification required"},
)
return
}
clientIP := ctx.ClientIP()
isValid, err := turnstileService.VerifyToken(*request.CloudflareTurnstileToken, clientIP)
if err != nil || !isValid {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "Cloudflare Turnstile verification failed"},
)
return
}
}
user, err := c.userService.SignUp(&request)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, gin.H{"message": "User created successfully"})
response, err := c.userService.GenerateAccessToken(user)
if err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
return
}
ctx.JSON(http.StatusOK, response)
}
// SignIn
@@ -84,6 +117,28 @@ func (c *UserController) SignIn(ctx *gin.Context) {
return
}
// Verify Cloudflare Turnstile if enabled
turnstileService := cloudflare_turnstile.GetCloudflareTurnstileService()
if turnstileService.IsEnabled() {
if request.CloudflareTurnstileToken == nil || *request.CloudflareTurnstileToken == "" {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "Cloudflare Turnstile verification required"},
)
return
}
clientIP := ctx.ClientIP()
isValid, err := turnstileService.VerifyToken(*request.CloudflareTurnstileToken, clientIP)
if err != nil || !isValid {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "Cloudflare Turnstile verification failed"},
)
return
}
}
allowed, _ := c.rateLimiter.CheckLimit(request.Email, "signin", 10, 1*time.Minute)
if !allowed {
ctx.JSON(
@@ -340,3 +395,92 @@ func (c *UserController) HandleGoogleOAuth(ctx *gin.Context) {
ctx.JSON(http.StatusOK, response)
}
// SendResetPasswordCode
// @Summary Send password reset code
// @Description Send a password reset code to the user's email
// @Tags users
// @Accept json
// @Produce json
// @Param request body users_dto.SendResetPasswordCodeRequestDTO true "Email address"
// @Success 200 {object} map[string]string
// @Failure 400 {object} map[string]string
// @Failure 429 {object} map[string]string
// @Router /users/send-reset-password-code [post]
func (c *UserController) SendResetPasswordCode(ctx *gin.Context) {
var request user_dto.SendResetPasswordCodeRequestDTO
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"})
return
}
// Verify Cloudflare Turnstile if enabled
turnstileService := cloudflare_turnstile.GetCloudflareTurnstileService()
if turnstileService.IsEnabled() {
if request.CloudflareTurnstileToken == nil || *request.CloudflareTurnstileToken == "" {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "Cloudflare Turnstile verification required"},
)
return
}
clientIP := ctx.ClientIP()
isValid, err := turnstileService.VerifyToken(*request.CloudflareTurnstileToken, clientIP)
if err != nil || !isValid {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "Cloudflare Turnstile verification failed"},
)
return
}
}
allowed, _ := c.rateLimiter.CheckLimit(
request.Email,
"reset-password",
3,
1*time.Hour,
)
if !allowed {
ctx.JSON(
http.StatusTooManyRequests,
gin.H{"error": "Rate limit exceeded. Please try again later."},
)
return
}
err := c.userService.SendResetPasswordCode(request.Email)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, gin.H{"message": "If the email exists, a reset code has been sent"})
}
// ResetPassword
// @Summary Reset password with code
// @Description Reset user password using the code sent via email
// @Tags users
// @Accept json
// @Produce json
// @Param request body users_dto.ResetPasswordRequestDTO true "Reset password data"
// @Success 200 {object} map[string]string
// @Failure 400 {object} map[string]string
// @Router /users/reset-password [post]
func (c *UserController) ResetPassword(ctx *gin.Context) {
var request user_dto.ResetPasswordRequestDTO
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"})
return
}
err := c.userService.ResetPassword(request.Email, request.Code, request.NewPassword)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, gin.H{"message": "Password reset successfully"})
}

View File

@@ -27,7 +27,20 @@ func Test_SignUpUser_WithValidData_UserCreated(t *testing.T) {
Name: "Test User",
}
test_utils.MakePostRequest(t, router, "/api/v1/users/signup", "", request, http.StatusOK)
var response users_dto.SignInResponseDTO
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/users/signup",
"",
request,
http.StatusOK,
&response,
)
assert.NotEmpty(t, response.Token)
assert.NotEqual(t, uuid.Nil, response.UserID)
assert.Equal(t, request.Email, response.Email)
}
func Test_SignUpUser_WithInvalidJSON_ReturnsBadRequest(t *testing.T) {

View File

@@ -9,14 +9,16 @@ import (
)
type SignUpRequestDTO struct {
Email string `json:"email" binding:"required"`
Password string `json:"password" binding:"required,min=8"`
Name string `json:"name" binding:"required"`
Email string `json:"email" binding:"required"`
Password string `json:"password" binding:"required,min=8"`
Name string `json:"name" binding:"required"`
CloudflareTurnstileToken *string `json:"cloudflareTurnstileToken"`
}
type SignInRequestDTO struct {
Email string `json:"email" binding:"required"`
Password string `json:"password" binding:"required"`
Email string `json:"email" binding:"required"`
Password string `json:"password" binding:"required"`
CloudflareTurnstileToken *string `json:"cloudflareTurnstileToken"`
}
type SignInResponseDTO struct {
@@ -92,3 +94,14 @@ type OAuthCallbackResponseDTO struct {
Token string `json:"token"`
IsNewUser bool `json:"isNewUser"`
}
type SendResetPasswordCodeRequestDTO struct {
Email string `json:"email" binding:"required,email"`
CloudflareTurnstileToken *string `json:"cloudflareTurnstileToken"`
}
type ResetPasswordRequestDTO struct {
Email string `json:"email" binding:"required,email"`
Code string `json:"code" binding:"required"`
NewPassword string `json:"newPassword" binding:"required,min=8"`
}

View File

@@ -7,3 +7,7 @@ import (
type AuditLogWriter interface {
WriteAuditLog(message string, userID *uuid.UUID, workspaceID *uuid.UUID)
}
type EmailSender interface {
SendEmail(to, subject, body string) error
}

View File

@@ -0,0 +1,24 @@
package users_models
import (
"time"
"github.com/google/uuid"
)
type PasswordResetCode struct {
ID uuid.UUID `json:"id" gorm:"column:id"`
UserID uuid.UUID `json:"userId" gorm:"column:user_id"`
HashedCode string `json:"-" gorm:"column:hashed_code"`
ExpiresAt time.Time `json:"expiresAt" gorm:"column:expires_at"`
IsUsed bool `json:"isUsed" gorm:"column:is_used"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
}
func (PasswordResetCode) TableName() string {
return "password_reset_codes"
}
func (p *PasswordResetCode) IsValid() bool {
return !p.IsUsed && time.Now().UTC().Before(p.ExpiresAt)
}

View File

@@ -2,6 +2,7 @@ package users_repositories
var userRepository = &UserRepository{}
var usersSettingsRepository = &UsersSettingsRepository{}
var passwordResetRepository = &PasswordResetRepository{}
func GetUserRepository() *UserRepository {
return userRepository
@@ -10,3 +11,7 @@ func GetUserRepository() *UserRepository {
func GetUsersSettingsRepository() *UsersSettingsRepository {
return usersSettingsRepository
}
func GetPasswordResetRepository() *PasswordResetRepository {
return passwordResetRepository
}

View File

@@ -0,0 +1,61 @@
package users_repositories
import (
"time"
users_models "databasus-backend/internal/features/users/models"
"databasus-backend/internal/storage"
"github.com/google/uuid"
)
type PasswordResetRepository struct{}
func (r *PasswordResetRepository) CreateResetCode(code *users_models.PasswordResetCode) error {
if code.ID == uuid.Nil {
code.ID = uuid.New()
}
return storage.GetDb().Create(code).Error
}
func (r *PasswordResetRepository) GetValidCodeByUserID(
userID uuid.UUID,
) (*users_models.PasswordResetCode, error) {
var code users_models.PasswordResetCode
err := storage.GetDb().
Where("user_id = ? AND is_used = ? AND expires_at > ?", userID, false, time.Now().UTC()).
Order("created_at DESC").
First(&code).Error
if err != nil {
return nil, err
}
return &code, nil
}
func (r *PasswordResetRepository) MarkCodeAsUsed(codeID uuid.UUID) error {
return storage.GetDb().Model(&users_models.PasswordResetCode{}).
Where("id = ?", codeID).
Update("is_used", true).Error
}
func (r *PasswordResetRepository) DeleteExpiredCodes() error {
return storage.GetDb().
Where("expires_at < ?", time.Now().UTC()).
Delete(&users_models.PasswordResetCode{}).Error
}
func (r *PasswordResetRepository) CountRecentCodesByUserID(
userID uuid.UUID,
since time.Time,
) (int64, error) {
var count int64
err := storage.GetDb().Model(&users_models.PasswordResetCode{}).
Where("user_id = ? AND created_at > ?", userID, since).
Count(&count).Error
return count, err
}

View File

@@ -1,6 +1,7 @@
package users_services
import (
"databasus-backend/internal/features/email"
"databasus-backend/internal/features/encryption/secrets"
users_repositories "databasus-backend/internal/features/users/repositories"
)
@@ -10,6 +11,8 @@ var userService = &UserService{
secrets.GetSecretKeyService(),
settingsService,
nil,
email.GetEmailSMTPSender(),
users_repositories.GetPasswordResetRepository(),
}
var settingsService = &SettingsService{
users_repositories.GetUsersSettingsRepository(),

Some files were not shown because too many files have changed in this diff Show More